Merge branch 'Comfy-Org:master' into offloader-maifee

This commit is contained in:
Maifee Ul Asad 2026-01-07 21:26:04 +06:00 committed by GitHub
commit ffa7a369ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
144 changed files with 11142 additions and 3427 deletions

View File

@ -53,6 +53,16 @@ try:
repo.stash(ident) repo.stash(ident)
except KeyError: except KeyError:
print("nothing to stash") # noqa: T201 print("nothing to stash") # noqa: T201
except:
print("Could not stash, cleaning index and trying again.") # noqa: T201
repo.state_cleanup()
repo.index.read_tree(repo.head.peel().tree)
repo.index.write()
try:
repo.stash(ident)
except KeyError:
print("nothing to stash.") # noqa: T201
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S')) backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201 print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
try: try:

View File

@ -1,3 +1,3 @@
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes ..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
pause pause

View File

@ -1,3 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
pause pause

View File

@ -1,3 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
pause pause

View File

@ -117,7 +117,7 @@ jobs:
./python.exe get-pip.py ./python.exe get-pip.py
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/* ./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt grep comfy ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
./python.exe -s -m pip install -r requirements_comfyui.txt ./python.exe -s -m pip install -r requirements_comfyui.txt
rm requirements_comfyui.txt rm requirements_comfyui.txt

View File

@ -18,7 +18,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}

View File

@ -5,6 +5,7 @@ on:
push: push:
branches: branches:
- master - master
- release/**
paths-ignore: paths-ignore:
- 'app/**' - 'app/**'
- 'input/**' - 'input/**'
@ -19,6 +20,7 @@ jobs:
test-stable: test-stable:
strategy: strategy:
fail-fast: false fail-fast: false
max-parallel: 1 # This forces sequential execution
matrix: matrix:
# os: [macos, linux, windows] # os: [macos, linux, windows]
# os: [macos, linux] # os: [macos, linux]
@ -73,6 +75,7 @@ jobs:
test-unix-nightly: test-unix-nightly:
strategy: strategy:
fail-fast: false fail-fast: false
max-parallel: 1 # This forces sequential execution
matrix: matrix:
# os: [macos, linux] # os: [macos, linux]
os: [linux] os: [linux]

View File

@ -2,9 +2,9 @@ name: Execution Tests
on: on:
push: push:
branches: [ main, master ] branches: [ main, master, release/** ]
pull_request: pull_request:
branches: [ main, master ] branches: [ main, master, release/** ]
jobs: jobs:
test: test:

View File

@ -2,9 +2,9 @@ name: Test server launches without errors
on: on:
push: push:
branches: [ main, master ] branches: [ main, master, release/** ]
pull_request: pull_request:
branches: [ main, master ] branches: [ main, master, release/** ]
jobs: jobs:
test: test:
@ -32,7 +32,9 @@ jobs:
working-directory: ComfyUI working-directory: ComfyUI
- name: Check for unhandled exceptions in server log - name: Check for unhandled exceptions in server log
run: | run: |
if grep -qE "Exception|Error" console_output.log; then grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': True, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" console_output.log | grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': False, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" > console_output_filtered.log
cat console_output_filtered.log
if grep -qE "Exception|Error" console_output_filtered.log; then
echo "Unhandled exception/error found in server log." echo "Unhandled exception/error found in server log."
exit 1 exit 1
fi fi

View File

@ -2,9 +2,9 @@ name: Unit Tests
on: on:
push: push:
branches: [ main, master ] branches: [ main, master, release/** ]
pull_request: pull_request:
branches: [ main, master ] branches: [ main, master, release/** ]
jobs: jobs:
test: test:

View File

@ -6,6 +6,7 @@ on:
- "pyproject.toml" - "pyproject.toml"
branches: branches:
- master - master
- release/**
jobs: jobs:
update-version: update-version:

View File

@ -119,6 +119,9 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0) roughly every week. - Releases a new stable version (e.g., v0.7.0) roughly every week.
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
- Minor versions will be used for releases off the master branch.
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
- Commits outside of the stable release tags may be very unstable and break many custom nodes. - Commits outside of the stable release tags may be very unstable and break many custom nodes.
- Serves as the foundation for the desktop release - Serves as the foundation for the desktop release
@ -209,6 +212,8 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old.
### Instructions: ### Instructions:
Git clone this repo. Git clone this repo.

View File

@ -58,8 +58,13 @@ class InternalRoutes:
return web.json_response({"error": "Invalid directory type"}, status=400) return web.json_response({"error": "Invalid directory type"}, status=400)
directory = get_directory_by_type(directory_type) directory = get_directory_by_type(directory_type)
def is_visible_file(entry: os.DirEntry) -> bool:
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
return entry.is_file() and not entry.name.startswith('.')
sorted_files = sorted( sorted_files = sorted(
(entry for entry in os.scandir(directory) if entry.is_file()), (entry for entry in os.scandir(directory) if is_visible_file(entry)),
key=lambda entry: -entry.stat().st_mtime key=lambda entry: -entry.stat().st_mtime
) )
return web.json_response([entry.name for entry in sorted_files], status=200) return web.json_response([entry.name for entry in sorted_files], status=200)

View File

@ -44,7 +44,7 @@ class ModelFileManager:
@routes.get("/experiment/models/{folder}") @routes.get("/experiment/models/{folder}")
async def get_all_models(request): async def get_all_models(request):
folder = request.match_info.get("folder", None) folder = request.match_info.get("folder", None)
if not folder in folder_paths.folder_names_and_paths: if folder not in folder_paths.folder_names_and_paths:
return web.Response(status=404) return web.Response(status=404)
files = self.get_model_file_list(folder) files = self.get_model_file_list(folder)
return web.json_response(files) return web.json_response(files)
@ -55,7 +55,7 @@ class ModelFileManager:
path_index = int(request.match_info.get("path_index", None)) path_index = int(request.match_info.get("path_index", None))
filename = request.match_info.get("filename", None) filename = request.match_info.get("filename", None)
if not folder_name in folder_paths.folder_names_and_paths: if folder_name not in folder_paths.folder_names_and_paths:
return web.Response(status=404) return web.Response(status=404)
folders = folder_paths.folder_names_and_paths[folder_name] folders = folder_paths.folder_names_and_paths[folder_name]

View File

@ -97,6 +97,13 @@ class LatentPreviewMethod(enum.Enum):
Latent2RGB = "latent2rgb" Latent2RGB = "latent2rgb"
TAESD = "taesd" TAESD = "taesd"
@classmethod
def from_string(cls, value: str):
for member in cls:
if member.value == value:
return member
return None
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")

View File

@ -2,6 +2,25 @@ import torch
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ops import comfy.ops
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
if crop:
scale = (size / min(image.shape[2], image.shape[3]))
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
class CLIPAttention(torch.nn.Module): class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations): def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__() super().__init__()

View File

@ -1,6 +1,5 @@
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
import os import os
import torch
import json import json
import logging import logging
@ -17,24 +16,7 @@ class Output:
def __setitem__(self, key, item): def __setitem__(self, key, item):
setattr(self, key, item) setattr(self, key, item)
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
if crop:
scale = (size / min(image.shape[2], image.shape[3]))
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
IMAGE_ENCODERS = { IMAGE_ENCODERS = {
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
@ -73,7 +55,7 @@ class ClipVisionModel():
def encode_image(self, image, crop=True): def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher) comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
outputs = Output() outputs = Output()

View File

@ -87,6 +87,7 @@ class IndexListCallbacks:
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results" COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
EXECUTE_START = "execute_start" EXECUTE_START = "execute_start"
EXECUTE_CLEANUP = "execute_cleanup" EXECUTE_CLEANUP = "execute_cleanup"
RESIZE_COND_ITEM = "resize_cond_item"
def init_callbacks(self): def init_callbacks(self):
return {} return {}
@ -142,7 +143,7 @@ class IndexListContextHandler(ContextHandlerABC):
# if multiple conds, split based on primary region # if multiple conds, split based on primary region
if self.split_conds_to_windows and len(cond_in) > 1: if self.split_conds_to_windows and len(cond_in) > 1:
region = window.get_region_index(len(cond_in)) region = window.get_region_index(len(cond_in))
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}") logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}")
cond_in = [cond_in[region]] cond_in = [cond_in[region]]
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it # cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in: for actual_cond in cond_in:
@ -166,6 +167,18 @@ class IndexListContextHandler(ContextHandlerABC):
new_cond_item = cond_item.copy() new_cond_item = cond_item.copy()
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor) # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items(): for cond_key, cond_value in new_cond_item.items():
# Allow callbacks to handle custom conditioning items
handled = False
for callback in comfy.patcher_extension.get_all_callbacks(
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
):
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
if result is not None:
new_cond_item[cond_key] = result
handled = True
break
if handled:
continue
if isinstance(cond_value, torch.Tensor): if isinstance(cond_value, torch.Tensor):
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \ if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)): (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
@ -175,6 +188,12 @@ class IndexListContextHandler(ContextHandlerABC):
audio_cond = cond_value.cond audio_cond = cond_value.cond
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim): if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1)) new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
# Handle vace_context (temporal dim is 3)
elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
vace_cond = cond_value.cond
if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim):
sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list)
new_cond_item[cond_key] = cond_value._copy_with(sliced_vace)
# if has cond that is a Tensor, check if needs to be subset # if has cond that is a Tensor, check if needs to be subset
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \ if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \

View File

@ -527,7 +527,8 @@ class HookKeyframeGroup:
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0: if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
break break
# if eval_c is outside the percent range, stop looking further # if eval_c is outside the percent range, stop looking further
else: break else:
break
# update steps current context is used # update steps current context is used
self._current_used_steps += 1 self._current_used_steps += 1
# update current timestep this was performed on # update current timestep this was performed on

View File

@ -74,6 +74,9 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
def default_noise_sampler(x, seed=None): def default_noise_sampler(x, seed=None):
if seed is not None: if seed is not None:
if x.device == torch.device("cpu"):
seed += 1
generator = torch.Generator(device=x.device) generator = torch.Generator(device=x.device)
generator.manual_seed(seed) generator.manual_seed(seed)
else: else:
@ -1557,10 +1560,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad() @torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023) arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
""" """
if solver_type not in {"phi_1", "phi_2"}:
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None) seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@ -1600,8 +1606,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2 # Step 2
denoised_d = torch.lerp(denoised, denoised_2, fac) if solver_type == "phi_1":
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
elif solver_type == "phi_2":
b2 = ei_h_phi_2(-h_eta) / r
b1 = ei_h_phi_1(-h_eta) - b2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
if inject_noise: if inject_noise:
segment_factor = (r - 1) * h * eta segment_factor = (r - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp() sde_noise = sde_noise * segment_factor.exp()
@ -1609,6 +1621,17 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
x = x + sde_noise * sigmas[i + 1] * s_noise x = x + sde_noise * sigmas[i + 1] * s_noise
return x return x
@torch.no_grad()
def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"):
"""Deterministic exponential Heun second order method in data prediction (x0) and logSNR time."""
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type)
@torch.no_grad()
def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"):
"""Stochastic exponential Heun second order method in data prediction (x0) and logSNR time."""
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type)
@torch.no_grad() @torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
@ -1756,7 +1779,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
# Predictor # Predictor
if sigmas[i + 1] == 0: if sigmas[i + 1] == 0:
# Denoising step # Denoising step
x = denoised x_pred = denoised
else: else:
tau_t = tau_func(sigmas[i + 1]) tau_t = tau_func(sigmas[i + 1])
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1] curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
@ -1777,7 +1800,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
if tau_t > 0 and s_noise > 0: if tau_t > 0 and s_noise > 0:
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
x_pred = x_pred + noise x_pred = x_pred + noise
return x return x_pred
@torch.no_grad() @torch.no_grad()

View File

@ -407,6 +407,11 @@ class LTXV(LatentFormat):
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512] self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
class LTXAV(LTXV):
def __init__(self):
self.latent_rgb_factors = None
self.latent_rgb_factors_bias = None
class HunyuanVideo(LatentFormat): class HunyuanVideo(LatentFormat):
latent_channels = 16 latent_channels = 16
latent_dimensions = 3 latent_dimensions = 3

View File

@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
nerf_final_head_type: str nerf_final_head_type: str
# None means use the same dtype as the model. # None means use the same dtype as the model.
nerf_embedder_dtype: Optional[torch.dtype] nerf_embedder_dtype: Optional[torch.dtype]
use_x0: bool
class ChromaRadiance(Chroma): class ChromaRadiance(Chroma):
""" """
@ -159,6 +159,9 @@ class ChromaRadiance(Chroma):
self.skip_dit = [] self.skip_dit = []
self.lite = False self.lite = False
if params.use_x0:
self.register_buffer("__x0__", torch.tensor([]))
@property @property
def _nerf_final_layer(self) -> nn.Module: def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear": if self.params.nerf_final_head_type == "linear":
@ -267,7 +270,7 @@ class ChromaRadiance(Chroma):
bad_keys = tuple( bad_keys = tuple(
k k
for k, v in overrides.items() for k, v in overrides.items()
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys) if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys)
) )
if bad_keys: if bad_keys:
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
@ -276,6 +279,12 @@ class ChromaRadiance(Chroma):
params_dict |= overrides params_dict |= overrides
return params.__class__(**params_dict) return params.__class__(**params_dict)
def _apply_x0_residual(self, predicted, noisy, timesteps):
# non zero during training to prevent 0 div
eps = 0.0
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
def _forward( def _forward(
self, self,
x: Tensor, x: Tensor,
@ -316,4 +325,11 @@ class ChromaRadiance(Chroma):
transformer_options, transformer_options,
attn_mask=kwargs.get("attention_mask", None), attn_mask=kwargs.get("attention_mask", None),
) )
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
# If x0 variant → v-pred, just return this instead
if hasattr(self, "__x0__"):
out = self._apply_x0_residual(out, img, timestep)
return out

View File

@ -4,6 +4,7 @@ from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management import comfy.model_management
import logging
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor: def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0 assert dim % 2 == 0
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=torch.float32, device=pos.device) return out.to(dtype=torch.float32, device=pos.device)
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0] try:
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) import comfy.quant_ops
apply_rope = comfy.quant_ops.ck.apply_rope
apply_rope1 = comfy.quant_ops.ck.apply_rope1
except:
logging.warning("No comfy kitchen, using old apply_rope functions.")
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
return x_out.reshape(*x.shape).type_as(x) x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): return x_out.reshape(*x.shape).type_as(x)
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)

View File

@ -43,6 +43,7 @@ class HunyuanVideoParams:
meanflow: bool meanflow: bool
use_cond_type_embedding: bool use_cond_type_embedding: bool
vision_in_dim: int vision_in_dim: int
meanflow_sum: bool
class SelfAttentionRef(nn.Module): class SelfAttentionRef(nn.Module):
@ -317,7 +318,7 @@ class HunyuanVideo(nn.Module):
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1] timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype) timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype)) vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
vec = (vec + vec_r) / 2 vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2
if ref_latent is not None: if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent) ref_latent_ids = self.img_ids(ref_latent)

View File

@ -3,7 +3,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
import model_management, model_patcher import model_management
import model_patcher
class SRResidualCausalBlock3D(nn.Module): class SRResidualCausalBlock3D(nn.Module):
def __init__(self, channels: int): def __init__(self, channels: int):

View File

@ -0,0 +1,837 @@
from typing import Tuple
import torch
import torch.nn as nn
from comfy.ldm.lightricks.model import (
CrossAttention,
FeedForward,
AdaLayerNormSingle,
PixArtAlphaTextProjection,
LTXVModel,
)
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
import comfy.ldm.common_dit
class BasicAVTransformerBlock(nn.Module):
def __init__(
self,
v_dim,
a_dim,
v_heads,
a_heads,
vd_head,
ad_head,
v_context_dim=None,
a_context_dim=None,
attn_precision=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.attn_precision = attn_precision
self.attn1 = CrossAttention(
query_dim=v_dim,
heads=v_heads,
dim_head=vd_head,
context_dim=None,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.audio_attn1 = CrossAttention(
query_dim=a_dim,
heads=a_heads,
dim_head=ad_head,
context_dim=None,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.attn2 = CrossAttention(
query_dim=v_dim,
context_dim=v_context_dim,
heads=v_heads,
dim_head=vd_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.audio_attn2 = CrossAttention(
query_dim=a_dim,
context_dim=a_context_dim,
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
# Q: Video, K,V: Audio
self.audio_to_video_attn = CrossAttention(
query_dim=v_dim,
context_dim=a_dim,
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
# Q: Audio, K,V: Video
self.video_to_audio_attn = CrossAttention(
query_dim=a_dim,
context_dim=v_dim,
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.ff = FeedForward(
v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations
)
self.audio_ff = FeedForward(
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
self.audio_scale_shift_table = nn.Parameter(
torch.empty(6, a_dim, device=device, dtype=dtype)
)
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
torch.empty(5, a_dim, device=device, dtype=dtype)
)
self.scale_shift_table_a2v_ca_video = nn.Parameter(
torch.empty(5, v_dim, device=device, dtype=dtype)
)
def get_ada_values(
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
):
num_ada_params = scale_shift_table.shape[0]
ada_values = (
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
).unbind(dim=2)
return ada_values
def get_av_ca_ada_values(
self,
scale_shift_table: torch.Tensor,
batch_size: int,
scale_shift_timestep: torch.Tensor,
gate_timestep: torch.Tensor,
num_scale_shift_values: int = 4,
):
scale_shift_ada_values = self.get_ada_values(
scale_shift_table[:num_scale_shift_values, :],
batch_size,
scale_shift_timestep,
)
gate_ada_values = self.get_ada_values(
scale_shift_table[num_scale_shift_values:, :],
batch_size,
gate_timestep,
)
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
return (*scale_shift_chunks, *gate_ada_values)
def forward(
self,
x: Tuple[torch.Tensor, torch.Tensor],
v_context=None,
a_context=None,
attention_mask=None,
v_timestep=None,
a_timestep=None,
v_pe=None,
a_pe=None,
v_cross_pe=None,
a_cross_pe=None,
v_cross_scale_shift_timestep=None,
a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None,
a_cross_gate_timestep=None,
transformer_options=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True)
vx, ax = x
run_ax = run_ax and ax.numel() > 0
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
if run_vx:
vshift_msa, vscale_msa, vgate_msa = (
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
)
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
vx += self.attn2(
comfy.ldm.common_dit.rms_norm(vx),
context=v_context,
mask=attention_mask,
transformer_options=transformer_options,
)
del vshift_msa, vscale_msa, vgate_msa
if run_ax:
ashift_msa, ascale_msa, agate_msa = (
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
)
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
ax += (
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
* agate_msa
)
ax += self.audio_attn2(
comfy.ldm.common_dit.rms_norm(ax),
context=a_context,
mask=attention_mask,
transformer_options=transformer_options,
)
del ashift_msa, ascale_msa, agate_msa
# Audio - Video cross attention.
if run_a2v or run_v2a:
# norm3
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
(
scale_ca_audio_hidden_states_a2v,
shift_ca_audio_hidden_states_a2v,
scale_ca_audio_hidden_states_v2a,
shift_ca_audio_hidden_states_v2a,
gate_out_v2a,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
a_cross_scale_shift_timestep,
a_cross_gate_timestep,
)
(
scale_ca_video_hidden_states_a2v,
shift_ca_video_hidden_states_a2v,
scale_ca_video_hidden_states_v2a,
shift_ca_video_hidden_states_v2a,
gate_out_a2v,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
v_cross_scale_shift_timestep,
v_cross_gate_timestep,
)
if run_a2v:
vx_scaled = (
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
+ shift_ca_video_hidden_states_a2v
)
ax_scaled = (
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
+ shift_ca_audio_hidden_states_a2v
)
vx += (
self.audio_to_video_attn(
vx_scaled,
context=ax_scaled,
pe=v_cross_pe,
k_pe=a_cross_pe,
transformer_options=transformer_options,
)
* gate_out_a2v
)
del gate_out_a2v
del scale_ca_video_hidden_states_a2v,\
shift_ca_video_hidden_states_a2v,\
scale_ca_audio_hidden_states_a2v,\
shift_ca_audio_hidden_states_a2v,\
if run_v2a:
ax_scaled = (
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
+ shift_ca_audio_hidden_states_v2a
)
vx_scaled = (
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
+ shift_ca_video_hidden_states_v2a
)
ax += (
self.video_to_audio_attn(
ax_scaled,
context=vx_scaled,
pe=a_cross_pe,
k_pe=v_cross_pe,
transformer_options=transformer_options,
)
* gate_out_v2a
)
del gate_out_v2a
del scale_ca_video_hidden_states_v2a,\
shift_ca_video_hidden_states_v2a,\
scale_ca_audio_hidden_states_v2a,\
shift_ca_audio_hidden_states_v2a
if run_vx:
vshift_mlp, vscale_mlp, vgate_mlp = (
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
)
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
vx += self.ff(vx_scaled) * vgate_mlp
del vshift_mlp, vscale_mlp, vgate_mlp
if run_ax:
ashift_mlp, ascale_mlp, agate_mlp = (
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
)
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
ax += self.audio_ff(ax_scaled) * agate_mlp
del ashift_mlp, ascale_mlp, agate_mlp
return vx, ax
class LTXAVModel(LTXVModel):
"""LTXAV model for audio-video generation."""
def __init__(
self,
in_channels=128,
audio_in_channels=128,
cross_attention_dim=4096,
audio_cross_attention_dim=2048,
attention_head_dim=128,
audio_attention_head_dim=64,
num_attention_heads=32,
audio_num_attention_heads=32,
caption_channels=3840,
num_layers=48,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
audio_positional_embedding_max_pos=[20],
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier=1000.0,
av_ca_timestep_scale_multiplier=1.0,
dtype=None,
device=None,
operations=None,
**kwargs,
):
# Store audio-specific parameters
self.audio_in_channels = audio_in_channels
self.audio_cross_attention_dim = audio_cross_attention_dim
self.audio_attention_head_dim = audio_attention_head_dim
self.audio_num_attention_heads = audio_num_attention_heads
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
# Calculate audio dimensions
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
self.audio_out_channels = audio_in_channels
# Audio-specific constants
self.num_audio_channels = 8
self.audio_frequency_bins = 16
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
super().__init__(
in_channels=in_channels,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
caption_channels=caption_channels,
num_layers=num_layers,
positional_embedding_theta=positional_embedding_theta,
positional_embedding_max_pos=positional_embedding_max_pos,
causal_temporal_positioning=causal_temporal_positioning,
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
dtype=dtype,
device=device,
operations=operations,
**kwargs,
)
def _init_model_components(self, device, dtype, **kwargs):
"""Initialize LTXAV-specific components."""
# Audio-specific projections
self.audio_patchify_proj = self.operations.Linear(
self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device
)
# Audio-specific AdaLN
self.audio_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
use_additional_conditions=False,
dtype=dtype,
device=device,
operations=self.operations,
)
num_scale_shift_values = 4
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
self.inner_dim,
use_additional_conditions=False,
embedding_coefficient=num_scale_shift_values,
dtype=dtype,
device=device,
operations=self.operations,
)
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
self.inner_dim,
use_additional_conditions=False,
embedding_coefficient=1,
dtype=dtype,
device=device,
operations=self.operations,
)
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
use_additional_conditions=False,
embedding_coefficient=num_scale_shift_values,
dtype=dtype,
device=device,
operations=self.operations,
)
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
use_additional_conditions=False,
embedding_coefficient=1,
dtype=dtype,
device=device,
operations=self.operations,
)
# Audio caption projection
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks for LTXAV."""
self.transformer_blocks = nn.ModuleList(
[
BasicAVTransformerBlock(
v_dim=self.inner_dim,
a_dim=self.audio_inner_dim,
v_heads=self.num_attention_heads,
a_heads=self.audio_num_attention_heads,
vd_head=self.attention_head_dim,
ad_head=self.audio_attention_head_dim,
v_context_dim=self.cross_attention_dim,
a_context_dim=self.audio_cross_attention_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
for _ in range(self.num_layers)
]
)
def _init_output_components(self, device, dtype):
"""Initialize output components for LTXAV."""
# Video output components
super()._init_output_components(device, dtype)
# Audio output components
self.audio_scale_shift_table = nn.Parameter(
torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device)
)
self.audio_norm_out = self.operations.LayerNorm(
self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
)
self.audio_proj_out = self.operations.Linear(
self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device
)
self.a_patchifier = AudioPatchifier(1, start_end=True)
def separate_audio_and_video_latents(self, x, audio_length):
"""Separate audio and video latents from combined input."""
# vx = x[:, : self.in_channels]
# ax = x[:, self.in_channels :]
#
# ax = ax.reshape(ax.shape[0], -1)
# ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins]
#
# ax = ax.reshape(
# ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins
# )
vx = x[0]
ax = x[1] if len(x) > 1 else torch.zeros(
(vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins),
device=vx.device, dtype=vx.dtype
)
return vx, ax
def recombine_audio_and_video_latents(self, vx, ax, target_shape=None):
if ax.numel() == 0:
return vx
else:
return [vx, ax]
"""Recombine audio and video latents for output."""
# if ax.device != vx.device or ax.dtype != vx.dtype:
# logging.warning("Audio and video latents are on different devices or dtypes.")
# ax = ax.to(device=vx.device, dtype=vx.dtype)
# logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}")
#
# ax = ax.reshape(ax.shape[0], -1)
# # pad to f x h x w of the video latents
# divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3]
# if target_shape is None:
# repetitions = math.ceil(ax.shape[-1] / divisor)
# else:
# repetitions = target_shape[1] - vx.shape[1]
# padded_len = repetitions * divisor
# ax = F.pad(ax, (0, padded_len - ax.shape[-1]))
# ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1])
# return torch.cat([vx, ax], dim=1)
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
"""Process input for LTXAV - separate audio and video, then patchify."""
audio_length = kwargs.get("audio_length", 0)
# Separate audio and video latents
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
[vx, v_pixel_coords, additional_args] = super()._process_input(
vx, keyframe_idxs, denoise_mask, **kwargs
)
ax, a_latent_coords = self.a_patchifier.patchify(ax)
ax = self.audio_patchify_proj(ax)
# additional_args.update({"av_orig_shape": list(x.shape)})
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
"""Prepare timestep embeddings."""
# TODO: some code reuse is needed here.
grid_mask = kwargs.get("grid_mask", None)
if grid_mask is not None:
timestep = timestep[:, grid_mask]
timestep = timestep * self.timestep_scale_multiplier
v_timestep, v_embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1])
v_embedded_timestep = v_embedded_timestep.view(
batch_size, -1, v_embedded_timestep.shape[-1]
)
# Prepare audio timestep
a_timestep = kwargs.get("a_timestep")
if a_timestep is not None:
a_timestep = a_timestep * self.timestep_scale_multiplier
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
a_timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
timestep.flatten() * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
a_timestep.flatten() * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
a_timestep, a_embedded_timestep = self.audio_adaln_single(
a_timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
a_embedded_timestep = a_embedded_timestep.view(
batch_size, -1, a_embedded_timestep.shape[-1]
)
cross_av_timestep_ss = [
av_ca_audio_scale_shift_timestep,
av_ca_video_scale_shift_timestep,
av_ca_a2v_gate_noise_timestep,
av_ca_v2a_gate_noise_timestep,
]
cross_av_timestep_ss = list(
[t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss]
)
else:
a_timestep = timestep
a_embedded_timestep = kwargs.get("embedded_timestep")
cross_av_timestep_ss = []
return [v_timestep, a_timestep, cross_av_timestep_ss], [
v_embedded_timestep,
a_embedded_timestep,
]
def _prepare_context(self, context, batch_size, x, attention_mask=None):
vx = x[0]
ax = x[1]
v_context, a_context = torch.split(
context, int(context.shape[-1] / 2), len(context.shape) - 1
)
v_context, attention_mask = super()._prepare_context(
v_context, batch_size, vx, attention_mask
)
if self.audio_caption_projection is not None:
a_context = self.audio_caption_projection(a_context)
a_context = a_context.view(batch_size, -1, ax.shape[-1])
return [v_context, a_context], attention_mask
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
v_pixel_coords = pixel_coords[0]
v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype)
a_latent_coords = pixel_coords[1]
a_pe = self._precompute_freqs_cis(
a_latent_coords,
dim=self.audio_inner_dim,
out_dtype=x_dtype,
max_pos=self.audio_positional_embedding_max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.audio_num_attention_heads,
)
# calculate positional embeddings for the middle of the token duration, to use in av cross attention layers.
max_pos = max(
self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]
)
v_pixel_coords = v_pixel_coords.to(torch.float32)
v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate)
av_cross_video_freq_cis = self._precompute_freqs_cis(
v_pixel_coords[:, 0:1, :],
dim=self.audio_cross_attention_dim,
out_dtype=x_dtype,
max_pos=[max_pos],
use_middle_indices_grid=True,
num_attention_heads=self.audio_num_attention_heads,
)
av_cross_audio_freq_cis = self._precompute_freqs_cis(
a_latent_coords[:, 0:1, :],
dim=self.audio_cross_attention_dim,
out_dtype=x_dtype,
max_pos=[max_pos],
use_middle_indices_grid=True,
num_attention_heads=self.audio_num_attention_heads,
)
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
def _process_transformer_blocks(
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
):
vx = x[0]
ax = x[1]
v_context = context[0]
a_context = context[1]
v_timestep = timestep[0]
a_timestep = timestep[1]
v_pe, av_cross_video_freq_cis = pe[0]
a_pe, av_cross_audio_freq_cis = pe[1]
(
av_ca_audio_scale_shift_timestep,
av_ca_video_scale_shift_timestep,
av_ca_a2v_gate_noise_timestep,
av_ca_v2a_gate_noise_timestep,
) = timestep[2]
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
# Process transformer blocks
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(
args["img"],
v_context=args["v_context"],
a_context=args["a_context"],
attention_mask=args["attention_mask"],
v_timestep=args["v_timestep"],
a_timestep=args["a_timestep"],
v_pe=args["v_pe"],
a_pe=args["a_pe"],
v_cross_pe=args["v_cross_pe"],
a_cross_pe=args["a_cross_pe"],
v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"],
a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"],
v_cross_gate_timestep=args["v_cross_gate_timestep"],
a_cross_gate_timestep=args["a_cross_gate_timestep"],
transformer_options=args["transformer_options"],
)
return out
out = blocks_replace[("double_block", i)](
{
"img": (vx, ax),
"v_context": v_context,
"a_context": a_context,
"attention_mask": attention_mask,
"v_timestep": v_timestep,
"a_timestep": a_timestep,
"v_pe": v_pe,
"a_pe": a_pe,
"v_cross_pe": av_cross_video_freq_cis,
"a_cross_pe": av_cross_audio_freq_cis,
"v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep,
"a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep,
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
"transformer_options": transformer_options,
},
{"original_block": block_wrap},
)
vx, ax = out["img"]
else:
vx, ax = block(
(vx, ax),
v_context=v_context,
a_context=a_context,
attention_mask=attention_mask,
v_timestep=v_timestep,
a_timestep=a_timestep,
v_pe=v_pe,
a_pe=a_pe,
v_cross_pe=av_cross_video_freq_cis,
a_cross_pe=av_cross_audio_freq_cis,
v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep,
a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep,
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
transformer_options=transformer_options,
)
return [vx, ax]
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
vx = x[0]
ax = x[1]
v_embedded_timestep = embedded_timestep[0]
a_embedded_timestep = embedded_timestep[1]
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
# Process audio output
a_scale_shift_values = (
self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype)
+ a_embedded_timestep[:, :, None]
)
a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1]
ax = self.audio_norm_out(ax)
ax = ax * (1 + a_scale) + a_shift
ax = self.audio_proj_out(ax)
# Unpatchify audio
ax = self.a_patchifier.unpatchify(
ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins
)
# Recombine audio and video
original_shape = kwargs.get("av_orig_shape")
return self.recombine_audio_and_video_latents(vx, ax, original_shape)
def forward(
self,
x,
timestep,
context,
attention_mask=None,
frame_rate=25,
transformer_options={},
keyframe_idxs=None,
**kwargs,
):
"""
Forward pass for LTXAV model.
Args:
x: Combined audio-video input tensor
timestep: Tuple of (video_timestep, audio_timestep) or single timestep
context: Context tensor (e.g., text embeddings)
attention_mask: Attention mask tensor
frame_rate: Frame rate for temporal processing
transformer_options: Additional options for transformer blocks
keyframe_idxs: Keyframe indices for temporal processing
**kwargs: Additional keyword arguments including audio_length
Returns:
Combined audio-video output tensor
"""
# Handle timestep format
if isinstance(timestep, (tuple, list)) and len(timestep) == 2:
v_timestep, a_timestep = timestep
kwargs["a_timestep"] = a_timestep
timestep = v_timestep
else:
kwargs["a_timestep"] = timestep
# Call parent forward method
return super().forward(
x,
timestep,
context,
attention_mask,
frame_rate,
transformer_options,
keyframe_idxs,
**kwargs,
)

View File

@ -0,0 +1,305 @@
import math
from typing import Optional
import comfy.ldm.common_dit
import torch
from comfy.ldm.lightricks.model import (
CrossAttention,
FeedForward,
generate_freq_grid_np,
interleaved_freqs_cis,
split_freqs_cis,
)
from torch import nn
class BasicTransformerBlock1D(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers.
qk_norm (`str`, *optional*, defaults to None):
Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`.
ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer.
attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer.
use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE).
ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer.
"""
def __init__(
self,
dim,
n_heads,
d_head,
context_dim=None,
attn_precision=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
context_dim=None,
dtype=dtype,
device=device,
operations=operations,
)
# 3. Feed-forward
self.ff = FeedForward(
dim,
dim_out=dim,
glu=True,
dtype=dtype,
device=device,
operations=operations,
)
def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Normalization Before Self-Attention
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
norm_hidden_states = norm_hidden_states.squeeze(1)
# 2. Self-Attention
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 3. Normalization before Feed-Forward
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
# 4. Feed-forward
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class Embeddings1DConnector(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=128,
num_attention_heads=30,
num_layers=2,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[4096],
causal_temporal_positioning=False,
num_learnable_registers: Optional[int] = 128,
dtype=None,
device=None,
operations=None,
split_rope=False,
double_precision_rope=False,
**kwargs,
):
super().__init__()
self.dtype = dtype
self.out_channels = in_channels
self.num_attention_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.positional_embedding_theta = positional_embedding_theta
self.positional_embedding_max_pos = positional_embedding_max_pos
self.split_rope = split_rope
self.double_precision_rope = double_precision_rope
self.transformer_1d_blocks = nn.ModuleList(
[
BasicTransformerBlock1D(
self.inner_dim,
num_attention_heads,
attention_head_dim,
context_dim=cross_attention_dim,
dtype=dtype,
device=device,
operations=operations,
)
for _ in range(num_layers)
]
)
inner_dim = num_attention_heads * attention_head_dim
self.num_learnable_registers = num_learnable_registers
if self.num_learnable_registers:
self.learnable_registers = nn.Parameter(
torch.rand(
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
)
* 2.0
- 1.0
)
def get_fractional_positions(self, indices_grid):
fractional_positions = torch.stack(
[
indices_grid[:, i] / self.positional_embedding_max_pos[i]
for i in range(1)
],
dim=-1,
)
return fractional_positions
def precompute_freqs(self, indices_grid, spacing):
source_dtype = indices_grid.dtype
dtype = (
torch.float32
if source_dtype in (torch.bfloat16, torch.float16)
else source_dtype
)
fractional_positions = self.get_fractional_positions(indices_grid)
indices = (
generate_freq_grid_np(
self.positional_embedding_theta,
indices_grid.shape[1],
self.inner_dim,
)
if self.double_precision_rope
else self.generate_freq_grid(spacing, dtype, fractional_positions.device)
).to(device=fractional_positions.device)
if spacing == "exp_2":
freqs = (
(indices * fractional_positions.unsqueeze(-1))
.transpose(-1, -2)
.flatten(2)
)
else:
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
return freqs
def generate_freq_grid(self, spacing, dtype, device):
dim = self.inner_dim
theta = self.positional_embedding_theta
n_pos_dims = 1
n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6
start = 1
end = theta
if spacing == "exp":
indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem))
indices = indices.to(dtype=dtype, device=device)
elif spacing == "exp_2":
indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim)
indices = indices.to(dtype=dtype)
elif spacing == "linear":
indices = torch.linspace(
start, end, dim // n_elem, device=device, dtype=dtype
)
elif spacing == "sqrt":
indices = torch.linspace(
start**2, end**2, dim // n_elem, device=device, dtype=dtype
).sqrt()
indices = indices * math.pi / 2
return indices
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
dim = self.inner_dim
n_elem = 2 # 2 because of cos and sin
freqs = self.precompute_freqs(indices_grid, spacing)
if self.split_rope:
expected_freqs = dim // 2
current_freqs = freqs.shape[-1]
pad_size = expected_freqs - current_freqs
cos_freq, sin_freq = split_freqs_cis(
freqs, pad_size, self.num_attention_heads
)
else:
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# 1. Input
if self.num_learnable_registers:
num_registers_duplications = math.ceil(
max(1024, hidden_states.shape[1]) / self.num_learnable_registers
)
learnable_registers = torch.tile(
self.learnable_registers.to(hidden_states), (num_registers_duplications, 1)
)
hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1)
if attention_mask is not None:
attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device)
indices_grid = torch.arange(
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
)
indices_grid = indices_grid[None, None, :]
freqs_cis = self.precompute_freqs_cis(indices_grid)
# 2. Blocks
for block_idx, block in enumerate(self.transformer_1d_blocks):
hidden_states = block(
hidden_states, attention_mask=attention_mask, pe=freqs_cis
)
# 3. Output
# if self.output_scale is not None:
# hidden_states = hidden_states / self.output_scale
hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
return hidden_states, attention_mask

View File

@ -0,0 +1,292 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
def _rational_for_scale(scale: float) -> Tuple[int, int]:
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
if float(scale) not in mapping:
raise ValueError(
f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}"
)
return mapping[float(scale)]
class PixelShuffleND(nn.Module):
def __init__(self, dims, upscale_factors=(2, 2, 2)):
super().__init__()
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
self.dims = dims
self.upscale_factors = upscale_factors
def forward(self, x):
if self.dims == 3:
return rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
p3=self.upscale_factors[2],
)
elif self.dims == 2:
return rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
)
elif self.dims == 1:
return rearrange(
x,
"b (c p1) f h w -> b c (f p1) h w",
p1=self.upscale_factors[0],
)
class BlurDownsample(nn.Module):
"""
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
"""
def __init__(self, dims: int, stride: int):
super().__init__()
assert dims in (2, 3)
assert stride >= 1 and isinstance(stride, int)
self.dims = dims
self.stride = stride
# 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized
k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0])
k2d = k[:, None] @ k[None, :]
k2d = (k2d / k2d.sum()).float() # shape (5,5)
self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.stride == 1:
return x
def _apply_2d(x2d: torch.Tensor) -> torch.Tensor:
# x2d: (B, C, H, W)
B, C, H, W = x2d.shape
weight = self.kernel.expand(C, 1, 5, 5) # depthwise
x2d = F.conv2d(
x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C
)
return x2d
if self.dims == 2:
return _apply_2d(x)
else:
# dims == 3: apply per-frame on H,W
b, c, f, h, w = x.shape
x = rearrange(x, "b c f h w -> (b f) c h w")
x = _apply_2d(x)
h2, w2 = x.shape[-2:]
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
return x
class SpatialRationalResampler(nn.Module):
"""
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
downsample by 'den' using fixed blur + stride. Operates on H,W only.
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
"""
def __init__(self, mid_channels: int, scale: float):
super().__init__()
self.scale = float(scale)
self.num, self.den = _rational_for_scale(self.scale)
self.conv = nn.Conv2d(
mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1
)
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
self.blur_down = BlurDownsample(dims=2, stride=self.den)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, f, h, w = x.shape
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.blur_down(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
return x
class ResBlock(nn.Module):
def __init__(
self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
):
super().__init__()
if mid_channels is None:
mid_channels = channels
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
self.norm1 = nn.GroupNorm(32, mid_channels)
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
self.norm2 = nn.GroupNorm(32, channels)
self.activation = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.activation(x + residual)
return x
class LatentUpsampler(nn.Module):
"""
Model to spatially upsample VAE latents.
Args:
in_channels (`int`): Number of channels in the input latent
mid_channels (`int`): Number of channels in the middle layers
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
dims (`int`): Number of dimensions for convolutions (2 or 3)
spatial_upsample (`bool`): Whether to spatially upsample the latent
temporal_upsample (`bool`): Whether to temporally upsample the latent
"""
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 512,
num_blocks_per_stage: int = 4,
dims: int = 3,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
spatial_scale: float = 2.0,
rational_resampler: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.num_blocks_per_stage = num_blocks_per_stage
self.dims = dims
self.spatial_upsample = spatial_upsample
self.temporal_upsample = temporal_upsample
self.spatial_scale = float(spatial_scale)
self.rational_resampler = rational_resampler
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = nn.GroupNorm(32, mid_channels)
self.initial_activation = nn.SiLU()
self.res_blocks = nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
if spatial_upsample and temporal_upsample:
self.upsampler = nn.Sequential(
nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(3),
)
elif spatial_upsample:
if rational_resampler:
self.upsampler = SpatialRationalResampler(
mid_channels=mid_channels, scale=self.spatial_scale
)
else:
self.upsampler = nn.Sequential(
nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(2),
)
elif temporal_upsample:
self.upsampler = nn.Sequential(
nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(1),
)
else:
raise ValueError(
"Either spatial_upsample or temporal_upsample must be True"
)
self.post_upsample_res_blocks = nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
def forward(self, latent: torch.Tensor) -> torch.Tensor:
b, c, f, h, w = latent.shape
if self.dims == 2:
x = rearrange(latent, "b c f h w -> (b f) c h w")
x = self.initial_conv(x)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
x = self.upsampler(x)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
else:
x = self.initial_conv(latent)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
if self.temporal_upsample:
x = self.upsampler(x)
x = x[:, :, 1:, :, :]
else:
if isinstance(self.upsampler, SpatialRationalResampler):
x = self.upsampler(x)
else:
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.upsampler(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
return x
@classmethod
def from_config(cls, config):
return cls(
in_channels=config.get("in_channels", 4),
mid_channels=config.get("mid_channels", 128),
num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
dims=config.get("dims", 2),
spatial_upsample=config.get("spatial_upsample", True),
temporal_upsample=config.get("temporal_upsample", False),
spatial_scale=config.get("spatial_scale", 2.0),
rational_resampler=config.get("rational_resampler", False),
)
def config(self):
return {
"_class_name": "LatentUpsampler",
"in_channels": self.in_channels,
"mid_channels": self.mid_channels,
"num_blocks_per_stage": self.num_blocks_per_stage,
"dims": self.dims,
"spatial_upsample": self.spatial_upsample,
"temporal_upsample": self.temporal_upsample,
"spatial_scale": self.spatial_scale,
"rational_resampler": self.rational_resampler,
}

View File

@ -1,13 +1,47 @@
from abc import ABC, abstractmethod
from enum import Enum
import functools
import math
from typing import Dict, Optional, Tuple
from einops import rearrange
import numpy as np
import torch import torch
from torch import nn from torch import nn
import comfy.patcher_extension import comfy.patcher_extension
import comfy.ldm.modules.attention import comfy.ldm.modules.attention
import comfy.ldm.common_dit import comfy.ldm.common_dit
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy.ldm.flux.math import apply_rope1
def _log_base(x, base):
return np.log(x) / np.log(base)
class LTXRopeType(str, Enum):
INTERLEAVED = "interleaved"
SPLIT = "split"
KEY = "rope_type"
@classmethod
def from_dict(cls, kwargs, default=None):
if default is None:
default = cls.INTERLEAVED
return cls(kwargs.get(cls.KEY, default))
class LTXFrequenciesPrecision(str, Enum):
FLOAT32 = "float32"
FLOAT64 = "float64"
KEY = "frequencies_precision"
@classmethod
def from_dict(cls, kwargs, default=None):
if default is None:
default = cls.FLOAT32
return cls(kwargs.get(cls.KEY, default))
def get_timestep_embedding( def get_timestep_embedding(
timesteps: torch.Tensor, timesteps: torch.Tensor,
@ -39,9 +73,7 @@ def get_timestep_embedding(
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange( exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift) exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent) emb = torch.exp(exponent)
@ -73,7 +105,9 @@ class TimestepEmbedding(nn.Module):
post_act_fn: Optional[str] = None, post_act_fn: Optional[str] = None,
cond_proj_dim=None, cond_proj_dim=None,
sample_proj_bias=True, sample_proj_bias=True,
dtype=None, device=None, operations=None, dtype=None,
device=None,
operations=None,
): ):
super().__init__() super().__init__()
@ -90,7 +124,9 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = out_dim time_embed_dim_out = out_dim
else: else:
time_embed_dim_out = time_embed_dim time_embed_dim_out = time_embed_dim
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device) self.linear_2 = operations.Linear(
time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device
)
if post_act_fn is None: if post_act_fn is None:
self.post_act = None self.post_act = None
@ -139,12 +175,22 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
""" """
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): def __init__(
self,
embedding_dim,
size_emb_dim,
use_additional_conditions: bool = False,
dtype=None,
device=None,
operations=None,
):
super().__init__() super().__init__()
self.outdim = size_emb_dim self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations) self.timestep_embedder = TimestepEmbedding(
in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations
)
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep) timesteps_proj = self.time_proj(timestep)
@ -163,15 +209,22 @@ class AdaLayerNormSingle(nn.Module):
use_additional_conditions (`bool`): To use additional conditions for normalization or not. use_additional_conditions (`bool`): To use additional conditions for normalization or not.
""" """
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): def __init__(
self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None
):
super().__init__() super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations embedding_dim,
size_emb_dim=embedding_dim // 3,
use_additional_conditions=use_additional_conditions,
dtype=dtype,
device=device,
operations=operations,
) )
self.silu = nn.SiLU() self.silu = nn.SiLU()
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device) self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device)
def forward( def forward(
self, self,
@ -185,6 +238,7 @@ class AdaLayerNormSingle(nn.Module):
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep return self.linear(self.silu(embedded_timestep)), embedded_timestep
class PixArtAlphaTextProjection(nn.Module): class PixArtAlphaTextProjection(nn.Module):
""" """
Projects caption embeddings. Also handles dropout for classifier-free guidance. Projects caption embeddings. Also handles dropout for classifier-free guidance.
@ -192,18 +246,24 @@ class PixArtAlphaTextProjection(nn.Module):
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
""" """
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None): def __init__(
self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None
):
super().__init__() super().__init__()
if out_features is None: if out_features is None:
out_features = hidden_size out_features = hidden_size
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device) self.linear_1 = operations.Linear(
in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device
)
if act_fn == "gelu_tanh": if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh") self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu": elif act_fn == "silu":
self.act_1 = nn.SiLU() self.act_1 = nn.SiLU()
else: else:
raise ValueError(f"Unknown activation function: {act_fn}") raise ValueError(f"Unknown activation function: {act_fn}")
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device) self.linear_2 = operations.Linear(
in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device
)
def forward(self, caption): def forward(self, caption):
hidden_states = self.linear_1(caption) hidden_states = self.linear_1(caption)
@ -222,23 +282,68 @@ class GELU_approx(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None): def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations) project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
self.net = nn.Sequential( self.net = nn.Sequential(
project_in, project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
nn.Dropout(dropout),
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
) )
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
def apply_rotary_emb(input_tensor, freqs_cis):
cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1]
split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False
return (
apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs)
if split_pe else
apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs)
)
def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
def apply_split_rotary_emb(input_tensor, cos, sin):
needs_reshape = False
if input_tensor.ndim != 4 and cos.ndim == 4:
B, H, T, _ = cos.shape
input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2)
needs_reshape = True
split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2)
first_half_input = split_input[..., :1, :]
second_half_input = split_input[..., 1:, :]
output = split_input * cos.unsqueeze(-2)
first_half_output = output[..., :1, :]
second_half_output = output[..., 1:, :]
first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input)
second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input)
output = rearrange(output, "... d r -> ... (d r)")
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None): def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
attn_precision=None,
dtype=None,
device=None,
operations=None,
):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim context_dim = query_dim if context_dim is None else context_dim
@ -254,9 +359,11 @@ class CrossAttention(nn.Module):
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}): def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}):
q = self.to_q(x) q = self.to_q(x)
context = x if context is None else context context = x if context is None else context
k = self.to_k(context) k = self.to_k(context)
@ -266,8 +373,8 @@ class CrossAttention(nn.Module):
k = self.k_norm(k) k = self.k_norm(k)
if pe is not None: if pe is not None:
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1) q = apply_rotary_emb(q, pe)
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1) k = apply_rotary_emb(k, pe if k_pe is None else k_pe)
if mask is None: if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
@ -277,14 +384,34 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None): def __init__(
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
):
super().__init__() super().__init__()
self.attn_precision = attn_precision self.attn_precision = attn_precision
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
context_dim=None,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations) self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
@ -306,116 +433,446 @@ class BasicTransformerBlock(nn.Module):
return x return x
def get_fractional_positions(indices_grid, max_pos): def get_fractional_positions(indices_grid, max_pos):
n_pos_dims = indices_grid.shape[1]
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
fractional_positions = torch.stack( fractional_positions = torch.stack(
[ [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],
indices_grid[:, i] / max_pos[i] axis=-1,
for i in range(3)
],
dim=-1,
) )
return fractional_positions return fractional_positions
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]): @functools.lru_cache(maxsize=5)
dtype = torch.float32 def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None):
device = indices_grid.device theta = positional_embedding_theta
start = 1
end = theta
n_elem = 2 * positional_embedding_max_pos_count
pow_indices = np.power(
theta,
np.linspace(
_log_base(start, theta),
_log_base(end, theta),
inner_dim // n_elem,
dtype=np.float64,
),
)
return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)
def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device):
theta = positional_embedding_theta
start = 1
end = theta
n_elem = 2 * positional_embedding_max_pos_count
indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
inner_dim // n_elem,
device=device,
dtype=torch.float32,
)
)
indices = indices.to(dtype=torch.float32)
indices = indices * math.pi / 2
return indices
def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid):
if use_middle_indices_grid:
assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2)
indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
elif len(indices_grid.shape) == 4:
indices_grid = indices_grid[..., 0]
# Get fractional positions and compute frequency indices # Get fractional positions and compute frequency indices
fractional_positions = get_fractional_positions(indices_grid, max_pos) fractional_positions = get_fractional_positions(indices_grid, max_pos)
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2 indices = indices.to(device=fractional_positions.device)
# Compute frequencies and apply cos/sin freqs = (
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) (indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
cos_vals = freqs.cos().repeat_interleave(2, dim=-1) .transpose(-1, -2)
sin_vals = freqs.sin().repeat_interleave(2, dim=-1) .flatten(2)
)
return freqs
# Pad if dim is not divisible by 6 def interleaved_freqs_cis(freqs, pad_size):
if dim % 6 != 0: cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
padding_size = dim % 6 sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1) if pad_size != 0:
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1) cos_padding = torch.ones_like(cos_freq[:, :, : pad_size])
sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq, sin_freq
# Reshape and extract one value per pair (since repeat_interleave duplicates each value) def split_freqs_cis(freqs, pad_size, num_attention_heads):
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] cos_freq = freqs.cos()
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] sin_freq = freqs.sin()
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension if pad_size != 0:
freqs_cis = torch.stack([ cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
torch.stack([cos_vals, -sin_vals], dim=-1), sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
torch.stack([sin_vals, cos_vals], dim=-1)
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
return freqs_cis cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape freqs to be compatible with multi-head attention
B , T, half_HD = cos_freq.shape
class LTXVModel(torch.nn.Module): cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
def __init__(self, sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=64,
num_attention_heads=32,
caption_channels=4096, cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
num_layers=28, sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
return cos_freq, sin_freq
class LTXBaseModel(torch.nn.Module, ABC):
"""
Abstract base class for LTX models (Lightricks Transformer models).
positional_embedding_theta=10000.0, This class defines the common interface and shared functionality for all LTX models,
positional_embedding_max_pos=[20, 2048, 2048], including LTXV (video) and LTXAV (audio-video) variants.
causal_temporal_positioning=False, """
vae_scale_factors=(8, 32, 32),
dtype=None, device=None, operations=None, **kwargs): def __init__(
self,
in_channels: int,
cross_attention_dim: int,
attention_head_dim: int,
num_attention_heads: int,
caption_channels: int,
num_layers: int,
positional_embedding_theta: float = 10000.0,
positional_embedding_max_pos: list = [20, 2048, 2048],
causal_temporal_positioning: bool = False,
vae_scale_factors: tuple = (8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier = 1000.0,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__() super().__init__()
self.generator = None self.generator = None
self.vae_scale_factors = vae_scale_factors self.vae_scale_factors = vae_scale_factors
self.use_middle_indices_grid = use_middle_indices_grid
self.dtype = dtype self.dtype = dtype
self.out_channels = in_channels self.in_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim self.cross_attention_dim = cross_attention_dim
self.attention_head_dim = attention_head_dim
self.num_attention_heads = num_attention_heads
self.caption_channels = caption_channels
self.num_layers = num_layers
self.positional_embedding_theta = positional_embedding_theta
self.positional_embedding_max_pos = positional_embedding_max_pos
self.split_positional_embedding = LTXRopeType.from_dict(kwargs)
self.freq_grid_generator = (
generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64
else generate_freq_grid_pytorch
)
self.causal_temporal_positioning = causal_temporal_positioning self.causal_temporal_positioning = causal_temporal_positioning
self.operations = operations
self.timestep_scale_multiplier = timestep_scale_multiplier
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device) # Common dimensions
self.inner_dim = num_attention_heads * attention_head_dim
self.out_channels = in_channels
# Initialize common components
self._init_common_components(device, dtype)
# Initialize model-specific components
self._init_model_components(device, dtype, **kwargs)
# Initialize transformer blocks
self._init_transformer_blocks(device, dtype, **kwargs)
# Initialize output components
self._init_output_components(device, dtype)
def _init_common_components(self, device, dtype):
"""Initialize components common to all LTX models
- patchify_proj: Linear projection for patchifying input
- adaln_single: AdaLN layer for timestep embedding
- caption_projection: Linear projection for caption embedding
"""
self.patchify_proj = self.operations.Linear(
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
)
self.adaln_single = AdaLayerNormSingle( self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
) )
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
self.caption_projection = PixArtAlphaTextProjection( self.caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations in_features=self.caption_channels,
hidden_size=self.inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
) )
@abstractmethod
def _init_model_components(self, device, dtype, **kwargs):
"""Initialize model-specific components. Must be implemented by subclasses."""
pass
@abstractmethod
def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks. Must be implemented by subclasses."""
pass
@abstractmethod
def _init_output_components(self, device, dtype):
"""Initialize output components. Must be implemented by subclasses."""
pass
@abstractmethod
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
"""Process input data. Must be implemented by subclasses."""
pass
@abstractmethod
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
"""Process transformer blocks. Must be implemented by subclasses."""
pass
@abstractmethod
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
"""Process output data. Must be implemented by subclasses."""
pass
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
"""Prepare timestep embeddings."""
grid_mask = kwargs.get("grid_mask", None)
if grid_mask is not None:
timestep = timestep[:, grid_mask]
timestep = timestep * self.timestep_scale_multiplier
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
return timestep, embedded_timestep
def _prepare_context(self, context, batch_size, x, attention_mask=None):
"""Prepare context for transformer blocks."""
if self.caption_projection is not None:
context = self.caption_projection(context)
context = context.view(batch_size, -1, x.shape[-1])
return context, attention_mask
def _precompute_freqs_cis(
self,
indices_grid,
dim,
out_dtype,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=False,
num_attention_heads=32,
):
split_mode = self.split_positional_embedding == LTXRopeType.SPLIT
indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device)
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
if split_mode:
expected_freqs = dim // 2
current_freqs = freqs.shape[-1]
pad_size = expected_freqs - current_freqs
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
else:
# 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only
n_elem = 2 * indices_grid.shape[1]
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
"""Prepare positional embeddings."""
fractional_coords = pixel_coords.to(torch.float32)
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
pe = self._precompute_freqs_cis(
fractional_coords,
dim=self.inner_dim,
out_dtype=x_dtype,
max_pos=self.positional_embedding_max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
)
return pe
def _prepare_attention_mask(self, attention_mask, x_dtype):
"""Prepare attention mask."""
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x_dtype).reshape(
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
) * torch.finfo(x_dtype).max
return attention_mask
def forward(
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
):
"""
Forward pass for LTX models.
Args:
x: Input tensor
timestep: Timestep tensor
context: Context tensor (e.g., text embeddings)
attention_mask: Attention mask tensor
frame_rate: Frame rate for temporal processing
transformer_options: Additional options for transformer blocks
keyframe_idxs: Keyframe indices for temporal processing
**kwargs: Additional keyword arguments
Returns:
Processed output tensor
"""
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(
comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options
),
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs)
def _forward(
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
):
"""
Internal forward pass for LTX models.
Args:
x: Input tensor
timestep: Timestep tensor
context: Context tensor (e.g., text embeddings)
attention_mask: Attention mask tensor
frame_rate: Frame rate for temporal processing
transformer_options: Additional options for transformer blocks
keyframe_idxs: Keyframe indices for temporal processing
**kwargs: Additional keyword arguments
Returns:
Processed output tensor
"""
if isinstance(x, list):
input_dtype = x[0].dtype
batch_size = x[0].shape[0]
else:
input_dtype = x.dtype
batch_size = x.shape[0]
# Process input
merged_args = {**transformer_options, **kwargs}
x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args)
merged_args.update(additional_args)
# Prepare timestep and context
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
# Prepare attention mask and positional embeddings
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
# Process transformer blocks
x = self._process_transformer_blocks(
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
)
# Process output
x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args)
return x
class LTXVModel(LTXBaseModel):
"""LTXV model for video generation."""
def __init__(
self,
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=64,
num_attention_heads=32,
caption_channels=4096,
num_layers=28,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier = 1000.0,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__(
in_channels=in_channels,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
caption_channels=caption_channels,
num_layers=num_layers,
positional_embedding_theta=positional_embedding_theta,
positional_embedding_max_pos=positional_embedding_max_pos,
causal_temporal_positioning=causal_temporal_positioning,
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
dtype=dtype,
device=device,
operations=operations,
**kwargs,
)
def _init_model_components(self, device, dtype, **kwargs):
"""Initialize LTXV-specific components."""
# No additional components needed for LTXV beyond base class
pass
def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks for LTXV."""
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[ [
BasicTransformerBlock( BasicTransformerBlock(
self.inner_dim, self.inner_dim,
num_attention_heads, self.num_attention_heads,
attention_head_dim, self.attention_head_dim,
context_dim=cross_attention_dim, context_dim=self.cross_attention_dim,
# attn_precision=attn_precision, dtype=dtype,
dtype=dtype, device=device, operations=operations device=device,
operations=self.operations,
) )
for d in range(num_layers) for _ in range(self.num_layers)
] ]
) )
def _init_output_components(self, device, dtype):
"""Initialize output components for LTXV."""
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device)) self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.norm_out = self.operations.LayerNorm(
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
)
self.patchifier = SymmetricPatchifier(1) self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
self.patchifier = SymmetricPatchifier(1, start_end=True)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
orig_shape = list(x.shape)
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
"""Process input for LTXV."""
additional_args = {"orig_shape": list(x.shape)}
x, latent_coords = self.patchifier.patchify(x) x, latent_coords = self.patchifier.patchify(x)
pixel_coords = latent_to_pixel_coords( pixel_coords = latent_to_pixel_coords(
latent_coords=latent_coords, latent_coords=latent_coords,
@ -423,44 +880,30 @@ class LTXVModel(torch.nn.Module):
causal_fix=self.causal_temporal_positioning, causal_fix=self.causal_temporal_positioning,
) )
grid_mask = None
if keyframe_idxs is not None: if keyframe_idxs is not None:
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs additional_args.update({ "orig_patchified_shape": list(x.shape)})
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
additional_args.update({"grid_mask": grid_mask})
x = x[:, grid_mask, :]
pixel_coords = pixel_coords[:, :, grid_mask, ...]
fractional_coords = pixel_coords.to(torch.float32) kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
x = self.patchify_proj(x) x = self.patchify_proj(x)
timestep = timestep * 1000.0 return x, pixel_coords, additional_args
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=x.dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(
batch_size, -1, embedded_timestep.shape[-1]
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = x.shape[0]
context = self.caption_projection(context)
context = context.view(
batch_size, -1, x.shape[-1]
)
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
"""Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
@ -478,16 +921,28 @@ class LTXVModel(torch.nn.Module):
transformer_options=transformer_options, transformer_options=transformer_options,
) )
# 3. Output return x
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
"""Process output for LTXV."""
# Apply scale-shift modulation
scale_shift_values = ( scale_shift_values = (
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
) )
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x) x = self.norm_out(x)
# Modulation x = x * (1 + scale) + shift
x = torch.addcmul(x, x, scale).add_(shift)
x = self.proj_out(x) x = self.proj_out(x)
if keyframe_idxs is not None:
grid_mask = kwargs["grid_mask"]
orig_patchified_shape = kwargs["orig_patchified_shape"]
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
full_x[:, grid_mask, :] = x
x = full_x
# Unpatchify to restore original dimensions
orig_shape = kwargs["orig_shape"]
x = self.patchifier.unpatchify( x = self.patchifier.unpatchify(
latents=x, latents=x,
output_height=orig_shape[3], output_height=orig_shape[3],

View File

@ -21,20 +21,23 @@ def latent_to_pixel_coords(
Returns: Returns:
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
""" """
shape = [1] * latent_coords.ndim
shape[1] = -1
pixel_coords = ( pixel_coords = (
latent_coords latent_coords
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] * torch.tensor(scale_factors, device=latent_coords.device).view(*shape)
) )
if causal_fix: if causal_fix:
# Fix temporal scale for first frame to 1 due to causality # Fix temporal scale for first frame to 1 due to causality
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords return pixel_coords
class Patchifier(ABC): class Patchifier(ABC):
def __init__(self, patch_size: int): def __init__(self, patch_size: int, start_end: bool=False):
super().__init__() super().__init__()
self._patch_size = (1, patch_size, patch_size) self._patch_size = (1, patch_size, patch_size)
self.start_end = start_end
@abstractmethod @abstractmethod
def patchify( def patchify(
@ -71,11 +74,23 @@ class Patchifier(ABC):
torch.arange(0, latent_width, self._patch_size[2], device=device), torch.arange(0, latent_width, self._patch_size[2], device=device),
indexing="ij", indexing="ij",
) )
latent_sample_coords = torch.stack(latent_sample_coords, dim=0) latent_sample_coords_start = torch.stack(latent_sample_coords, dim=0)
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) delta = torch.tensor(self._patch_size, device=latent_sample_coords_start.device, dtype=latent_sample_coords_start.dtype)[:, None, None, None]
latent_coords = rearrange( latent_sample_coords_end = latent_sample_coords_start + delta
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
latent_sample_coords_start = latent_sample_coords_start.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
latent_sample_coords_start = rearrange(
latent_sample_coords_start, "b c f h w -> b c (f h w)", b=batch_size
) )
if self.start_end:
latent_sample_coords_end = latent_sample_coords_end.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
latent_sample_coords_end = rearrange(
latent_sample_coords_end, "b c f h w -> b c (f h w)", b=batch_size
)
latent_coords = torch.stack((latent_sample_coords_start, latent_sample_coords_end), dim=-1)
else:
latent_coords = latent_sample_coords_start
return latent_coords return latent_coords
@ -115,3 +130,61 @@ class SymmetricPatchifier(Patchifier):
q=self._patch_size[2], q=self._patch_size[2],
) )
return latents return latents
class AudioPatchifier(Patchifier):
def __init__(self, patch_size: int,
sample_rate=16000,
hop_length=160,
audio_latent_downsample_factor=4,
is_causal=True,
start_end=False,
shift = 0
):
super().__init__(patch_size, start_end=start_end)
self.hop_length = hop_length
self.sample_rate = sample_rate
self.audio_latent_downsample_factor = audio_latent_downsample_factor
self.is_causal = is_causal
self.shift = shift
def copy_with_shift(self, shift):
return AudioPatchifier(
self.patch_size, self.sample_rate, self.hop_length, self.audio_latent_downsample_factor,
self.is_causal, self.start_end, shift
)
def _get_audio_latent_time_in_sec(self, start_latent, end_latent: int, dtype: torch.dtype, device=torch.device):
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
if self.is_causal:
audio_mel_frame = (audio_mel_frame + 1 - self.audio_latent_downsample_factor).clip(min=0)
return audio_mel_frame * self.hop_length / self.sample_rate
def patchify(self, audio_latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# audio_latents: (batch, channels, time, freq)
b, _, t, _ = audio_latents.shape
audio_latents = rearrange(
audio_latents,
"b c t f -> b t (c f)",
)
audio_latents_start_timings = self._get_audio_latent_time_in_sec(self.shift, t + self.shift, torch.float32, audio_latents.device)
audio_latents_start_timings = audio_latents_start_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
if self.start_end:
audio_latents_end_timings = self._get_audio_latent_time_in_sec(self.shift + 1, t + self.shift + 1, torch.float32, audio_latents.device)
audio_latents_end_timings = audio_latents_end_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
audio_latents_timings = torch.stack([audio_latents_start_timings, audio_latents_end_timings], dim=-1)
else:
audio_latents_timings = audio_latents_start_timings
return audio_latents, audio_latents_timings
def unpatchify(self, audio_latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor:
# audio_latents: (batch, time, freq * channels)
audio_latents = rearrange(
audio_latents, "b t (c f) -> b c t f", c=channels, f=freq
)
return audio_latents

View File

@ -0,0 +1,286 @@
import json
from dataclasses import dataclass
import math
import torch
import torchaudio
import comfy.model_management
import comfy.model_patcher
import comfy.utils as utils
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
CausalityAxis,
CausalAudioAutoencoder,
)
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
LATENT_DOWNSAMPLE_FACTOR = 4
@dataclass(frozen=True)
class AudioVAEComponentConfig:
"""Container for model component configuration extracted from metadata."""
autoencoder: dict
vocoder: dict
@classmethod
def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig":
assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE"
raw_config = metadata["config"]
if isinstance(raw_config, str):
parsed_config = json.loads(raw_config)
else:
parsed_config = raw_config
audio_config = parsed_config.get("audio_vae")
vocoder_config = parsed_config.get("vocoder")
assert audio_config is not None, "Audio VAE config is required for audio VAE"
assert vocoder_config is not None, "Vocoder config is required for audio VAE"
return cls(autoencoder=audio_config, vocoder=vocoder_config)
class ModelDeviceManager:
"""Manages device placement and GPU residency for the composed model."""
def __init__(self, module: torch.nn.Module):
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.vae_offload_device()
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
def ensure_model_loaded(self) -> None:
comfy.model_management.free_memory(
self.patcher.model_size(),
self.patcher.load_device,
)
comfy.model_management.load_model_gpu(self.patcher)
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self.patcher.load_device)
@property
def load_device(self):
return self.patcher.load_device
class AudioLatentNormalizer:
"""Applies per-channel statistics in patch space and restores original layout."""
def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module):
self.patchifier = patchfier
self.statistics = statistics_processor
def normalize(self, latents: torch.Tensor) -> torch.Tensor:
channels = latents.shape[1]
freq = latents.shape[3]
patched, _ = self.patchifier.patchify(latents)
normalized = self.statistics.normalize(patched)
return self.patchifier.unpatchify(normalized, channels=channels, freq=freq)
def denormalize(self, latents: torch.Tensor) -> torch.Tensor:
channels = latents.shape[1]
freq = latents.shape[3]
patched, _ = self.patchifier.patchify(latents)
denormalized = self.statistics.un_normalize(patched)
return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq)
class AudioPreprocessor:
"""Prepares raw waveforms for the autoencoder by matching training conditions."""
def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int):
self.target_sample_rate = target_sample_rate
self.mel_bins = mel_bins
self.mel_hop_length = mel_hop_length
self.n_fft = n_fft
def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor:
if source_rate == self.target_sample_rate:
return waveform
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
@staticmethod
def normalize_amplitude(
waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5
) -> torch.Tensor:
waveform = waveform - waveform.mean(dim=2, keepdim=True)
peak = torch.max(torch.abs(waveform)) + eps
scale = peak.clamp(max=max_amplitude) / peak
return waveform * scale
def waveform_to_mel(
self, waveform: torch.Tensor, waveform_sample_rate: int, device
) -> torch.Tensor:
waveform = self.resample(waveform, waveform_sample_rate)
waveform = self.normalize_amplitude(waveform)
mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=self.target_sample_rate,
n_fft=self.n_fft,
win_length=self.n_fft,
hop_length=self.mel_hop_length,
f_min=0.0,
f_max=self.target_sample_rate / 2.0,
n_mels=self.mel_bins,
window_fn=torch.hann_window,
center=True,
pad_mode="reflect",
power=1.0,
mel_scale="slaney",
norm="slaney",
).to(device)
mel = mel_transform(waveform)
mel = torch.log(torch.clamp(mel, min=1e-5))
return mel.permute(0, 1, 3, 2).contiguous()
class AudioVAE(torch.nn.Module):
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
def __init__(self, state_dict: dict, metadata: dict):
super().__init__()
component_config = AudioVAEComponentConfig.from_metadata(metadata)
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
self.vocoder = Vocoder(config=component_config.vocoder)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)
autoencoder_config = self.autoencoder.get_config()
self.normalizer = AudioLatentNormalizer(
AudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=autoencoder_config["sampling_rate"],
hop_length=autoencoder_config["mel_hop_length"],
is_causal=autoencoder_config["is_causal"],
),
self.autoencoder.per_channel_statistics,
)
self.preprocessor = AudioPreprocessor(
target_sample_rate=autoencoder_config["sampling_rate"],
mel_bins=autoencoder_config["mel_bins"],
mel_hop_length=autoencoder_config["mel_hop_length"],
n_fft=autoencoder_config["n_fft"],
)
self.device_manager = ModelDeviceManager(self)
def encode(self, audio: dict) -> torch.Tensor:
"""Encode a waveform dictionary into normalized latent tensors."""
waveform = audio["waveform"]
waveform_sample_rate = audio["sample_rate"]
input_device = waveform.device
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()
waveform = self.device_manager.move_to_load_device(waveform)
expected_channels = self.autoencoder.encoder.in_channels
if waveform.shape[1] != expected_channels:
raise ValueError(
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
)
mel_spec = self.preprocessor.waveform_to_mel(
waveform, waveform_sample_rate, device=self.device_manager.load_device
)
latents = self.autoencoder.encode(mel_spec)
posterior = DiagonalGaussianDistribution(latents)
latent_mode = posterior.mode()
normalized = self.normalizer.normalize(latent_mode)
return normalized.to(input_device)
def decode(self, latents: torch.Tensor) -> torch.Tensor:
"""Decode normalized latent tensors into an audio waveform."""
original_shape = latents.shape
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()
latents = self.device_manager.move_to_load_device(latents)
latents = self.normalizer.denormalize(latents)
target_shape = self.target_shape_from_latents(original_shape)
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
waveform = self.run_vocoder(mel_spec)
return self.device_manager.move_to_load_device(waveform)
def target_shape_from_latents(self, latents_shape):
batch, _, time, _ = latents_shape
target_length = time * LATENT_DOWNSAMPLE_FACTOR
if self.autoencoder.causality_axis != CausalityAxis.NONE:
target_length -= LATENT_DOWNSAMPLE_FACTOR - 1
return (
batch,
self.autoencoder.decoder.out_ch,
target_length,
self.autoencoder.mel_bins,
)
def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int:
return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second)
def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor:
audio_channels = self.autoencoder.decoder.out_ch
vocoder_input = mel_spec.transpose(2, 3)
if audio_channels == 1:
vocoder_input = vocoder_input.squeeze(1)
elif audio_channels != 2:
raise ValueError(f"Unsupported audio_channels: {audio_channels}")
return self.vocoder(vocoder_input)
@property
def sample_rate(self) -> int:
return int(self.autoencoder.sampling_rate)
@property
def mel_hop_length(self) -> int:
return int(self.autoencoder.mel_hop_length)
@property
def mel_bins(self) -> int:
return int(self.autoencoder.mel_bins)
@property
def latent_channels(self) -> int:
return int(self.autoencoder.decoder.z_channels)
@property
def latent_frequency_bins(self) -> int:
return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR)
@property
def latents_per_second(self) -> float:
return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR
@property
def output_sample_rate(self) -> int:
output_rate = getattr(self.vocoder, "output_sample_rate", None)
if output_rate is not None:
return int(output_rate)
upsample_factor = getattr(self.vocoder, "upsample_factor", None)
if upsample_factor is None:
raise AttributeError(
"Vocoder is missing upsample_factor; cannot infer output sample rate"
)
return int(self.sample_rate * upsample_factor / self.mel_hop_length)
def memory_required(self, input_shape):
return self.device_manager.patcher.model_size()

View File

@ -0,0 +1,909 @@
from __future__ import annotations
import torch
from torch import nn
from torch.nn import functional as F
from typing import Optional
from enum import Enum
from .pixel_norm import PixelNorm
import comfy.ops
import logging
ops = comfy.ops.disable_weight_init
class StringConvertibleEnum(Enum):
"""
Base enum class that provides string-to-enum conversion functionality.
This mixin adds a str_to_enum() class method that handles conversion from
strings, None, or existing enum instances with case-insensitive matching.
"""
@classmethod
def str_to_enum(cls, value):
"""
Convert a string, enum instance, or None to the appropriate enum member.
Args:
value: Can be an enum instance of this class, a string, or None
Returns:
Enum member of this class
Raises:
ValueError: If the value cannot be converted to a valid enum member
"""
# Already an enum instance of this class
if isinstance(value, cls):
return value
# None maps to NONE member if it exists
if value is None:
if hasattr(cls, "NONE"):
return cls.NONE
raise ValueError(f"{cls.__name__} does not have a NONE member to map None to")
# String conversion (case-insensitive)
if isinstance(value, str):
value_lower = value.lower()
# Try to match against enum values
for member in cls:
# Handle members with None values
if member.value is None:
if value_lower == "none":
return member
# Handle members with string values
elif isinstance(member.value, str) and member.value.lower() == value_lower:
return member
# Build helpful error message with valid values
valid_values = []
for member in cls:
if member.value is None:
valid_values.append("none")
elif isinstance(member.value, str):
valid_values.append(member.value)
raise ValueError(f"Invalid {cls.__name__} string: '{value}'. " f"Valid values are: {valid_values}")
raise ValueError(
f"Cannot convert type {type(value).__name__} to {cls.__name__} enum. "
f"Expected string, None, or {cls.__name__} instance."
)
class AttentionType(StringConvertibleEnum):
"""Enum for specifying the attention mechanism type."""
VANILLA = "vanilla"
LINEAR = "linear"
NONE = "none"
class CausalityAxis(StringConvertibleEnum):
"""Enum for specifying the causality axis in causal convolutions."""
NONE = None
WIDTH = "width"
HEIGHT = "height"
WIDTH_COMPATIBILITY = "width-compatibility"
def Normalize(in_channels, *, num_groups=32, normtype="group"):
if normtype == "group":
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
elif normtype == "pixel":
return PixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {normtype}")
class CausalConv2d(nn.Module):
"""
A causal 2D convolution.
This layer ensures that the output at time `t` only depends on inputs
at time `t` and earlier. It achieves this by applying asymmetric padding
to the time dimension (width) before the convolution.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
groups=1,
bias=True,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
):
super().__init__()
self.causality_axis = causality_axis
# Ensure kernel_size and dilation are tuples
kernel_size = nn.modules.utils._pair(kernel_size)
dilation = nn.modules.utils._pair(dilation)
# Calculate padding dimensions
pad_h = (kernel_size[0] - 1) * dilation[0]
pad_w = (kernel_size[1] - 1) * dilation[1]
# The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
match self.causality_axis:
case CausalityAxis.NONE:
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
case CausalityAxis.HEIGHT:
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
case _:
raise ValueError(f"Invalid causality_axis: {causality_axis}")
# The internal convolution layer uses no padding, as we handle it manually
self.conv = ops.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
)
def forward(self, x):
# Apply causal padding before convolution
x = F.pad(x, self.padding)
return self.conv(x)
def make_conv2d(
in_channels,
out_channels,
kernel_size,
stride=1,
padding=None,
dilation=1,
groups=1,
bias=True,
causality_axis: Optional[CausalityAxis] = None,
):
"""
Create a 2D convolution layer that can be either causal or non-causal.
Args:
in_channels: Number of input channels
out_channels: Number of output channels
kernel_size: Size of the convolution kernel
stride: Convolution stride
padding: Padding (if None, will be calculated based on causal flag)
dilation: Dilation rate
groups: Number of groups for grouped convolution
bias: Whether to use bias
causality_axis: Dimension along which to apply causality.
Returns:
Either a regular Conv2d or CausalConv2d layer
"""
if causality_axis is not None:
# For causal convolution, padding is handled internally by CausalConv2d
return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
else:
# For non-causal convolution, use symmetric padding if not specified
if padding is None:
if isinstance(kernel_size, int):
padding = kernel_size // 2
else:
padding = tuple(k // 2 for k in kernel_size)
return ops.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.HEIGHT):
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.with_conv:
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
# So the output elements rely on the following windows:
# 0: [-,-,0]
# 1: [-,0,0]
# 2: [0,0,1]
# 3: [0,1,1]
# 4: [1,1,2]
# 5: [1,2,2]
# Notice that the first and second elements in the output rely only on the first element in the input,
# while all other elements rely on two elements in the input.
# So we can drop the first element to undo the padding (rather than the last element).
# This is a no-op for non-causal convolutions.
match self.causality_axis:
case CausalityAxis.NONE:
pass # x remains unchanged
case CausalityAxis.HEIGHT:
x = x[:, :, 1:, :]
case CausalityAxis.WIDTH:
x = x[:, :, :, 1:]
case CausalityAxis.WIDTH_COMPATIBILITY:
pass # x remains unchanged
case _:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
return x
class Downsample(nn.Module):
"""
A downsampling layer that can use either a strided convolution
or average pooling. Supports standard and causal padding for the
convolutional mode.
"""
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.WIDTH):
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
raise ValueError("causality is only supported when `with_conv=True`.")
if self.with_conv:
# Do time downsampling here
# no asymmetric padding in torch conv, must do it ourselves
self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
# (pad_left, pad_right, pad_top, pad_bottom)
match self.causality_axis:
case CausalityAxis.NONE:
pad = (0, 1, 0, 1)
case CausalityAxis.WIDTH:
pad = (2, 0, 0, 1)
case CausalityAxis.HEIGHT:
pad = (0, 1, 2, 0)
case CausalityAxis.WIDTH_COMPATIBILITY:
pad = (1, 0, 0, 1)
case _:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
# This branch is only taken if with_conv=False, which implies causality_axis is NONE.
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
norm_type="group",
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
):
super().__init__()
self.causality_axis = causality_axis
if self.causality_axis != CausalityAxis.NONE and norm_type == "group":
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels, normtype=norm_type)
self.non_linearity = nn.SiLU()
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
if temb_channels > 0:
self.temb_proj = ops.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels, normtype=norm_type)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.nin_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = self.non_linearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
h = self.norm2(h)
h = self.non_linearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels, norm_type="group"):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels, normtype=norm_type)
self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w).contiguous()
q = q.permute(0, 2, 1).contiguous() # b,hw,c
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w).contiguous()
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w).contiguous()
h_ = self.proj_out(h_)
return x + h_
def make_attn(in_channels, attn_type="vanilla", norm_type="group"):
# Convert string to enum if needed
attn_type = AttentionType.str_to_enum(attn_type)
if attn_type != AttentionType.NONE:
logging.info(f"making attention of type '{attn_type.value}' with {in_channels} in_channels")
else:
logging.info(f"making identity attention with {in_channels} in_channels")
match attn_type:
case AttentionType.VANILLA:
return AttnBlock(in_channels, norm_type=norm_type)
case AttentionType.NONE:
return nn.Identity(in_channels)
case AttentionType.LINEAR:
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
case _:
raise ValueError(f"Unknown attention type: {attn_type}")
class Encoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
attn_type="vanilla",
mid_block_add_attention=True,
norm_type="group",
causality_axis=CausalityAxis.WIDTH.value,
**ignore_kwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.z_channels = z_channels
self.double_z = double_z
self.norm_type = norm_type
# Convert string to enum if needed (for config loading)
causality_axis = CausalityAxis.str_to_enum(causality_axis)
self.attn_type = AttentionType.str_to_enum(attn_type)
# downsampling
self.conv_in = make_conv2d(
in_channels,
self.ch,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
self.non_linearity = nn.SiLU()
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
if mid_block_add_attention:
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
else:
self.mid.attn_1 = nn.Identity()
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
# end
self.norm_out = Normalize(block_in, normtype=self.norm_type)
self.conv_out = make_conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
def forward(self, x):
"""
Forward pass through the encoder.
Args:
x: Input tensor of shape [batch, channels, time, n_mels]
Returns:
Encoded latent representation
"""
feature_maps = [self.conv_in(x)]
# Process each resolution level (from high to low resolution)
for resolution_level in range(self.num_resolutions):
# Apply residual blocks at current resolution level
for block_idx in range(self.num_res_blocks):
# Apply ResNet block with optional timestep embedding
current_features = self.down[resolution_level].block[block_idx](feature_maps[-1], temb=None)
# Apply attention if configured for this resolution level
if len(self.down[resolution_level].attn) > 0:
current_features = self.down[resolution_level].attn[block_idx](current_features)
# Store processed features
feature_maps.append(current_features)
# Downsample spatial dimensions (except at the final resolution level)
if resolution_level != self.num_resolutions - 1:
downsampled_features = self.down[resolution_level].downsample(feature_maps[-1])
feature_maps.append(downsampled_features)
# === MIDDLE PROCESSING PHASE ===
# Take the lowest resolution features for middle processing
bottleneck_features = feature_maps[-1]
# Apply first middle ResNet block
bottleneck_features = self.mid.block_1(bottleneck_features, temb=None)
# Apply middle attention block
bottleneck_features = self.mid.attn_1(bottleneck_features)
# Apply second middle ResNet block
bottleneck_features = self.mid.block_2(bottleneck_features, temb=None)
# === OUTPUT PHASE ===
# Normalize the bottleneck features
output_features = self.norm_out(bottleneck_features)
# Apply non-linearity (SiLU activation)
output_features = self.non_linearity(output_features)
# Final convolution to produce latent representation
# [batch, channels, time, n_mels] -> [batch, 2 * z_channels if double_z else z_channels, time, n_mels]
return self.conv_out(output_features)
class Decoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
attn_type="vanilla",
mid_block_add_attention=True,
norm_type="group",
causality_axis=CausalityAxis.WIDTH.value,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.out_ch = out_ch
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
self.norm_type = norm_type
self.z_channels = z_channels
# Convert string to enum if needed (for config loading)
causality_axis = CausalityAxis.str_to_enum(causality_axis)
self.attn_type = AttentionType.str_to_enum(attn_type)
# compute block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = make_conv2d(z_channels, block_in, kernel_size=3, stride=1, causality_axis=causality_axis)
self.non_linearity = nn.SiLU()
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
if mid_block_add_attention:
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
else:
self.mid.attn_1 = nn.Identity()
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in, normtype=self.norm_type)
self.conv_out = make_conv2d(block_in, out_ch, kernel_size=3, stride=1, causality_axis=causality_axis)
def _adjust_output_shape(self, decoded_output, target_shape):
"""
Adjust output shape to match target dimensions for variable-length audio.
This function handles the common case where decoded audio spectrograms need to be
resized to match a specific target shape.
Args:
decoded_output: Tensor of shape (batch, channels, time, frequency)
target_shape: Target shape tuple (batch, channels, time, frequency)
Returns:
Tensor adjusted to match target_shape exactly
"""
# Current output shape: (batch, channels, time, frequency)
_, _, current_time, current_freq = decoded_output.shape
_, target_channels, target_time, target_freq = target_shape
# Step 1: Crop first to avoid exceeding target dimensions
decoded_output = decoded_output[
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
]
# Step 2: Calculate padding needed for time and frequency dimensions
time_padding_needed = target_time - decoded_output.shape[2]
freq_padding_needed = target_freq - decoded_output.shape[3]
# Step 3: Apply padding if needed
if time_padding_needed > 0 or freq_padding_needed > 0:
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
# For audio: pad_left/right = frequency, pad_top/bottom = time
padding = (
0,
max(freq_padding_needed, 0), # frequency padding (left, right)
0,
max(time_padding_needed, 0), # time padding (top, bottom)
)
decoded_output = F.pad(decoded_output, padding)
# Step 4: Final safety crop to ensure exact target shape
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
return decoded_output
def get_config(self):
return {
"ch": self.ch,
"out_ch": self.out_ch,
"ch_mult": self.ch_mult,
"num_res_blocks": self.num_res_blocks,
"in_channels": self.in_channels,
"resolution": self.resolution,
"z_channels": self.z_channels,
}
def forward(self, latent_features, target_shape=None):
"""
Decode latent features back to audio spectrograms.
Args:
latent_features: Encoded latent representation of shape (batch, channels, height, width)
target_shape: Optional target output shape (batch, channels, time, frequency)
If provided, output will be cropped/padded to match this shape
Returns:
Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
"""
assert target_shape is not None, "Target shape is required for CausalAudioAutoencoder Decoder"
# Transform latent features to decoder's internal feature dimension
hidden_features = self.conv_in(latent_features)
# Middle processing
hidden_features = self.mid.block_1(hidden_features, temb=None)
hidden_features = self.mid.attn_1(hidden_features)
hidden_features = self.mid.block_2(hidden_features, temb=None)
# Upsampling
# Progressively increase spatial resolution from lowest to highest
for resolution_level in reversed(range(self.num_resolutions)):
# Apply residual blocks at current resolution level
for block_index in range(self.num_res_blocks + 1):
hidden_features = self.up[resolution_level].block[block_index](hidden_features, temb=None)
if len(self.up[resolution_level].attn) > 0:
hidden_features = self.up[resolution_level].attn[block_index](hidden_features)
if resolution_level != 0:
hidden_features = self.up[resolution_level].upsample(hidden_features)
# Output
if self.give_pre_end:
# Return intermediate features before final processing (for debugging/analysis)
decoded_output = hidden_features
else:
# Standard output path: normalize, activate, and convert to output channels
# Final normalization layer
hidden_features = self.norm_out(hidden_features)
# Apply SiLU (Swish) activation function
hidden_features = self.non_linearity(hidden_features)
# Final convolution to map to output channels (typically 2 for stereo audio)
decoded_output = self.conv_out(hidden_features)
# Optional tanh activation to bound output values to [-1, 1] range
if self.tanh_out:
decoded_output = torch.tanh(decoded_output)
# Adjust shape for audio data
if target_shape is not None:
decoded_output = self._adjust_output_shape(decoded_output, target_shape)
return decoded_output
class processor(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("std-of-means", torch.empty(128))
self.register_buffer("mean-of-means", torch.empty(128))
def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
def normalize(self, x):
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
class CausalAudioAutoencoder(nn.Module):
def __init__(self, config=None):
super().__init__()
if config is None:
config = self._guess_config()
# Extract encoder and decoder configs from the new format
model_config = config.get("model", {}).get("params", {})
variables_config = config.get("variables", {})
self.sampling_rate = variables_config.get(
"sampling_rate",
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
)
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
decoder_config = model_config.get("decoder", encoder_config)
# Load mel spectrogram parameters
self.mel_bins = encoder_config.get("mel_bins", 64)
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
# Store causality configuration at VAE level (not just in encoder internals)
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
self.encoder = Encoder(**encoder_config)
self.decoder = Decoder(**decoder_config)
self.per_channel_statistics = processor()
def _guess_config(self):
encoder_config = {
# Required parameters - based on ltx-video-av-1679000 model metadata
"ch": 128,
"out_ch": 8,
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
"num_res_blocks": 2,
"attn_resolutions": [], # Based on metadata: empty list, no attention
"dropout": 0.0,
"resamp_with_conv": True,
"in_channels": 2, # stereo
"resolution": 256,
"z_channels": 8,
"double_z": True,
"attn_type": "vanilla",
"mid_block_add_attention": False, # Based on metadata: false
"norm_type": "pixel",
"causality_axis": "height", # Based on metadata
"mel_bins": 64, # Based on metadata: mel_bins = 64
}
decoder_config = {
# Inherits encoder config, can override specific params
**encoder_config,
"out_ch": 2, # Stereo audio output (2 channels)
"give_pre_end": False,
"tanh_out": False,
}
config = {
"_class_name": "CausalAudioAutoencoder",
"sampling_rate": 16000,
"model": {
"params": {
"encoder": encoder_config,
"decoder": decoder_config,
}
},
}
return config
def get_config(self):
return {
"sampling_rate": self.sampling_rate,
"mel_bins": self.mel_bins,
"mel_hop_length": self.mel_hop_length,
"n_fft": self.n_fft,
"causality_axis": self.causality_axis.value,
"is_causal": self.is_causal,
}
def encode(self, x):
return self.encoder(x)
def decode(self, x, target_shape=None):
return self.decoder(x, target_shape=target_shape)

View File

@ -0,0 +1,213 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
import comfy.ops
import numpy as np
ops = comfy.ops.disable_weight_init
LRELU_SLOPE = 0.1
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.convs1 = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
),
]
)
self.convs2 = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
]
)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.convs = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
),
]
)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
class Vocoder(torch.nn.Module):
"""
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
"""
def __init__(self, config=None):
super(Vocoder, self).__init__()
if config is None:
config = self.get_default_config()
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
stereo = config.get("stereo", True)
resblock = config.get("resblock", "1")
self.output_sample_rate = config.get("output_sample_rate")
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
in_channels = 128 if stereo else 64
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
ops.ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock_class(ch, k, d))
out_channels = 2 if stereo else 1
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
def get_default_config(self):
"""Generate default configuration for the vocoder."""
config = {
"resblock_kernel_sizes": [3, 7, 11],
"upsample_rates": [6, 5, 2, 2, 2],
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"upsample_initial_channel": 1024,
"stereo": True,
"resblock": "1",
}
return config
def forward(self, x):
"""
Forward pass of the vocoder.
Args:
x: Input spectrogram tensor. Can be:
- 3D: (batch_size, channels, time_steps) for mono
- 4D: (batch_size, 2, channels, time_steps) for stereo
Returns:
Audio tensor of shape (batch_size, out_channels, audio_length)
"""
if x.dim() == 4: # stereo
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x

View File

@ -41,6 +41,11 @@ class ZImage_Control(torch.nn.Module):
ffn_dim_multiplier: float = (8.0 / 3.0), ffn_dim_multiplier: float = (8.0 / 3.0),
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
qk_norm: bool = True, qk_norm: bool = True,
n_control_layers=6,
control_in_dim=16,
additional_in_dim=0,
broken=False,
refiner_control=False,
dtype=None, dtype=None,
device=None, device=None,
operations=None, operations=None,
@ -49,10 +54,11 @@ class ZImage_Control(torch.nn.Module):
super().__init__() super().__init__()
operation_settings = {"operations": operations, "device": device, "dtype": dtype} operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.additional_in_dim = 0 self.broken = broken
self.control_in_dim = 16 self.additional_in_dim = additional_in_dim
self.control_in_dim = control_in_dim
n_refiner_layers = 2 n_refiner_layers = 2
self.n_control_layers = 6 self.n_control_layers = n_control_layers
self.control_layers = nn.ModuleList( self.control_layers = nn.ModuleList(
[ [
ZImageControlTransformerBlock( ZImageControlTransformerBlock(
@ -74,28 +80,49 @@ class ZImage_Control(torch.nn.Module):
all_x_embedder = {} all_x_embedder = {}
patch_size = 2 patch_size = 2
f_patch_size = 1 f_patch_size = 1
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype) x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
self.refiner_control = refiner_control
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
self.control_noise_refiner = nn.ModuleList( if self.refiner_control:
[ self.control_noise_refiner = nn.ModuleList(
JointTransformerBlock( [
layer_id, ZImageControlTransformerBlock(
dim, layer_id,
n_heads, dim,
n_kv_heads, n_heads,
multiple_of, n_kv_heads,
ffn_dim_multiplier, multiple_of,
norm_eps, ffn_dim_multiplier,
qk_norm, norm_eps,
modulation=True, qk_norm,
z_image_modulation=True, block_id=layer_id,
operation_settings=operation_settings, operation_settings=operation_settings,
) )
for layer_id in range(n_refiner_layers) for layer_id in range(n_refiner_layers)
] ]
) )
else:
self.control_noise_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=True,
z_image_modulation=True,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input): def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
patch_size = 2 patch_size = 2
@ -105,9 +132,29 @@ class ZImage_Control(torch.nn.Module):
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
x_attn_mask = None x_attn_mask = None
for layer in self.control_noise_refiner: if not self.refiner_control:
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input) for layer in self.control_noise_refiner:
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
return control_context return control_context
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
if self.refiner_control:
if self.broken:
if layer_id == 0:
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
if layer_id > 0:
out = None
for i in range(1, len(self.control_layers)):
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
if out is None:
out = o
return (out, control_context)
else:
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
else:
return (None, control_context)
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)

View File

@ -377,6 +377,7 @@ class NextDiT(nn.Module):
z_image_modulation=False, z_image_modulation=False,
time_scale=1.0, time_scale=1.0,
pad_tokens_multiple=None, pad_tokens_multiple=None,
clip_text_dim=None,
image_model=None, image_model=None,
device=None, device=None,
dtype=None, dtype=None,
@ -447,6 +448,31 @@ class NextDiT(nn.Module):
), ),
) )
self.clip_text_pooled_proj = None
if clip_text_dim is not None:
self.clip_text_dim = clip_text_dim
self.clip_text_pooled_proj = nn.Sequential(
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
clip_text_dim,
clip_text_dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.time_text_embed = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024) + clip_text_dim,
min(dim, 1024),
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
JointTransformerBlock( JointTransformerBlock(
@ -465,7 +491,8 @@ class NextDiT(nn.Module):
for layer_id in range(n_layers) for layer_id in range(n_layers)
] ]
) )
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) # This norm final is in the lumina 2.0 code but isn't actually used for anything.
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
if self.pad_tokens_multiple is not None: if self.pad_tokens_multiple is not None:
@ -510,6 +537,7 @@ class NextDiT(nn.Module):
bsz = len(x) bsz = len(x)
pH = pW = self.patch_size pH = pW = self.patch_size
device = x[0].device device = x[0].device
orig_x = x
if self.pad_tokens_multiple is not None: if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
@ -546,13 +574,21 @@ class NextDiT(nn.Module):
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
patches = transformer_options.get("patches", {})
# refine context # refine context
for layer in self.context_refiner: for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
padded_img_mask = None padded_img_mask = None
for layer in self.noise_refiner: x_input = x
for i, layer in enumerate(self.noise_refiner):
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
if "noise_refiner" in patches:
for p in patches["noise_refiner"]:
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
if "img" in out:
x = out["img"]
padded_full_embed = torch.cat((cap_feats, x), dim=1) padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None mask = None
@ -585,16 +621,29 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
if self.clip_text_pooled_proj is not None:
pooled = kwargs.get("clip_text_pooled", None)
if pooled is not None:
pooled = self.clip_text_pooled_proj(pooled)
else:
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor) x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(img.device) freqs_cis = freqs_cis.to(img.device)
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
img_input = img
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches: if "double_block" in patches:
for p in patches["double_block"]: for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
if "img" in out: if "img" in out:
img[:, cap_size[0]:] = out["img"] img[:, cap_size[0]:] = out["img"]
if "txt" in out: if "txt" in out:

View File

@ -30,6 +30,13 @@ except ImportError as e:
raise e raise e
exit(-1) exit(-1)
SAGE_ATTENTION3_IS_AVAILABLE = False
try:
from sageattn3 import sageattn3_blackwell
SAGE_ATTENTION3_IS_AVAILABLE = True
except ImportError:
pass
FLASH_ATTENTION_IS_AVAILABLE = False FLASH_ATTENTION_IS_AVAILABLE = False
try: try:
from flash_attn import flash_attn_func from flash_attn import flash_attn_func
@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = out.reshape(b, -1, heads * dim_head) out = out.reshape(b, -1, heads * dim_head)
return out return out
@wrap_attn
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
exception_fallback = False
if (q.device.type != "cuda" or
q.dtype not in (torch.float16, torch.bfloat16) or
mask is not None):
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)
if skip_reshape:
B, H, L, D = q.shape
if H != heads:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
**kwargs
)
q_s, k_s, v_s = q, k, v
N = q.shape[2]
dim_head = D
else:
B, N, inner_dim = q.shape
if inner_dim % heads != 0:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_output_reshape=skip_output_reshape,
**kwargs
)
dim_head = inner_dim // heads
if dim_head >= 256 or N <= 1024:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)
if not skip_reshape:
q_s, k_s, v_s = map(
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
(q, k, v),
)
B, H, L, D = q_s.shape
try:
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
except Exception as e:
exception_fallback = True
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
if exception_fallback:
if not skip_reshape:
del q_s, k_s, v_s
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_output_reshape=skip_output_reshape,
**kwargs
)
if skip_reshape:
if not skip_output_reshape:
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
else:
if skip_output_reshape:
pass
else:
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
return out
try: try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) @torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
@ -650,6 +744,8 @@ optimized_attention_masked = optimized_attention
# register core-supported attention functions # register core-supported attention functions
if SAGE_ATTENTION_IS_AVAILABLE: if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage) register_attention_function("sage", attention_sage)
if SAGE_ATTENTION3_IS_AVAILABLE:
register_attention_function("sage3", attention3_sage)
if FLASH_ATTENTION_IS_AVAILABLE: if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash) register_attention_function("flash", attention_flash)
if model_management.xformers_enabled(): if model_management.xformers_enabled():

View File

@ -394,7 +394,8 @@ class Model(nn.Module):
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
super().__init__() super().__init__()
if use_linear_attn: attn_type = "linear" if use_linear_attn:
attn_type = "linear"
self.ch = ch self.ch = ch
self.temb_ch = self.ch*4 self.temb_ch = self.ch*4
self.num_resolutions = len(ch_mult) self.num_resolutions = len(ch_mult)
@ -548,7 +549,8 @@ class Encoder(nn.Module):
conv3d=False, time_compress=None, conv3d=False, time_compress=None,
**ignore_kwargs): **ignore_kwargs):
super().__init__() super().__init__()
if use_linear_attn: attn_type = "linear" if use_linear_attn:
attn_type = "linear"
self.ch = ch self.ch = ch
self.temb_ch = 0 self.temb_ch = 0
self.num_resolutions = len(ch_mult) self.num_resolutions = len(ch_mult)

View File

@ -45,7 +45,7 @@ class LitEma(nn.Module):
shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else: else:
assert not key in self.m_name2s_name assert key not in self.m_name2s_name
def copy_to(self, model): def copy_to(self, model):
m_param = dict(model.named_parameters()) m_param = dict(model.named_parameters())
@ -54,7 +54,7 @@ class LitEma(nn.Module):
if m_param[key].requires_grad: if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else: else:
assert not key in self.m_name2s_name assert key not in self.m_name2s_name
def store(self, parameters): def store(self, parameters):
""" """

View File

@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis):
class QwenTimestepProjEmbeddings(nn.Module): class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None): def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding( self.timestep_embedder = TimestepEmbedding(
@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module):
operations=operations operations=operations
) )
def forward(self, timestep, hidden_states): self.use_additional_t_cond = use_additional_t_cond
if self.use_additional_t_cond:
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
def forward(self, timestep, hidden_states, addition_t_cond=None):
timesteps_proj = self.time_proj(timestep) timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
if self.use_additional_t_cond:
if addition_t_cond is None:
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
return timesteps_emb return timesteps_emb
@ -218,9 +228,24 @@ class QwenImageTransformerBlock(nn.Module):
operations=operations, operations=operations,
) )
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def _apply_gate(self, x, y, gate, timestep_zero_index=None):
if timestep_zero_index is not None:
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
else:
return torch.addcmul(y, gate, x)
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1) shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1) if timestep_zero_index is not None:
actual_batch = shift.size(0) // 2
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
else:
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
def forward( def forward(
self, self,
@ -229,14 +254,19 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states_mask: torch.Tensor, encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
timestep_zero_index=None,
transformer_options={}, transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb) img_mod_params = self.img_mod(temb)
if timestep_zero_index is not None:
temb = temb.chunk(2, dim=0)[0]
txt_mod_params = self.txt_mod(temb) txt_mod_params = self.txt_mod(temb)
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1) img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
del img_mod1 del img_mod1
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1) txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
del txt_mod1 del txt_mod1
@ -251,15 +281,15 @@ class QwenImageTransformerBlock(nn.Module):
del img_modulated del img_modulated
del txt_modulated del txt_modulated
hidden_states = hidden_states + img_gate1 * img_attn_output hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
del img_attn_output del img_attn_output
del txt_attn_output del txt_attn_output
del img_gate1 del img_gate1
del txt_gate1 del txt_gate1
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2) img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2)) hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2) txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2)) encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
@ -300,10 +330,11 @@ class QwenImageTransformer2DModel(nn.Module):
num_attention_heads: int = 24, num_attention_heads: int = 24,
joint_attention_dim: int = 3584, joint_attention_dim: int = 3584,
pooled_projection_dim: int = 768, pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
default_ref_method="index",
image_model=None, image_model=None,
final_layer=True, final_layer=True,
use_additional_t_cond=False,
dtype=None, dtype=None,
device=None, device=None,
operations=None, operations=None,
@ -314,12 +345,14 @@ class QwenImageTransformer2DModel(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels or in_channels self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim self.inner_dim = num_attention_heads * attention_head_dim
self.default_ref_method = default_ref_method
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
self.time_text_embed = QwenTimestepProjEmbeddings( self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim, embedding_dim=self.inner_dim,
pooled_projection_dim=pooled_projection_dim, pooled_projection_dim=pooled_projection_dim,
use_additional_t_cond=use_additional_t_cond,
dtype=dtype, dtype=dtype,
device=device, device=device,
operations=operations operations=operations
@ -341,6 +374,9 @@ class QwenImageTransformer2DModel(nn.Module):
for _ in range(num_layers) for _ in range(num_layers)
]) ])
if self.default_ref_method == "index_timestep_zero":
self.register_buffer("__index_timestep_zero__", torch.tensor([]))
if final_layer: if final_layer:
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
@ -350,27 +386,33 @@ class QwenImageTransformer2DModel(nn.Module):
patch_size = self.patch_size patch_size = self.patch_size
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
t_len = t
h_len = ((h + (patch_size // 2)) // patch_size) h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size)
h_offset = ((h_offset + (patch_size // 2)) // patch_size) h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size) w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device) img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): if t_len > 1:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
else:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor( return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self._forward,
self, self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs) ).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
def _forward( def _forward(
self, self,
@ -378,8 +420,8 @@ class QwenImageTransformer2DModel(nn.Module):
timesteps, timesteps,
context, context,
attention_mask=None, attention_mask=None,
guidance: torch.Tensor = None,
ref_latents=None, ref_latents=None,
additional_t_cond=None,
transformer_options={}, transformer_options={},
control=None, control=None,
**kwargs **kwargs
@ -391,16 +433,24 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states, img_ids, orig_shape = self.process_img(x) hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1] num_embeds = hidden_states.shape[1]
timestep_zero_index = None
if ref_latents is not None: if ref_latents is not None:
h = 0 h = 0
w = 0 w = 0
index = 0 index = 0
index_ref_method = kwargs.get("ref_latents_method", "index") == "index" ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
negative_ref_method = ref_method == "negative_index"
timestep_zero = ref_method == "index_timestep_zero"
for ref in ref_latents: for ref in ref_latents:
if index_ref_method: if index_ref_method:
index += 1 index += 1
h_offset = 0 h_offset = 0
w_offset = 0 w_offset = 0
elif negative_ref_method:
index -= 1
h_offset = 0
w_offset = 0
else: else:
index = 1 index = 1
h_offset = 0 h_offset = 0
@ -415,6 +465,10 @@ class QwenImageTransformer2DModel(nn.Module):
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1) hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = num_embeds
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
@ -426,14 +480,7 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None: temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
guidance = guidance * 1000
temb = (
self.time_text_embed(timestep, hidden_states)
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
)
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
@ -446,7 +493,7 @@ class QwenImageTransformer2DModel(nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"] hidden_states = out["img"]
@ -458,6 +505,7 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask, encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
timestep_zero_index=timestep_zero_index,
transformer_options=transformer_options, transformer_options=transformer_options,
) )
@ -474,9 +522,12 @@ class QwenImageTransformer2DModel(nn.Module):
if add is not None: if add is not None:
hidden_states[:, :add.shape[1]] += add hidden_states[:, :add.shape[1]] += add
if timestep_zero_index is not None:
temb = temb.chunk(2, dim=0)[0]
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@ -71,7 +71,7 @@ def count_params(model, verbose=False):
def instantiate_from_config(config): def instantiate_from_config(config):
if not "target" in config: if "target" not in config:
if config == '__is_first_stage__': if config == '__is_first_stage__':
return None return None
elif config == "__is_unconditional__": elif config == "__is_unconditional__":

View File

@ -568,7 +568,10 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -763,7 +766,10 @@ class VaceWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -862,7 +868,10 @@ class CameraWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -1326,16 +1335,19 @@ class WanModel_S2V(WanModel):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"] x = out["img"]
else: else:
x = block(x, e=e0, freqs=freqs, context=context) x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
if audio_emb is not None: if audio_emb is not None:
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len) x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
# head # head
@ -1574,7 +1586,10 @@ class HumoWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}

View File

@ -523,7 +523,10 @@ class AnimateWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}

View File

@ -227,6 +227,7 @@ class Encoder3d(nn.Module):
def __init__(self, def __init__(self,
dim=128, dim=128,
z_dim=4, z_dim=4,
input_channels=3,
dim_mult=[1, 2, 4, 4], dim_mult=[1, 2, 4, 4],
num_res_blocks=2, num_res_blocks=2,
attn_scales=[], attn_scales=[],
@ -245,7 +246,7 @@ class Encoder3d(nn.Module):
scale = 1.0 scale = 1.0
# init block # init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
# downsample blocks # downsample blocks
downsamples = [] downsamples = []
@ -331,6 +332,7 @@ class Decoder3d(nn.Module):
def __init__(self, def __init__(self,
dim=128, dim=128,
z_dim=4, z_dim=4,
output_channels=3,
dim_mult=[1, 2, 4, 4], dim_mult=[1, 2, 4, 4],
num_res_blocks=2, num_res_blocks=2,
attn_scales=[], attn_scales=[],
@ -378,7 +380,7 @@ class Decoder3d(nn.Module):
# output blocks # output blocks
self.head = nn.Sequential( self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(), RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1)) CausalConv3d(out_dim, output_channels, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]): def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1 ## conv1
@ -449,6 +451,7 @@ class WanVAE(nn.Module):
num_res_blocks=2, num_res_blocks=2,
attn_scales=[], attn_scales=[],
temperal_downsample=[True, True, False], temperal_downsample=[True, True, False],
image_channels=3,
dropout=0.0): dropout=0.0):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -460,11 +463,11 @@ class WanVAE(nn.Module):
self.temperal_upsample = temperal_downsample[::-1] self.temperal_upsample = temperal_downsample[::-1]
# modules # modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout) attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout) attn_scales, self.temperal_upsample, dropout)
def encode(self, x): def encode(self, x):

View File

@ -320,6 +320,7 @@ def model_lora_keys_unet(model, key_map={}):
to = diffusers_keys[k] to = diffusers_keys[k]
key_lora = k[:-len(".weight")] key_lora = k[:-len(".weight")]
key_map["diffusion_model.{}".format(key_lora)] = to key_map["diffusion_model.{}".format(key_lora)] = to
key_map["transformer.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
if isinstance(model, comfy.model_base.Kandinsky5): if isinstance(model, comfy.model_base.Kandinsky5):

View File

@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1
import comfy.ldm.hunyuan3dv2_1.hunyuandit import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch import torch
import logging import logging
import comfy.ldm.lightricks.av_model
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_c import StageC
from comfy.ldm.cascade.stage_b import StageB from comfy.ldm.cascade.stage_b import StageB
@ -946,7 +947,7 @@ class GenmoMochi(BaseModel):
class LTXV(BaseModel): class LTXV(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None): def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel)
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)
@ -977,6 +978,60 @@ class LTXV(BaseModel):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image return latent_image
class LTXAV(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
audio_denoise_mask = None
if denoise_mask is not None and "latent_shapes" in kwargs:
denoise_mask = utils.unpack_latents(denoise_mask, kwargs["latent_shapes"])
if len(denoise_mask) > 1:
audio_denoise_mask = denoise_mask[1]
denoise_mask = denoise_mask[0]
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
if audio_denoise_mask is not None:
out["audio_denoise_mask"] = comfy.conds.CONDRegular(audio_denoise_mask)
keyframe_idxs = kwargs.get("keyframe_idxs", None)
if keyframe_idxs is not None:
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
latent_shapes = kwargs.get("latent_shapes", None)
if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
v_timestep = timestep
a_timestep = timestep
if denoise_mask is not None:
v_timestep = self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
if audio_denoise_mask is not None:
a_timestep = self.diffusion_model.a_patchifier.patchify(((audio_denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (audio_denoise_mask.ndim - 1)))[:, :1, :, :1])[0]
return v_timestep, a_timestep
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class HunyuanVideo(BaseModel): class HunyuanVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
@ -1110,6 +1165,10 @@ class Lumina2(BaseModel):
if 'num_tokens' not in out: if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
clip_text_pooled = kwargs.get("pooled_output", None) # NewBie
if clip_text_pooled is not None:
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
return out return out
class WAN21(BaseModel): class WAN21(BaseModel):

View File

@ -180,8 +180,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["use_cond_type_embedding"] = False dit_config["use_cond_type_embedding"] = False
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys: if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0] dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
dit_config["meanflow_sum"] = True
else: else:
dit_config["vision_in_dim"] = None dit_config["vision_in_dim"] = None
dit_config["meanflow_sum"] = False
return dit_config return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
@ -257,6 +259,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_tile_size"] = 512 dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32 dit_config["nerf_embedder_dtype"] = torch.float32
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
else:
dit_config["use_x0"] = False
else: else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
@ -299,7 +305,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {} dit_config = {}
dit_config["image_model"] = "ltxv" dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv"
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
dit_config["attention_head_dim"] = shape[0] // 32 dit_config["attention_head_dim"] = shape[0] // 32
@ -423,6 +429,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_lens"] = [300, 512, 512] dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0 dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0 dit_config["ffn_dim_multiplier"] = 4.0
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
if ctd_weight is not None: # NewBie
dit_config["clip_text_dim"] = ctd_weight.shape[0]
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
elif dit_config["dim"] == 3840: # Z image elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30 dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30 dit_config["n_kv_heads"] = 30
@ -609,6 +619,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "qwen_image" dit_config["image_model"] = "qwen_image"
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1] dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
dit_config["default_ref_method"] = "index_timestep_zero"
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
dit_config["use_additional_t_cond"] = True
dit_config["default_ref_method"] = "negative_index"
return dit_config return dit_config
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5

View File

@ -26,6 +26,7 @@ import importlib
import platform import platform
import weakref import weakref
import gc import gc
import os
class VRAMState(Enum): class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram DISABLED = 0 #No vram present: no need to move models to vram
@ -333,13 +334,15 @@ except:
SUPPORT_FP8_OPS = args.supports_fp8_compute SUPPORT_FP8_OPS = args.supports_fp8_compute
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"] AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
try: try:
if is_amd(): if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
try: try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
@ -453,7 +456,7 @@ def module_size(module):
sd = module.state_dict() sd = module.state_dict()
for k in sd: for k in sd:
t = sd[k] t = sd[k]
module_mem += t.nelement() * t.element_size() module_mem += t.nbytes
return module_mem return module_mem
class LoadedModel: class LoadedModel:
@ -1016,8 +1019,8 @@ NUM_STREAMS = 0
if args.async_offload is not None: if args.async_offload is not None:
NUM_STREAMS = args.async_offload NUM_STREAMS = args.async_offload
else: else:
# Enable by default on Nvidia # Enable by default on Nvidia and AMD
if is_nvidia(): if is_nvidia() or is_amd():
NUM_STREAMS = 2 NUM_STREAMS = 2
if args.disable_async_offload: if args.disable_async_offload:
@ -1123,6 +1126,16 @@ if not args.disable_pinned_memory:
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
def discard_cuda_async_error():
try:
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
torch.cuda.synchronize()
except torch.AcceleratorError:
#Dump it! We already know about it from the synchronous return
pass
def pin_memory(tensor): def pin_memory(tensor):
global TOTAL_PINNED_MEMORY global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0: if MAX_PINNED_MEMORY <= 0:
@ -1143,7 +1156,7 @@ def pin_memory(tensor):
if not tensor.is_contiguous(): if not tensor.is_contiguous():
return False return False
size = tensor.numel() * tensor.element_size() size = tensor.nbytes
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
return False return False
@ -1155,6 +1168,9 @@ def pin_memory(tensor):
PINNED_MEMORY[ptr] = size PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size TOTAL_PINNED_MEMORY += size
return True return True
else:
logging.warning("Pin error.")
discard_cuda_async_error()
return False return False
@ -1167,7 +1183,7 @@ def unpin_memory(tensor):
return False return False
ptr = tensor.data_ptr() ptr = tensor.data_ptr()
size = tensor.numel() * tensor.element_size() size = tensor.nbytes
size_stored = PINNED_MEMORY.get(ptr, None) size_stored = PINNED_MEMORY.get(ptr, None)
if size_stored is None: if size_stored is None:
@ -1183,6 +1199,9 @@ def unpin_memory(tensor):
if len(PINNED_MEMORY) == 0: if len(PINNED_MEMORY) == 0:
TOTAL_PINNED_MEMORY = 0 TOTAL_PINNED_MEMORY = 0
return True return True
else:
logging.warning("Unpin error.")
discard_cuda_async_error()
return False return False
@ -1485,6 +1504,16 @@ def supports_fp8_compute(device=None):
return True return True
def supports_nvfp4_compute(device=None):
if not is_nvidia():
return False
props = torch.cuda.get_device_properties(device)
if props.major < 10:
return False
return True
def extended_fp16_support(): def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older # TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7): if torch_version_numeric < (2, 7):
@ -1523,6 +1552,10 @@ def soft_empty_cache(force=False):
def unload_all_models(): def unload_all_models():
free_memory(1e30, get_torch_device()) free_memory(1e30, get_torch_device())
def debug_memory_summary():
if is_amd() or is_nvidia():
return torch.cuda.memory.memory_summary()
return ""
#TODO: might be cleaner to put this somewhere else #TODO: might be cleaner to put this somewhere else
import threading import threading

View File

@ -35,6 +35,7 @@ import comfy.model_management
import comfy.patcher_extension import comfy.patcher_extension
import comfy.utils import comfy.utils
from comfy.comfy_types import UnetWrapperFunction from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@ -132,14 +133,17 @@ class LowVramPatch:
def __call__(self, weight): def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
def low_vram_patch_estimate_vram(model, key): def low_vram_patch_estimate_vram(model, key):
weight, set_func, convert_func = get_key_weight(model, key) weight, set_func, convert_func = get_key_weight(model, key)
if weight is None: if weight is None:
return 0 return 0
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
if model_dtype is None:
model_dtype = weight.dtype
return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
def get_key_weight(model, key): def get_key_weight(model, key):
set_func = None set_func = None
@ -450,6 +454,9 @@ class ModelPatcher:
def set_model_post_input_patch(self, patch): def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input") self.set_model_patch(patch, "post_input")
def set_model_noise_refiner_patch(self, patch):
self.set_model_patch(patch, "noise_refiner")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs): def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {}) rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x rope_options["scale_x"] = scale_x
@ -662,12 +669,18 @@ class ModelPatcher:
module_mem = comfy.model_management.module_size(m) module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
weight_key = "{}.weight".format(n) def check_module_offload_mem(key):
bias_key = "{}.bias".format(n) if key in self.patches:
if weight_key in self.patches: return low_vram_patch_estimate_vram(self.model, key)
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) model_dtype = getattr(self.model, "manual_cast_dtype", None)
if bias_key in self.patches: weight, _, _ = get_key_weight(self.model, key)
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) if model_dtype is None or weight is None:
return 0
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
return weight.numel() * model_dtype.itemsize
return 0
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
loading.append((module_offload_mem, module_mem, n, m, params)) loading.append((module_offload_mem, module_mem, n, m, params))
return loading return loading
@ -920,7 +933,7 @@ class ModelPatcher:
patch_counter += 1 patch_counter += 1
cast_weight = True cast_weight = True
if cast_weight: if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True m.comfy_cast_weights = True
m.comfy_patched_weights = False m.comfy_patched_weights = False

View File

@ -22,7 +22,6 @@ import comfy.model_management
from comfy.cli_args import args, PerformanceFeature from comfy.cli_args import args, PerformanceFeature
import comfy.float import comfy.float
import comfy.rmsnorm import comfy.rmsnorm
import contextlib
import json import json
def run_every_op(): def run_every_op():
@ -80,7 +79,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if input is not None: if input is not None:
if dtype is None: if dtype is None:
if isinstance(input, QuantizedTensor): if isinstance(input, QuantizedTensor):
dtype = input._layout_params["orig_dtype"] dtype = input.params.orig_dtype
else: else:
dtype = input.dtype dtype = input.dtype
if bias_dtype is None: if bias_dtype is None:
@ -94,13 +93,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
else: else:
offload_stream = None offload_stream = None
if offload_stream is not None:
wf_context = offload_stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(offload_stream)
else:
wf_context = contextlib.nullcontext()
non_blocking = comfy.model_management.device_supports_non_blocking(device) non_blocking = comfy.model_management.device_supports_non_blocking(device)
weight_has_function = len(s.weight_function) > 0 weight_has_function = len(s.weight_function) > 0
@ -420,26 +412,34 @@ def fp8_linear(self, input):
return None return None
input_dtype = input.dtype input_dtype = input.dtype
input_shape = input.shape
tensor_3d = input.ndim == 3
if input.ndim == 3 or input.ndim == 2: if tensor_3d:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) input = input.reshape(-1, input_shape[2])
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
scale_input = torch.ones((), device=input.device, dtype=torch.float32) if input.ndim != 2:
input = torch.clamp(input, min=-448, max=448, out=input) return None
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
# Wrap weight in QuantizedTensor - this enables unified dispatch scale_input = torch.ones((), device=input.device, dtype=torch.float32)
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! input = torch.clamp(input, min=-448, max=448, out=input)
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} input_fp8 = input.to(dtype).contiguous()
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight) layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input)
uncast_bias_weight(self, w, bias, offload_stream) # Wrap weight in QuantizedTensor - this enables unified dispatch
return o # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
return None uncast_bias_weight(self, w, bias, offload_stream)
if tensor_3d:
o = o.reshape((input_shape[0], input_shape[1], w.shape[0]))
return o
class fp8_ops(manual_cast): class fp8_ops(manual_cast):
class Linear(manual_cast.Linear): class Linear(manual_cast.Linear):
@ -485,14 +485,20 @@ if CUBLAS_IS_AVAILABLE:
# ============================================================================== # ==============================================================================
# Mixed Precision Operations # Mixed Precision Operations
# ============================================================================== # ==============================================================================
from .quant_ops import QuantizedTensor, QUANT_ALGOS from .quant_ops import (
QuantizedTensor,
QUANT_ALGOS,
TensorCoreFP8Layout,
get_layout_class,
)
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast): class MixedPrecisionOps(manual_cast):
_quant_config = quant_config _quant_config = quant_config
_compute_dtype = compute_dtype _compute_dtype = compute_dtype
_full_precision_mm = full_precision_mm _full_precision_mm = full_precision_mm
_disabled = disabled
class Linear(torch.nn.Module, CastWeightBiasOp): class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__( def __init__(
@ -517,10 +523,21 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.tensor_class = None self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm self._full_precision_mm = MixedPrecisionOps._full_precision_mm
self._full_precision_mm_config = False
def reset_parameters(self): def reset_parameters(self):
return None return None
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
key = f"{prefix}{param_name}"
value = state_dict.pop(key, None)
if value is not None:
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
manually_loaded_keys.append(key)
return value
def _load_from_state_dict(self, state_dict, prefix, local_metadata, def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs): strict, missing_keys, unexpected_keys, error_msgs):
@ -541,34 +558,58 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else: else:
self.quant_format = layer_conf.get("format", None) self.quant_format = layer_conf.get("format", None)
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
if not self._full_precision_mm: if not self._full_precision_mm:
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False) self._full_precision_mm = self._full_precision_mm_config
if self.quant_format in MixedPrecisionOps._disabled:
self._full_precision_mm = True
if self.quant_format is None: if self.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}") raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[self.quant_format] qconfig = QUANT_ALGOS[self.quant_format]
self.layout_type = qconfig["comfy_tensor_layout"] self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
weight_scale_key = f"{prefix}weight_scale" # Load format-specific parameters
scale = state_dict.pop(weight_scale_key, None) if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
if scale is not None: # FP8: single tensor scale
scale = scale.to(device) scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
layout_params = {
'scale': scale,
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if scale is not None: params = layout_cls.Params(
manually_loaded_keys.append(weight_scale_key) scale=scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.float8_e4m3fn)
if tensor_scale is None or block_scale is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
params = layout_cls.Params(
scale=tensor_scale,
block_scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
else:
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
self.weight = torch.nn.Parameter( self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params), QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
requires_grad=False requires_grad=False
) )
for param_name in qconfig["parameters"]: for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
continue # Already handled above
param_key = f"{prefix}{param_name}" param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None) _v = state_dict.pop(param_key, None)
if _v is None: if _v is None:
@ -585,11 +626,19 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def state_dict(self, *args, destination=None, prefix="", **kwargs): def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
if isinstance(self.weight, QuantizedTensor): if isinstance(self.weight, QuantizedTensor):
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale'] layout_cls = self.weight._layout_cls
# Check if it's any FP8 variant (E4M3 or E5M2)
if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"):
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
elif layout_cls == "TensorCoreNVFP4Layout":
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
quant_conf = {"format": self.quant_format} quant_conf = {"format": self.quant_format}
if self._full_precision_mm: if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8) sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
return sd return sd
def _forward(self, input, weight, bias): def _forward(self, input, weight, bias):
@ -604,12 +653,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def forward(self, input, *args, **kwargs): def forward(self, input, *args, **kwargs):
run_every_op() run_every_op()
input_shape = input.shape
tensor_3d = input.ndim == 3
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs) return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and if (getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor)): not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias) # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
if tensor_3d:
input = input.reshape(-1, input_shape[2])
if input.ndim != 2:
# Fall back to comfy_cast_weights for non-2D tensors
return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs)
# dtype is now implicit in the layout class
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None))
output = self._forward(input, self.weight, self.bias)
# Reshape output back to 3D if input was 3D
if tensor_3d:
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
return output
def convert_weight(self, weight, inplace=False, **kwargs): def convert_weight(self, weight, inplace=False, **kwargs):
if isinstance(weight, QuantizedTensor): if isinstance(weight, QuantizedTensor):
@ -619,7 +689,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None: if getattr(self, 'layout_type', None) is not None:
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) # dtype is now implicit in the layout class
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
else: else:
weight = weight.to(self.weight.dtype) weight = weight.to(self.weight.dtype)
if return_weight: if return_weight:
@ -646,10 +717,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations") logging.info("Using mixed precision operations")
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute) disabled = set()
if not nvfp4_compute:
disabled.add("nvfp4")
if not fp8_compute:
disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2")
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)
if ( if (
fp8_compute and fp8_compute and

View File

@ -1,577 +1,140 @@
import torch import torch
import logging import logging
from typing import Tuple, Dict
try:
import comfy_kitchen as ck
from comfy_kitchen.tensor import (
QuantizedTensor,
QuantizedLayout,
TensorCoreFP8Layout as _CKFp8Layout,
TensorCoreNVFP4Layout, # Direct import, no wrapper needed
register_layout_op,
register_layout_class,
get_layout_class,
)
_CK_AVAILABLE = True
if torch.version.cuda is None:
ck.registry.disable("cuda")
else:
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
if cuda_version < (13,):
ck.registry.disable("cuda")
ck.registry.disable("triton")
for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}")
except ImportError as e:
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
_CK_AVAILABLE = False
class QuantizedTensor:
pass
class _CKFp8Layout:
pass
class TensorCoreNVFP4Layout:
pass
def register_layout_class(name, cls):
pass
def get_layout_class(name):
return None
import comfy.float import comfy.float
_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
def register_layout_op(torch_op, layout_type):
"""
Decorator to register a layout-specific operation handler.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
layout_type: Layout class (e.g., TensorCoreFP8Layout)
Example:
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
def fp8_linear(func, args, kwargs):
# FP8-specific linear implementation
...
"""
def decorator(handler_func):
if torch_op not in _LAYOUT_REGISTRY:
_LAYOUT_REGISTRY[torch_op] = {}
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
return handler_func
return decorator
def register_generic_util(torch_op):
"""
Decorator to register a generic utility that works for all layouts.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
Example:
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
# Works for any layout
...
"""
def decorator(handler_func):
_GENERIC_UTILS[torch_op] = handler_func
return handler_func
return decorator
def _get_layout_from_args(args):
for arg in args:
if isinstance(arg, QuantizedTensor):
return arg._layout_type
elif isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, QuantizedTensor):
return item._layout_type
return None
def _move_layout_params_to_device(params, device):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.to(device=device)
else:
new_params[k] = v
return new_params
def _copy_layout_params(params):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.clone()
else:
new_params[k] = v
return new_params
def _copy_layout_params_inplace(src, dst, non_blocking=False):
for k, v in src.items():
if isinstance(v, torch.Tensor):
dst[k].copy_(v, non_blocking=non_blocking)
else:
dst[k] = v
class QuantizedLayout:
"""
Base class for quantization layouts.
A layout encapsulates the format-specific logic for quantization/dequantization
and provides a uniform interface for extracting raw tensors needed for computation.
New quantization formats should subclass this and implement the required methods.
"""
@classmethod
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
@staticmethod
def dequantize(qdata, **layout_params) -> torch.Tensor:
raise NotImplementedError("TensorLayout must implement dequantize()")
@classmethod
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
class QuantizedTensor(torch.Tensor):
"""
Universal quantized tensor that works with any layout.
This tensor subclass uses a pluggable layout system to support multiple
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
The layout_type determines format-specific behavior, while common operations
(detach, clone, to) are handled generically.
Attributes:
_qdata: The quantized tensor data
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
"""
@staticmethod
def __new__(cls, qdata, layout_type, layout_params):
"""
Create a quantized tensor.
Args:
qdata: The quantized data tensor
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
self._layout_type = layout_type
self._layout_params = layout_params
def __repr__(self):
layout_name = self._layout_type
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
@property
def layout_type(self):
return self._layout_type
def __tensor_flatten__(self):
"""
Tensor flattening protocol for proper device movement.
"""
inner_tensors = ["_qdata"]
ctx = {
"layout_type": self._layout_type,
}
tensor_params = {}
non_tensor_params = {}
for k, v in self._layout_params.items():
if isinstance(v, torch.Tensor):
tensor_params[k] = v
else:
non_tensor_params[k] = v
ctx["tensor_param_keys"] = list(tensor_params.keys())
ctx["non_tensor_params"] = non_tensor_params
for k, v in tensor_params.items():
attr_name = f"_layout_param_{k}"
object.__setattr__(self, attr_name, v)
inner_tensors.append(attr_name)
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
"""
Tensor unflattening protocol for proper device movement.
Reconstructs the QuantizedTensor after device movement.
"""
layout_type = ctx["layout_type"]
layout_params = dict(ctx["non_tensor_params"])
for key in ctx["tensor_param_keys"]:
attr_name = f"_layout_param_{key}"
layout_params[key] = inner_tensors[attr_name]
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
@classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor:
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
# Step 1: Check generic utilities first (detach, clone, to, etc.)
if func in _GENERIC_UTILS:
return _GENERIC_UTILS[func](func, args, kwargs)
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
layout_type = _get_layout_from_args(args)
if layout_type and func in _LAYOUT_REGISTRY:
handler = _LAYOUT_REGISTRY[func].get(layout_type)
if handler:
return handler(func, args, kwargs)
# Step 3: Fallback to dequantization
if isinstance(args[0] if args else None, QuantizedTensor):
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
return cls._dequant_and_fallback(func, args, kwargs)
@classmethod
def _dequant_and_fallback(cls, func, args, kwargs):
def dequant_arg(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize()
elif isinstance(arg, (list, tuple)):
return type(arg)(dequant_arg(a) for a in arg)
return arg
new_args = dequant_arg(args)
new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs)
def data_ptr(self):
return self._qdata.data_ptr()
def is_pinned(self):
return self._qdata.is_pinned()
def is_contiguous(self, *arg, **kwargs):
return self._qdata.is_contiguous(*arg, **kwargs)
def storage(self):
return self._qdata.storage()
# ============================================================================== # ==============================================================================
# Generic Utilities (Layout-Agnostic Operations) # FP8 Layouts with Comfy-Specific Extensions
# ============================================================================== # ==============================================================================
def _create_transformed_qtensor(qt, transform_fn): class _TensorCoreFP8LayoutBase(_CKFp8Layout):
new_data = transform_fn(qt._qdata) FP8_DTYPE = None # Must be overridden in subclass
new_params = _copy_layout_params(qt._layout_params)
return QuantizedTensor(new_data, qt._layout_type, new_params)
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
f"but not supported. Ignoring layout."
)
# Handle device transfer
current_device = qt._qdata.device
if target_device is not None:
# Normalize device for comparison
if isinstance(target_device, str):
target_device = torch.device(target_device)
if isinstance(current_device, str):
current_device = torch.device(current_device)
if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
if target_dtype is not None:
new_params["orig_dtype"] = target_dtype
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
return qt
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
"""Detach operation - creates a detached copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.detach())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.clone.default)
def generic_clone(func, args, kwargs):
"""Clone operation - creates a deep copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.clone())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._to_copy.default)
def generic_to_copy(func, args, kwargs):
"""Device/dtype transfer operation - handles .to(device) calls."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
op_name="_to_copy"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype_layout)
def generic_to_dtype_layout(func, args, kwargs):
"""Handle .to(device) calls using the dtype_layout variant."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
target_layout=kwargs.get('layout', None),
op_name="to"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.copy_.default)
def generic_copy_(func, args, kwargs):
qt_dest = args[0]
src = args[1]
non_blocking = args[2] if len(args) > 2 else False
if isinstance(qt_dest, QuantizedTensor):
if isinstance(src, QuantizedTensor):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
qt_dest._layout_type = src._layout_type
orig_dtype = qt_dest._layout_params["orig_dtype"]
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
qt_dest._layout_params["orig_dtype"] = orig_dtype
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
return qt_dest
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype)
def generic_to_dtype(func, args, kwargs):
"""Handle .to(dtype) calls - dtype conversion only."""
src = args[0]
if isinstance(src, QuantizedTensor):
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
src._layout_params["orig_dtype"] = target_dtype
return src
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True
@register_generic_util(torch.ops.aten.empty_like.default)
def generic_empty_like(func, args, kwargs):
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
# Create empty tensor with same shape and dtype as the quantized data
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
new_qdata = torch.empty_like(qt._qdata, **kwargs)
# Handle device transfer for layout params
target_device = kwargs.get('device', new_qdata.device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
# Update orig_dtype if dtype is specified
new_params['orig_dtype'] = hp_dtype
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
return func(*args, **kwargs)
# ==============================================================================
# FP8 Layout + Operation Handlers
# ==============================================================================
class TensorCoreFP8Layout(QuantizedLayout):
"""
Storage format:
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
- scale: Scalar tensor (float32) for dequantization
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod @classmethod
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if cls.FP8_DTYPE is None:
raise NotImplementedError(f"{cls.__name__} must define FP8_DTYPE")
orig_dtype = tensor.dtype orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
if isinstance(scale, str) and scale == "recalculate": if isinstance(scale, str) and scale == "recalculate":
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
tensor_info = torch.finfo(tensor.dtype)
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
if scale is not None: if scale is None:
if not isinstance(scale, torch.Tensor): scale = torch.ones((), device=tensor.device, dtype=torch.float32)
scale = torch.tensor(scale) if not isinstance(scale, torch.Tensor):
scale = scale.to(device=tensor.device, dtype=torch.float32) scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
if stochastic_rounding > 0:
if inplace_ops: if inplace_ops:
tensor *= (1.0 / scale).to(tensor.dtype) tensor *= (1.0 / scale).to(tensor.dtype)
else: else:
tensor = tensor * (1.0 / scale).to(tensor.dtype) tensor = tensor * (1.0 / scale).to(tensor.dtype)
qdata = comfy.float.stochastic_rounding(tensor, dtype=cls.FP8_DTYPE, seed=stochastic_rounding)
else: else:
scale = torch.ones((), device=tensor.device, dtype=torch.float32) qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE)
if stochastic_rounding > 0: params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape)
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) return qdata, params
else:
lp_amax = torch.finfo(dtype).max
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
layout_params = {
'scale': scale,
'orig_dtype': orig_dtype
}
return tensor, layout_params
@staticmethod class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
def dequantize(qdata, scale, orig_dtype, **kwargs): FP8_DTYPE = torch.float8_e4m3fn
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
plain_tensor.mul_(scale)
return plain_tensor
@classmethod
def get_plain_tensors(cls, qtensor): class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
return qtensor._qdata, qtensor._layout_params['scale'] FP8_DTYPE = torch.float8_e5m2
# Backward compatibility alias - default to E4M3
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
# ==============================================================================
# Registry
# ==============================================================================
register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
QUANT_ALGOS = { QUANT_ALGOS = {
"float8_e4m3fn": { "float8_e4m3fn": {
"storage_t": torch.float8_e4m3fn, "storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"}, "parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreFP8Layout", "comfy_tensor_layout": "TensorCoreFP8E4M3Layout",
},
"float8_e5m2": {
"storage_t": torch.float8_e5m2,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreFP8E5M2Layout",
},
"nvfp4": {
"storage_t": torch.uint8,
"parameters": {"weight_scale", "weight_scale_2", "input_scale"},
"comfy_tensor_layout": "TensorCoreNVFP4Layout",
"group_size": 16,
}, },
} }
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}
# ==============================================================================
# Re-exports for backward compatibility
# ==============================================================================
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout") __all__ = [
def fp8_linear(func, args, kwargs): "QuantizedTensor",
input_tensor = args[0] "QuantizedLayout",
weight = args[1] "TensorCoreFP8Layout",
bias = args[2] if len(args) > 2 else None "TensorCoreFP8E4M3Layout",
"TensorCoreFP8E5M2Layout",
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): "TensorCoreNVFP4Layout",
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) "QUANT_ALGOS",
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) "register_layout_op",
]
out_dtype = kwargs.get("out_dtype")
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
weight_t = plain_weight.t()
tensor_2d = False
if len(plain_input.shape) == 2:
tensor_2d = True
plain_input = plain_input.unsqueeze(1)
input_shape = plain_input.shape
if len(input_shape) != 3:
return None
try:
output = torch._scaled_mm(
plain_input.reshape(-1, input_shape[2]).contiguous(),
weight_t,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
if not tensor_2d:
output = output.reshape((-1, input_shape[1], weight.shape[0]))
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
output_scale = scale_a * scale_b
output_params = {
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
else:
return output
except Exception as e:
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
# Case 2: DQ Fallback
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
if isinstance(input_tensor, QuantizedTensor):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias)
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
output = torch._scaled_mm(
plain_input.contiguous(),
plain_weight,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
return output
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
def fp8_addmm(func, args, kwargs):
input_tensor = args[1]
weight = args[2]
bias = args[0]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
if isinstance(args[2], QuantizedTensor):
a[2] = args[2].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
def fp8_mm(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)

View File

@ -122,20 +122,20 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min) minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor( executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling, _prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True) comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
) )
return executor.execute(model, noise_shape, conds, model_options=model_options) return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
real_model: BaseModel = None real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype()) models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options) models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update? models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory) comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
real_model = model.model real_model = model.model
return real_model, conds, models return real_model, conds, models

View File

@ -720,7 +720,7 @@ class Sampler:
sigma = float(sigmas[0]) sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
@ -984,9 +984,6 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
noise = noise.to(device) noise = noise.to(device)
latent_image = latent_image.to(device) latent_image = latent_image.to(device)
sigmas = sigmas.to(device) sigmas = sigmas.to(device)
@ -1013,6 +1010,24 @@ class CFGGuider:
else: else:
latent_shapes = [latent_image.shape] latent_shapes = [latent_image.shape]
if denoise_mask is not None:
if denoise_mask.is_nested:
denoise_masks = denoise_mask.unbind()
denoise_masks = denoise_masks[:len(latent_shapes)]
else:
denoise_masks = [denoise_mask]
for i in range(len(denoise_masks), len(latent_shapes)):
denoise_masks.append(torch.ones(latent_shapes[i]))
for i in range(len(denoise_masks)):
denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
if len(denoise_masks) > 1:
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
else:
denoise_mask = denoise_masks[0]
self.conds = {} self.conds = {}
for k in self.original_conds: for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))

View File

@ -55,6 +55,8 @@ import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image import comfy.text_encoders.z_image
import comfy.text_encoders.ovis import comfy.text_encoders.ovis
import comfy.text_encoders.kandinsky5 import comfy.text_encoders.kandinsky5
import comfy.text_encoders.jina_clip_2
import comfy.text_encoders.newbie
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
@ -127,6 +129,8 @@ class CLIP:
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
self.patcher.is_clip = True self.patcher.is_clip = True
self.apply_hooks_to_conds = None self.apply_hooks_to_conds = None
@ -319,6 +323,7 @@ class VAE:
self.latent_channels = 4 self.latent_channels = 4
self.latent_dim = 2 self.latent_dim = 2
self.output_channels = 3 self.output_channels = 3
self.pad_channel_value = None
self.process_input = lambda image: image * 2.0 - 1.0 self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
@ -433,6 +438,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
self.latent_channels = 64 self.latent_channels = 64
self.output_channels = 2 self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 2048 self.upscale_ratio = 2048
self.downscale_ratio = 2048 self.downscale_ratio = 2048
self.latent_dim = 1 self.latent_dim = 1
@ -544,11 +550,15 @@ class VAE:
self.downscale_index_formula = (4, 8, 8) self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3 self.latent_dim = 3
self.latent_channels = 16 self.latent_channels = 16
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} self.output_channels = sd["encoder.conv1.weight"].shape[1]
self.pad_channel_value = 1.0
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype)
# Hunyuan 3d v2 2.0 & 2.1 # Hunyuan 3d v2 2.0 & 2.1
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd: elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
@ -578,6 +588,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
self.latent_channels = 8 self.latent_channels = 8
self.output_channels = 2 self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 4096 self.upscale_ratio = 4096
self.downscale_ratio = 4096 self.downscale_ratio = 4096
self.latent_dim = 2 self.latent_dim = 2
@ -686,17 +697,28 @@ class VAE:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels): def vae_encode_crop_pixels(self, pixels):
if not self.crop_input: if self.crop_input:
return pixels downscale_ratio = self.spacial_compression_encode()
downscale_ratio = self.spacial_compression_encode() dims = pixels.shape[1:-1]
for d in range(len(dims)):
x = (dims[d] // downscale_ratio) * downscale_ratio
x_offset = (dims[d] % downscale_ratio) // 2
if x != dims[d]:
pixels = pixels.narrow(d + 1, x_offset, x)
dims = pixels.shape[1:-1] if pixels.shape[-1] > self.output_channels:
for d in range(len(dims)): pixels = pixels[..., :self.output_channels]
x = (dims[d] // downscale_ratio) * downscale_ratio elif pixels.shape[-1] < self.output_channels:
x_offset = (dims[d] % downscale_ratio) // 2 if self.pad_channel_value is not None:
if x != dims[d]: if isinstance(self.pad_channel_value, str):
pixels = pixels.narrow(d + 1, x_offset, x) mode = self.pad_channel_value
value = None
else:
mode = "constant"
value = self.pad_channel_value
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
return pixels return pixels
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
@ -988,6 +1010,7 @@ class CLIPType(Enum):
OVIS = 21 OVIS = 21
KANDINSKY5 = 22 KANDINSKY5 = 22
KANDINSKY5_IMAGE = 23 KANDINSKY5_IMAGE = 23
NEWBIE = 24
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@ -1018,6 +1041,8 @@ class TEModel(Enum):
MISTRAL3_24B_PRUNED_FLUX2 = 15 MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16 QWEN3_4B = 16
QWEN3_2B = 17 QWEN3_2B = 17
GEMMA_3_12B = 18
JINA_CLIP_2 = 19
def detect_te_model(sd): def detect_te_model(sd):
@ -1027,6 +1052,8 @@ def detect_te_model(sd):
return TEModel.CLIP_H return TEModel.CLIP_H
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd: if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
return TEModel.CLIP_L return TEModel.CLIP_L
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
return TEModel.JINA_CLIP_2
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
if weight.shape[-1] == 4096: if weight.shape[-1] == 4096:
@ -1041,6 +1068,8 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd: if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd: if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_4B return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B return TEModel.GEMMA_2_2B
@ -1187,6 +1216,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.QWEN3_2B: elif te_model == TEModel.QWEN3_2B:
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
elif te_model == TEModel.JINA_CLIP_2:
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
else: else:
# clip_l # clip_l
if clip_type == CLIPType.SD3: if clip_type == CLIPType.SD3:
@ -1242,6 +1274,21 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.KANDINSKY5_IMAGE: elif clip_type == CLIPType.KANDINSKY5_IMAGE:
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif clip_type == CLIPType.NEWBIE:
clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer
if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]:
clip_data_gemma = clip_data[0]
clip_data_jina = clip_data[1]
else:
clip_data_gemma = clip_data[1]
clip_data_jina = clip_data[0]
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
else: else:
clip_target.clip = sdxl_clip.SDXLClipModel clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip_target.tokenizer = sdxl_clip.SDXLTokenizer

View File

@ -466,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out return embed_out
class SDTokenizer: class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}): def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}):
if tokenizer_path is None: if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
@ -513,6 +513,8 @@ class SDTokenizer:
self.embedding_size = embedding_size self.embedding_size = embedding_size
self.embedding_key = embedding_key self.embedding_key = embedding_key
self.disable_weights = disable_weights
def _try_get_embedding(self, embedding_name:str): def _try_get_embedding(self, embedding_name:str):
''' '''
Takes a potential embedding name and tries to retrieve it. Takes a potential embedding name and tries to retrieve it.
@ -547,7 +549,7 @@ class SDTokenizer:
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding) min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
text = escape_important(text) text = escape_important(text)
if kwargs.get("disable_weights", False): if kwargs.get("disable_weights", self.disable_weights):
parsed_weights = [(text, 1.0)] parsed_weights = [(text, 1.0)]
else: else:
parsed_weights = token_weights(text, 1.0) parsed_weights = token_weights(text, 1.0)

View File

@ -28,6 +28,7 @@ from . import supported_models_base
from . import latent_formats from . import latent_formats
from . import diffusers_convert from . import diffusers_convert
import comfy.model_management
class SD15(supported_models_base.BASE): class SD15(supported_models_base.BASE):
unet_config = { unet_config = {
@ -541,7 +542,7 @@ class SD3(supported_models_base.BASE):
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.SD3 latent_format = latent_formats.SD3
memory_usage_factor = 1.2 memory_usage_factor = 1.6
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]
@ -835,6 +836,21 @@ class LTXV(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
class LTXAV(LTXV):
unet_config = {
"image_model": "ltxav",
}
latent_format = latent_formats.LTXAV
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = 0.055 # TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXAV(self, device=device)
return out
class HunyuanVideo(supported_models_base.BASE): class HunyuanVideo(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "hunyuan_video", "image_model": "hunyuan_video",
@ -965,7 +981,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
def __init__(self, unet_config): def __init__(self, unet_config):
super().__init__(unet_config) super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9 self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, device=device) out = model_base.CosmosPredict2(self, device=device)
@ -1026,9 +1042,15 @@ class ZImage(Lumina2):
"shift": 3.0, "shift": 3.0,
} }
memory_usage_factor = 1.7 memory_usage_factor = 2.0
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
if comfy.model_management.extended_fp16_support():
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
self.supported_inference_dtypes.insert(1, torch.float16)
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
@ -1289,7 +1311,7 @@ class ChromaRadiance(Chroma):
latent_format = comfy.latent_formats.ChromaRadiance latent_format = comfy.latent_formats.ChromaRadiance
# Pixel-space model, no spatial compression for model input. # Pixel-space model, no spatial compression for model input.
memory_usage_factor = 0.038 memory_usage_factor = 0.044
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
return model_base.ChromaRadiance(self, device=device) return model_base.ChromaRadiance(self, device=device)
@ -1332,7 +1354,7 @@ class Omnigen2(supported_models_base.BASE):
"shift": 2.6, "shift": 2.6,
} }
memory_usage_factor = 1.65 #TODO memory_usage_factor = 1.95 #TODO
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.Flux latent_format = latent_formats.Flux
@ -1397,7 +1419,7 @@ class HunyuanImage21(HunyuanVideo):
latent_format = latent_formats.HunyuanImage21 latent_format = latent_formats.HunyuanImage21
memory_usage_factor = 7.7 memory_usage_factor = 8.7
supported_inference_dtypes = [torch.bfloat16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float32]
@ -1488,7 +1510,7 @@ class Kandinsky5(supported_models_base.BASE):
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 1.1 #TODO memory_usage_factor = 1.25 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float32]
@ -1517,7 +1539,7 @@ class Kandinsky5Image(Kandinsky5):
} }
latent_format = latent_formats.Flux latent_format = latent_formats.Flux
memory_usage_factor = 1.1 #TODO memory_usage_factor = 1.25 #TODO
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.Kandinsky5Image(self, device=device) out = model_base.Kandinsky5Image(self, device=device)
@ -1529,6 +1551,6 @@ class Kandinsky5Image(Kandinsky5):
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -154,7 +154,8 @@ class TAEHV(nn.Module):
self._show_progress_bar = value self._show_progress_bar = value
def encode(self, x, **kwargs): def encode(self, x, **kwargs):
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size) if self.patch_size > 1:
x = F.pixel_unshuffle(x, self.patch_size)
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
if x.shape[1] % 4 != 0: if x.shape[1] % 4 != 0:
# pad at end to multiple of 4 # pad at end to multiple of 4
@ -167,5 +168,6 @@ class TAEHV(nn.Module):
def decode(self, x, **kwargs): def decode(self, x, **kwargs):
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size) if self.patch_size > 1:
x = F.pixel_shuffle(x, self.patch_size)
return x[:, self.frames_to_trim:].movedim(2, 1) return x[:, self.frames_to_trim:].movedim(2, 1)

View File

@ -0,0 +1,219 @@
# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation:
# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py
# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py
from dataclasses import dataclass
import torch
from torch import nn as nn
from torch.nn import functional as F
import comfy.model_management
import comfy.ops
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
@dataclass
class XLMRobertaConfig:
vocab_size: int = 250002
type_vocab_size: int = 1
hidden_size: int = 1024
num_hidden_layers: int = 24
num_attention_heads: int = 16
rotary_emb_base: float = 20000.0
intermediate_size: int = 4096
hidden_act: str = "gelu"
hidden_dropout_prob: float = 0.1
attention_probs_dropout_prob: float = 0.1
layer_norm_eps: float = 1e-05
bos_token_id: int = 0
eos_token_id: int = 2
pad_token_id: int = 1
class XLMRobertaEmbeddings(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
embed_dim = config.hidden_size
self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype)
self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype)
def forward(self, input_ids=None, embeddings=None):
if input_ids is not None and embeddings is None:
embeddings = self.word_embeddings(input_ids)
if embeddings is not None:
token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = embeddings + token_type_embeddings
return embeddings
class RotaryEmbedding(nn.Module):
def __init__(self, dim, base, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
emb = torch.cat((freqs, freqs), dim=-1)
self._cos_cached = emb.cos().to(dtype)
self._sin_cached = emb.sin().to(dtype)
def forward(self, q, k):
batch, seqlen, heads, head_dim = q.shape
self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype)
cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim)
sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim)
def rotate_half(x):
size = x.shape[-1] // 2
x1, x2 = x[..., :size], x[..., size:]
return torch.cat((-x2, x1), dim=-1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class MHA(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = embed_dim // config.num_attention_heads
self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device)
self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
def forward(self, x, mask=None, optimized_attention=None):
qkv = self.Wqkv(x)
batch_size, seq_len, _ = qkv.shape
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(2)
q, k = self.rotary_emb(q, k)
# NHD -> HND
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
return self.out_proj(out)
class MLP(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
self.activation = F.gelu
self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
class Block(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.mixer = MHA(config, device=device, dtype=dtype, ops=ops)
self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
def forward(self, hidden_states, mask=None, optimized_attention=None):
mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention)
hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states)
mlp_out = self.mlp(hidden_states)
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
return hidden_states
class XLMRobertaEncoder(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None):
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
for layer in self.layers:
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
return hidden_states
class XLMRobertaModel_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops)
self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
x = self.emb_ln(x)
x = self.emb_drop(x)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1]))
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
sequence_output = self.encoder(x, attention_mask=mask)
# Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py
pooled_output = None
if attention_mask is None:
pooled_output = sequence_output.mean(dim=1)
else:
attention_mask = attention_mask.to(sequence_output.dtype)
pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True)
# Intermediate output is not yet implemented, use None for placeholder
return sequence_output, None, pooled_output
class XLMRobertaModel(nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.config = XLMRobertaConfig(**config_dict)
self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations)
self.num_layers = self.config.num_hidden_layers
def get_input_embeddings(self):
return self.model.embeddings.word_embeddings
def set_input_embeddings(self, embeddings):
self.model.embeddings.word_embeddings = embeddings
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
class JinaClip2TextModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)

View File

@ -3,13 +3,12 @@ import torch.nn as nn
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Any from typing import Optional, Any
import math import math
import logging
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management import comfy.model_management
import comfy.ldm.common_dit import comfy.ldm.common_dit
import comfy.clip_model
import comfy.model_management
from . import qwen_vl from . import qwen_vl
@dataclass @dataclass
@ -177,7 +176,7 @@ class Gemma3_4B_Config:
num_key_value_heads: int = 4 num_key_value_heads: int = 4
max_position_embeddings: int = 131072 max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6 rms_norm_eps: float = 1e-6
rope_theta = [10000.0, 1000000.0] rope_theta = [1000000.0, 10000.0]
transformer_type: str = "gemma3" transformer_type: str = "gemma3"
head_dim = 256 head_dim = 256
rms_norm_add = True rms_norm_add = True
@ -186,10 +185,35 @@ class Gemma3_4B_Config:
rope_dims = None rope_dims = None
q_norm = "gemma3" q_norm = "gemma3"
k_norm = "gemma3" k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024] sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [1.0, 8.0] rope_scale = [8.0, 1.0]
final_norm: bool = True final_norm: bool = True
@dataclass
class Gemma3_12B_Config:
vocab_size: int = 262208
hidden_size: int = 3840
intermediate_size: int = 15360
num_hidden_layers: int = 48
num_attention_heads: int = 16
num_key_value_heads: int = 8
max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6
rope_theta = [1000000.0, 10000.0]
transformer_type: str = "gemma3"
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [8.0, 1.0]
final_norm: bool = True
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
mm_tokens_per_image = 256
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
super().__init__() super().__init__()
@ -370,7 +394,7 @@ class TransformerBlockGemma2(nn.Module):
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens) if config.sliding_attention is not None:
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)] self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
else: else:
self.sliding_attention = False self.sliding_attention = False
@ -387,7 +411,12 @@ class TransformerBlockGemma2(nn.Module):
if self.transformer_type == 'gemma3': if self.transformer_type == 'gemma3':
if self.sliding_attention: if self.sliding_attention:
if x.shape[1] > self.sliding_attention: if x.shape[1] > self.sliding_attention:
logging.warning("Warning: sliding attention not implemented, results may be incorrect") sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
sliding_mask.tril_(diagonal=-self.sliding_attention)
if attention_mask is not None:
attention_mask = attention_mask + sliding_mask
else:
attention_mask = sliding_mask
freqs_cis = freqs_cis[1] freqs_cis = freqs_cis[1]
else: else:
freqs_cis = freqs_cis[0] freqs_cis = freqs_cis[0]
@ -517,6 +546,41 @@ class Llama2_(nn.Module):
return x, intermediate return x, intermediate
class Gemma3MultiModalProjector(torch.nn.Module):
def __init__(self, config, dtype, device, operations):
super().__init__()
self.mm_input_projection_weight = nn.Parameter(
torch.empty(config.vision_config["hidden_size"], config.hidden_size, device=device, dtype=dtype)
)
self.mm_soft_emb_norm = RMSNorm(config.vision_config["hidden_size"], eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.patches_per_image = int(config.vision_config["image_size"] // config.vision_config["patch_size"])
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
self.kernel_size = self.patches_per_image // self.tokens_per_side
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
def forward(self, vision_outputs: torch.Tensor):
batch_size, _, seq_length = vision_outputs.shape
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
batch_size, seq_length, self.patches_per_image, self.patches_per_image
)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
projected_vision_outputs = torch.matmul(normed_vision_outputs, comfy.model_management.cast_to_device(self.mm_input_projection_weight, device=normed_vision_outputs.device, dtype=normed_vision_outputs.dtype))
return projected_vision_outputs.type_as(vision_outputs)
class BaseLlama: class BaseLlama:
def get_input_embeddings(self): def get_input_embeddings(self):
return self.model.embed_tokens return self.model.embed_tokens
@ -633,3 +697,21 @@ class Gemma3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Gemma3_12B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_12B_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
self.dtype = dtype
self.image_size = config.vision_config["image_size"]
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
return None, None

View File

@ -1,7 +1,11 @@
from comfy import sd1_clip from comfy import sd1_clip
import os import os
from transformers import T5TokenizerFast from transformers import T5TokenizerFast
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.genmo import comfy.text_encoders.genmo
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import torch
import comfy.utils
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -16,3 +20,112 @@ class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
def ltxv_te(*args, **kwargs): def ltxv_te(*args, **kwargs):
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs) return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
class Gemma3_12BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs):
text = llama_template.format(text)
text_tokens = super().tokenize_with_weights(text, return_word_ids)
embed_count = 0
for k in text_tokens:
tt = text_tokens[k]
for r in tt:
for i in range(len(r)):
if r[i][0] == 262144:
if image_embeds is not None and embed_count < image_embeds.shape[0]:
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1
return text_tokens
class LTXAVTEModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = set()
self.dtypes.add(dtype)
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
self.dtypes.add(dtype_llama)
operations = self.gemma3_12b.operations # TODO
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=operations,
)
def set_clip_options(self, options):
self.execution_device = options.get("execution_device", self.execution_device)
self.gemma3_12b.set_clip_options(options)
def reset_clip_options(self):
self.gemma3_12b.reset_clip_options()
self.execution_device = None
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs["gemma3_12b"]
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
out_device = out.device
out = out.movedim(1, -1).to(self.execution_device)
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
return out.to(out_device), pooled
def load_sd(self, sd):
if "model.layers.47.self_attn.q_norm.weight" in sd:
return self.gemma3_12b.load_sd(sd)
else:
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
if len(sdo) == 0:
sdo = sd
return self.load_state_dict(sdo, strict=False)
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
class LTXAVTEModel_(LTXAVTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return LTXAVTEModel_

View File

@ -14,7 +14,7 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer): class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None) tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
def state_dict(self): def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()} return {"spiece_model": self.tokenizer.serialize_model()}
@ -33,6 +33,11 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
class Gemma3_4BModel(sd1_clip.SDClipModel): class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class LuminaModel(sd1_clip.SD1ClipModel): class LuminaModel(sd1_clip.SD1ClipModel):

View File

@ -0,0 +1,62 @@
import torch
import comfy.model_management
import comfy.text_encoders.jina_clip_2
import comfy.text_encoders.lumina2
class NewBieTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
raise NotImplementedError
def state_dict(self):
return {}
class NewBieTEModel(torch.nn.Module):
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device)
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
self.dtypes = {dtype, dtype_gemma}
def set_clip_options(self, options):
self.gemma.set_clip_options(options)
self.jina.set_clip_options(options)
def reset_clip_options(self):
self.gemma.reset_clip_options()
self.jina.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_gemma = token_weight_pairs["gemma"]
token_weight_pairs_jina = token_weight_pairs["jina"]
gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma)
jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina)
return gemma_out, jina_pooled, gemma_extra
def load_sd(self, sd):
if "model.layers.0.self_attn.q_norm.weight" in sd:
return self.gemma.load_sd(sd)
else:
return self.jina.load_sd(sd)
def te(dtype_llama=None, llama_quantization_metadata=None):
class NewBieTEModel_(NewBieTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return NewBieTEModel_

View File

@ -53,7 +53,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
ALWAYS_SAFE_LOAD = True ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.") logging.info("Checkpoint files will always be loaded safely.")
else: else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
# Try GDS loading first if available and device is GPU # Try GDS loading first if available and device is GPU
@ -815,12 +815,17 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
return None return None
return f.read(length_of_header) return f.read(length_of_header)
ATTR_UNSET={}
def set_attr(obj, attr, value): def set_attr(obj, attr, value):
attrs = attr.split(".") attrs = attr.split(".")
for name in attrs[:-1]: for name in attrs[:-1]:
obj = getattr(obj, name) obj = getattr(obj, name)
prev = getattr(obj, attrs[-1]) prev = getattr(obj, attrs[-1], ATTR_UNSET)
setattr(obj, attrs[-1], value) if value is ATTR_UNSET:
delattr(obj, attrs[-1])
else:
setattr(obj, attrs[-1], value)
return prev return prev
def set_attr_param(obj, attr, value): def set_attr_param(obj, attr, value):
@ -1205,7 +1210,7 @@ def unpack_latents(combined_latent, latent_shapes):
combined_latent = combined_latent[:, :, cut:] combined_latent = combined_latent[:, :, cut:]
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:])) output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
else: else:
output_tensors = combined_latent output_tensors = [combined_latent]
return output_tensors return output_tensors
def detect_layer_quantization(state_dict, prefix): def detect_layer_quantization(state_dict, prefix):
@ -1237,6 +1242,8 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
out_sd = {} out_sd = {}
layers = {} layers = {}
for k in list(state_dict.keys()): for k in list(state_dict.keys()):
if k == scaled_fp8_key:
continue
if not k.startswith(model_prefix): if not k.startswith(model_prefix):
out_sd[k] = state_dict[k] out_sd[k] = state_dict[k]
continue continue
@ -1269,6 +1276,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
if quant_metadata is not None: if quant_metadata is not None:
layers = quant_metadata["layers"] layers = quant_metadata["layers"]
for k, v in layers.items(): for k, v in layers.items():
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8) state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
return state_dict, metadata return state_dict, metadata

View File

@ -5,12 +5,12 @@ This module handles capability negotiation between frontend and backend,
allowing graceful protocol evolution while maintaining backward compatibility. allowing graceful protocol evolution while maintaining backward compatibility.
""" """
from typing import Any, Dict from typing import Any
from comfy.cli_args import args from comfy.cli_args import args
# Default server capabilities # Default server capabilities
SERVER_FEATURE_FLAGS: Dict[str, Any] = { SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True, "supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}}, "extension": {"manager": {"supports_v4": True}},
@ -18,7 +18,7 @@ SERVER_FEATURE_FLAGS: Dict[str, Any] = {
def get_connection_feature( def get_connection_feature(
sockets_metadata: Dict[str, Dict[str, Any]], sockets_metadata: dict[str, dict[str, Any]],
sid: str, sid: str,
feature_name: str, feature_name: str,
default: Any = False default: Any = False
@ -42,7 +42,7 @@ def get_connection_feature(
def supports_feature( def supports_feature(
sockets_metadata: Dict[str, Dict[str, Any]], sockets_metadata: dict[str, dict[str, Any]],
sid: str, sid: str,
feature_name: str feature_name: str
) -> bool: ) -> bool:
@ -60,7 +60,7 @@ def supports_feature(
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
def get_server_features() -> Dict[str, Any]: def get_server_features() -> dict[str, Any]:
""" """
Get the server's feature flags. Get the server's feature flags.

View File

@ -1,4 +1,4 @@
from typing import Type, List, NamedTuple from typing import NamedTuple
from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.singleton import ProxiedSingleton
from packaging import version as packaging_version from packaging import version as packaging_version
@ -10,7 +10,7 @@ class ComfyAPIBase(ProxiedSingleton):
class ComfyAPIWithVersion(NamedTuple): class ComfyAPIWithVersion(NamedTuple):
version: str version: str
api_class: Type[ComfyAPIBase] api_class: type[ComfyAPIBase]
def parse_version(version_str: str) -> packaging_version.Version: def parse_version(version_str: str) -> packaging_version.Version:
@ -23,16 +23,16 @@ def parse_version(version_str: str) -> packaging_version.Version:
return packaging_version.parse(version_str) return packaging_version.parse(version_str)
registered_versions: List[ComfyAPIWithVersion] = [] registered_versions: list[ComfyAPIWithVersion] = []
def register_versions(versions: List[ComfyAPIWithVersion]): def register_versions(versions: list[ComfyAPIWithVersion]):
versions.sort(key=lambda x: parse_version(x.version)) versions.sort(key=lambda x: parse_version(x.version))
global registered_versions global registered_versions
registered_versions = versions registered_versions = versions
def get_all_versions() -> List[ComfyAPIWithVersion]: def get_all_versions() -> list[ComfyAPIWithVersion]:
""" """
Returns a list of all registered ComfyAPI versions. Returns a list of all registered ComfyAPI versions.
""" """

View File

@ -8,7 +8,7 @@ import os
import textwrap import textwrap
import threading import threading
from enum import Enum from enum import Enum
from typing import Optional, Type, get_origin, get_args, get_type_hints from typing import Optional, get_origin, get_args, get_type_hints
class TypeTracker: class TypeTracker:
@ -193,7 +193,7 @@ class AsyncToSyncConverter:
return result_container["result"] return result_container["result"]
@classmethod @classmethod
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type: def create_sync_class(cls, async_class: type, thread_pool_size=10) -> type:
""" """
Creates a new class with synchronous versions of all async methods. Creates a new class with synchronous versions of all async methods.
@ -563,7 +563,7 @@ class AsyncToSyncConverter:
@classmethod @classmethod
def _generate_imports( def _generate_imports(
cls, async_class: Type, type_tracker: TypeTracker cls, async_class: type, type_tracker: TypeTracker
) -> list[str]: ) -> list[str]:
"""Generate import statements for the stub file.""" """Generate import statements for the stub file."""
imports = [] imports = []
@ -628,7 +628,7 @@ class AsyncToSyncConverter:
return imports return imports
@classmethod @classmethod
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]: def _get_class_attributes(cls, async_class: type) -> list[tuple[str, type]]:
"""Extract class attributes that are classes themselves.""" """Extract class attributes that are classes themselves."""
class_attributes = [] class_attributes = []
@ -654,7 +654,7 @@ class AsyncToSyncConverter:
def _generate_inner_class_stub( def _generate_inner_class_stub(
cls, cls,
name: str, name: str,
attr: Type, attr: type,
indent: str = " ", indent: str = " ",
type_tracker: Optional[TypeTracker] = None, type_tracker: Optional[TypeTracker] = None,
) -> list[str]: ) -> list[str]:
@ -782,7 +782,7 @@ class AsyncToSyncConverter:
return processed return processed
@classmethod @classmethod
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None: def generate_stub_file(cls, async_class: type, sync_class: type) -> None:
""" """
Generate a .pyi stub file for the sync class to help IDEs with type checking. Generate a .pyi stub file for the sync class to help IDEs with type checking.
""" """
@ -988,7 +988,7 @@ class AsyncToSyncConverter:
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type: def create_sync_class(async_class: type, thread_pool_size=10) -> type:
""" """
Creates a sync version of an async class Creates a sync version of an async class

View File

@ -1,4 +1,4 @@
from typing import Type, TypeVar from typing import TypeVar
class SingletonMetaclass(type): class SingletonMetaclass(type):
T = TypeVar("T", bound="SingletonMetaclass") T = TypeVar("T", bound="SingletonMetaclass")
@ -11,13 +11,13 @@ class SingletonMetaclass(type):
) )
return cls._instances[cls] return cls._instances[cls]
def inject_instance(cls: Type[T], instance: T) -> None: def inject_instance(cls: type[T], instance: T) -> None:
assert cls not in SingletonMetaclass._instances, ( assert cls not in SingletonMetaclass._instances, (
"Cannot inject instance after first instantiation" "Cannot inject instance after first instantiation"
) )
SingletonMetaclass._instances[cls] = instance SingletonMetaclass._instances[cls] = instance
def get_instance(cls: Type[T], *args, **kwargs) -> T: def get_instance(cls: type[T], *args, **kwargs) -> T:
""" """
Gets the singleton instance of the class, creating it if it doesn't exist. Gets the singleton instance of the class, creating it if it doesn't exist.
""" """

View File

@ -1,16 +1,15 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Type, TYPE_CHECKING from typing import TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents from ._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io_public as io from . import _io_public as io
from . import _ui_public as ui from . import _ui_public as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple from comfy_execution.progress import get_progress_state, PreviewImageTuple
from PIL import Image from PIL import Image
@ -80,7 +79,7 @@ class ComfyExtension(ABC):
async def on_load(self) -> None: async def on_load(self) -> None:
""" """
Called when an extension is loaded. Called when an extension is loaded.
This should be used to initialize any global resources neeeded by the extension. This should be used to initialize any global resources needed by the extension.
""" """
@abstractmethod @abstractmethod
@ -113,7 +112,7 @@ ComfyAPI = ComfyAPI_latest
if TYPE_CHECKING: if TYPE_CHECKING:
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub] ComfyAPISync: type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest) ComfyAPISync = create_sync_class(ComfyAPI_latest)
# create new aliases for io and ui # create new aliases for io and ui

View File

@ -1,5 +1,5 @@
import torch import torch
from typing import TypedDict, List, Optional from typing import TypedDict, Optional
ImageInput = torch.Tensor ImageInput = torch.Tensor
""" """
@ -39,4 +39,4 @@ class LatentInput(TypedDict):
Optional noise mask tensor in the same format as samples. Optional noise mask tensor in the same format as samples.
""" """
batch_index: Optional[List[int]] batch_index: Optional[list[int]]

View File

@ -4,7 +4,7 @@ from fractions import Fraction
from typing import Optional, Union, IO from typing import Optional, Union, IO
import io import io
import av import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents from .._util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC): class VideoInput(ABC):
""" """

View File

@ -3,14 +3,14 @@ from av.container import InputContainer
from av.subtitles.stream import SubtitleStream from av.subtitles.stream import SubtitleStream
from fractions import Fraction from fractions import Fraction
from typing import Optional from typing import Optional
from comfy_api.latest._input import AudioInput, VideoInput from .._input import AudioInput, VideoInput
import av import av
import io import io
import json import json
import numpy as np import numpy as np
import math import math
import torch import torch
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents from .._util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None: def container_to_output_format(container_format: str | None) -> str | None:

View File

@ -26,11 +26,9 @@ if TYPE_CHECKING:
from comfy_api.input import VideoInput from comfy_api.input import VideoInput
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class) prune_dict, shallow_clone_class)
from comfy_api.latest._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL from ._util import MESH, VOXEL, SVG as _SVG
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
class FolderType(str, Enum): class FolderType(str, Enum):
input = "input" input = "input"
@ -77,16 +75,6 @@ class NumberDisplay(str, Enum):
slider = "slider" slider = "slider"
class _StringIOType(str):
def __ne__(self, value: object) -> bool:
if self == "*" or value == "*":
return False
if not isinstance(value, str):
return True
a = frozenset(self.split(","))
b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b))
class _ComfyType(ABC): class _ComfyType(ABC):
Type = Any Type = Any
io_type: str = None io_type: str = None
@ -126,8 +114,7 @@ def comfytype(io_type: str, **kwargs):
new_cls.__module__ = cls.__module__ new_cls.__module__ = cls.__module__
new_cls.__doc__ = cls.__doc__ new_cls.__doc__ = cls.__doc__
# assign ComfyType attributes, if needed # assign ComfyType attributes, if needed
# NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details) new_cls.io_type = io_type
new_cls.io_type = _StringIOType(io_type)
if hasattr(new_cls, "Input") and new_cls.Input is not None: if hasattr(new_cls, "Input") and new_cls.Input is not None:
new_cls.Input.Parent = new_cls new_cls.Input.Parent = new_cls
if hasattr(new_cls, "Output") and new_cls.Output is not None: if hasattr(new_cls, "Output") and new_cls.Output is not None:
@ -166,7 +153,7 @@ class Input(_IO_V3):
''' '''
Base class for a V3 Input. Base class for a V3 Input.
''' '''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__() super().__init__()
self.id = id self.id = id
self.display_name = display_name self.display_name = display_name
@ -174,6 +161,7 @@ class Input(_IO_V3):
self.tooltip = tooltip self.tooltip = tooltip
self.lazy = lazy self.lazy = lazy
self.extra_dict = extra_dict if extra_dict is not None else {} self.extra_dict = extra_dict if extra_dict is not None else {}
self.rawLink = raw_link
def as_dict(self): def as_dict(self):
return prune_dict({ return prune_dict({
@ -181,10 +169,11 @@ class Input(_IO_V3):
"optional": self.optional, "optional": self.optional,
"tooltip": self.tooltip, "tooltip": self.tooltip,
"lazy": self.lazy, "lazy": self.lazy,
"rawLink": self.rawLink,
}) | prune_dict(self.extra_dict) }) | prune_dict(self.extra_dict)
def get_io_type(self): def get_io_type(self):
return _StringIOType(self.io_type) return self.io_type
def get_all(self) -> list[Input]: def get_all(self) -> list[Input]:
return [self] return [self]
@ -195,8 +184,8 @@ class WidgetInput(Input):
''' '''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: Any=None, default: Any=None,
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None): socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
self.default = default self.default = default
self.socketless = socketless self.socketless = socketless
self.widget_type = widget_type self.widget_type = widget_type
@ -218,13 +207,14 @@ class Output(_IO_V3):
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
is_output_list=False): is_output_list=False):
self.id = id self.id = id
self.display_name = display_name self.display_name = display_name if display_name else id
self.tooltip = tooltip self.tooltip = tooltip
self.is_output_list = is_output_list self.is_output_list = is_output_list
def as_dict(self): def as_dict(self):
display_name = self.display_name if self.display_name else self.id
return prune_dict({ return prune_dict({
"display_name": self.display_name, "display_name": display_name,
"tooltip": self.tooltip, "tooltip": self.tooltip,
"is_output_list": self.is_output_list, "is_output_list": self.is_output_list,
}) })
@ -252,8 +242,8 @@ class Boolean(ComfyTypeIO):
'''Boolean input.''' '''Boolean input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: bool=None, label_on: str=None, label_off: str=None, default: bool=None, label_on: str=None, label_off: str=None,
socketless: bool=None, force_input: bool=None): socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
self.label_on = label_on self.label_on = label_on
self.label_off = label_off self.label_off = label_off
self.default: bool self.default: bool
@ -272,8 +262,8 @@ class Int(ComfyTypeIO):
'''Integer input.''' '''Integer input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
self.min = min self.min = min
self.max = max self.max = max
self.step = step self.step = step
@ -298,8 +288,8 @@ class Float(ComfyTypeIO):
'''Float input.''' '''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
self.min = min self.min = min
self.max = max self.max = max
self.step = step self.step = step
@ -324,8 +314,8 @@ class String(ComfyTypeIO):
'''String input.''' '''String input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
socketless: bool=None, force_input: bool=None): socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
self.multiline = multiline self.multiline = multiline
self.placeholder = placeholder self.placeholder = placeholder
self.dynamic_prompts = dynamic_prompts self.dynamic_prompts = dynamic_prompts
@ -358,12 +348,14 @@ class Combo(ComfyTypeIO):
image_folder: FolderType=None, image_folder: FolderType=None,
remote: RemoteOptions=None, remote: RemoteOptions=None,
socketless: bool=None, socketless: bool=None,
extra_dict=None,
raw_link: bool=None,
): ):
if isinstance(options, type) and issubclass(options, Enum): if isinstance(options, type) and issubclass(options, Enum):
options = [v.value for v in options] options = [v.value for v in options]
if isinstance(default, Enum): if isinstance(default, Enum):
default = default.value default = default.value
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
self.multiselect = False self.multiselect = False
self.options = options self.options = options
self.control_after_generate = control_after_generate self.control_after_generate = control_after_generate
@ -387,10 +379,6 @@ class Combo(ComfyTypeIO):
super().__init__(id, display_name, tooltip, is_output_list) super().__init__(id, display_name, tooltip, is_output_list)
self.options = options if options is not None else [] self.options = options if options is not None else []
@property
def io_type(self):
return self.options
@comfytype(io_type="COMBO") @comfytype(io_type="COMBO")
class MultiCombo(ComfyTypeI): class MultiCombo(ComfyTypeI):
'''Multiselect Combo input (dropdown for selecting potentially more than one value).''' '''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
@ -399,8 +387,8 @@ class MultiCombo(ComfyTypeI):
class Input(Combo.Input): class Input(Combo.Input):
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
socketless: bool=None): socketless: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless) super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link)
self.multiselect = True self.multiselect = True
self.placeholder = placeholder self.placeholder = placeholder
self.chip = chip self.chip = chip
@ -433,9 +421,9 @@ class Webcam(ComfyTypeIO):
Type = str Type = str
def __init__( def __init__(
self, id: str, display_name: str=None, optional=False, self, id: str, display_name: str=None, optional=False,
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None
): ):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
@comfytype(io_type="MASK") @comfytype(io_type="MASK")
@ -656,7 +644,7 @@ class Video(ComfyTypeIO):
@comfytype(io_type="SVG") @comfytype(io_type="SVG")
class SVG(ComfyTypeIO): class SVG(ComfyTypeIO):
Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3 Type = _SVG
@comfytype(io_type="LORA_MODEL") @comfytype(io_type="LORA_MODEL")
class LoraModel(ComfyTypeIO): class LoraModel(ComfyTypeIO):
@ -774,6 +762,13 @@ class AudioEncoder(ComfyTypeIO):
class AudioEncoderOutput(ComfyTypeIO): class AudioEncoderOutput(ComfyTypeIO):
Type = Any Type = Any
@comfytype(io_type="TRACKS")
class Tracks(ComfyTypeIO):
class TrackDict(TypedDict):
track_path: torch.Tensor
track_visibility: torch.Tensor
Type = TrackDict
@comfytype(io_type="COMFY_MULTITYPED_V3") @comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType: class MultiType:
Type = Any Type = Any
@ -781,7 +776,7 @@ class MultiType:
''' '''
Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values.
''' '''
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
# if id is an Input, then use that Input with overridden values # if id is an Input, then use that Input with overridden values
self.input_override = None self.input_override = None
if isinstance(id, Input): if isinstance(id, Input):
@ -794,7 +789,7 @@ class MultiType:
# if is a widget input, make sure widget_type is set appropriately # if is a widget input, make sure widget_type is set appropriately
if isinstance(self.input_override, WidgetInput): if isinstance(self.input_override, WidgetInput):
self.input_override.widget_type = self.input_override.get_io_type() self.input_override.widget_type = self.input_override.get_io_type()
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
self._io_types = types self._io_types = types
@property @property
@ -848,8 +843,8 @@ class MatchType(ComfyTypeIO):
class Input(Input): class Input(Input):
def __init__(self, id: str, template: MatchType.Template, def __init__(self, id: str, template: MatchType.Template,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
self.template = template self.template = template
def as_dict(self): def as_dict(self):
@ -860,6 +855,8 @@ class MatchType(ComfyTypeIO):
class Output(Output): class Output(Output):
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None, def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
is_output_list=False): is_output_list=False):
if not id and not display_name:
display_name = "MATCHTYPE"
super().__init__(id, display_name, tooltip, is_output_list) super().__init__(id, display_name, tooltip, is_output_list)
self.template = template self.template = template
@ -872,24 +869,30 @@ class DynamicInput(Input, ABC):
''' '''
Abstract class for dynamic input registration. Abstract class for dynamic input registration.
''' '''
def get_dynamic(self) -> list[Input]: pass
return []
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
pass
class DynamicOutput(Output, ABC): class DynamicOutput(Output, ABC):
''' '''
Abstract class for dynamic output registration. Abstract class for dynamic output registration.
''' '''
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, pass
is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
def get_dynamic(self) -> list[Output]:
return []
def handle_prefix(prefix_list: list[str] | None, id: str | None = None) -> list[str]:
if prefix_list is None:
prefix_list = []
if id is not None:
prefix_list = prefix_list + [id]
return prefix_list
def finalize_prefix(prefix_list: list[str] | None, id: str | None = None) -> str:
assert not (prefix_list is None and id is None)
if prefix_list is None:
return id
elif id is not None:
prefix_list = prefix_list + [id]
return ".".join(prefix_list)
@comfytype(io_type="COMFY_AUTOGROW_V3") @comfytype(io_type="COMFY_AUTOGROW_V3")
class Autogrow(ComfyTypeI): class Autogrow(ComfyTypeI):
@ -926,14 +929,6 @@ class Autogrow(ComfyTypeI):
def validate(self): def validate(self):
self.input.validate() self.input.validate()
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
real_inputs = []
for name, input in self.cached_inputs.items():
if name in live_inputs:
real_inputs.append(input)
add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, real_inputs, curr_prefix)
class TemplatePrefix(_AutogrowTemplate): class TemplatePrefix(_AutogrowTemplate):
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10): def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
super().__init__(input) super().__init__(input)
@ -978,22 +973,45 @@ class Autogrow(ComfyTypeI):
"template": self.template.as_dict(), "template": self.template.as_dict(),
}) })
def get_dynamic(self) -> list[Input]:
return self.template.get_all()
def get_all(self) -> list[Input]: def get_all(self) -> list[Input]:
return [self] + self.template.get_all() return [self] + self.template.get_all()
def validate(self): def validate(self):
self.template.validate() self.template.validate()
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): @staticmethod
curr_prefix = f"{curr_prefix}{self.id}." def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
# need to remove self from expected inputs dictionary; replaced by template inputs in frontend # NOTE: purposely do not include self in out_dict; instead use only the template inputs
for inner_dict in d.values(): # need to figure out names based on template type
if self.id in inner_dict: is_names = ("names" in value[1]["template"])
del inner_dict[self.id] is_prefix = ("prefix" in value[1]["template"])
self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix) input = value[1]["template"]["input"]
if is_names:
min = value[1]["template"]["min"]
names = value[1]["template"]["names"]
max = len(names)
elif is_prefix:
prefix = value[1]["template"]["prefix"]
min = value[1]["template"]["min"]
max = value[1]["template"]["max"]
names = [f"{prefix}{i}" for i in range(max)]
# need to create a new input based on the contents of input
template_input = None
for _, dict_input in input.items():
# for now, get just the first value from dict_input
template_input = list(dict_input.values())[0]
new_dict = {}
for i, name in enumerate(names):
expected_id = finalize_prefix(curr_prefix, name)
if expected_id in live_inputs:
# required
if i < min:
type_dict = new_dict.setdefault("required", {})
# optional
else:
type_dict = new_dict.setdefault("optional", {})
type_dict[name] = template_input
parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix)
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3") @comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
class DynamicCombo(ComfyTypeI): class DynamicCombo(ComfyTypeI):
@ -1016,23 +1034,6 @@ class DynamicCombo(ComfyTypeI):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.options = options self.options = options
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
# check if dynamic input's id is in live_inputs
if self.id in live_inputs:
curr_prefix = f"{curr_prefix}{self.id}."
key = live_inputs[self.id]
selected_option = None
for option in self.options:
if option.key == key:
selected_option = option
break
if selected_option is not None:
add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
def get_dynamic(self) -> list[Input]:
return [input for option in self.options for input in option.inputs]
def get_all(self) -> list[Input]: def get_all(self) -> list[Input]:
return [self] + [input for option in self.options for input in option.inputs] return [self] + [input for option in self.options for input in option.inputs]
@ -1047,6 +1048,24 @@ class DynamicCombo(ComfyTypeI):
for input in option.inputs: for input in option.inputs:
input.validate() input.validate()
@staticmethod
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
finalized_id = finalize_prefix(curr_prefix)
if finalized_id in live_inputs:
key = live_inputs[finalized_id]
selected_option = None
# get options from dict
options: list[dict[str, str | dict[str, Any]]] = value[1]["options"]
for option in options:
if option["key"] == key:
selected_option = option
break
if selected_option is not None:
parse_class_inputs(out_dict, live_inputs, selected_option["inputs"], curr_prefix)
# add self to inputs
out_dict[input_type][finalized_id] = value
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
@comfytype(io_type="COMFY_DYNAMICSLOT_V3") @comfytype(io_type="COMFY_DYNAMICSLOT_V3")
class DynamicSlot(ComfyTypeI): class DynamicSlot(ComfyTypeI):
Type = dict[str, Any] Type = dict[str, Any]
@ -1069,17 +1088,8 @@ class DynamicSlot(ComfyTypeI):
self.force_input = True self.force_input = True
self.slot.force_input = True self.slot.force_input = True
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
if self.id in live_inputs:
curr_prefix = f"{curr_prefix}{self.id}."
add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
def get_dynamic(self) -> list[Input]:
return [self.slot] + self.inputs
def get_all(self) -> list[Input]: def get_all(self) -> list[Input]:
return [self] + [self.slot] + self.inputs return [self.slot] + self.inputs
def as_dict(self): def as_dict(self):
return super().as_dict() | prune_dict({ return super().as_dict() | prune_dict({
@ -1093,17 +1103,41 @@ class DynamicSlot(ComfyTypeI):
for input in self.inputs: for input in self.inputs:
input.validate() input.validate()
def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None): @staticmethod
dynamic = d.setdefault("dynamic_paths", {}) def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
if self is not None: finalized_id = finalize_prefix(curr_prefix)
dynamic[self.id] = f"{curr_prefix}{self.id}" if finalized_id in live_inputs:
for i in inputs: inputs = value[1]["inputs"]
if not isinstance(i, DynamicInput): parse_class_inputs(out_dict, live_inputs, inputs, curr_prefix)
dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}" # add self to inputs
out_dict[input_type][finalized_id] = value
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
def get_dynamic_input_func(io_type: str) -> Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]:
return DYNAMIC_INPUT_LOOKUP[io_type]
def setup_dynamic_input_funcs():
# Autogrow.Input
register_dynamic_input_func(Autogrow.io_type, Autogrow._expand_schema_for_dynamic)
# DynamicCombo.Input
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
# DynamicSlot.Input
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
if len(DYNAMIC_INPUT_LOOKUP) == 0:
setup_dynamic_input_funcs()
class V3Data(TypedDict): class V3Data(TypedDict):
hidden_inputs: dict[str, Any] hidden_inputs: dict[str, Any]
'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.'
dynamic_paths: dict[str, Any] dynamic_paths: dict[str, Any]
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
create_dynamic_tuple: bool
'When True, the value of the dynamic input will be in the format (value, path_key).'
class HiddenHolder: class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any, def __init__(self, unique_id: str, prompt: Any,
@ -1139,6 +1173,10 @@ class HiddenHolder:
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None), api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
) )
@classmethod
def from_v3_data(cls, v3_data: V3Data | None) -> HiddenHolder:
return cls.from_dict(v3_data["hidden_inputs"] if v3_data else None)
class Hidden(str, Enum): class Hidden(str, Enum):
''' '''
Enumerator for requesting hidden variables in nodes. Enumerator for requesting hidden variables in nodes.
@ -1244,61 +1282,56 @@ class Schema:
- verify ids on inputs and outputs are unique - both internally and in relation to each other - verify ids on inputs and outputs are unique - both internally and in relation to each other
''' '''
nested_inputs: list[Input] = [] nested_inputs: list[Input] = []
if self.inputs is not None: for input in self.inputs:
for input in self.inputs: if not isinstance(input, DynamicInput):
nested_inputs.extend(input.get_all()) nested_inputs.extend(input.get_all())
input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else [] input_ids = [i.id for i in nested_inputs]
output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] output_ids = [o.id for o in self.outputs]
input_set = set(input_ids) input_set = set(input_ids)
output_set = set(output_ids) output_set = set(output_ids)
issues = [] issues: list[str] = []
# verify ids are unique per list # verify ids are unique per list
if len(input_set) != len(input_ids): if len(input_set) != len(input_ids):
issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.") issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.")
if len(output_set) != len(output_ids): if len(output_set) != len(output_ids):
issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.") issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.")
# verify ids are unique between lists
intersection = input_set & output_set
if len(intersection) > 0:
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
if len(issues) > 0: if len(issues) > 0:
raise ValueError("\n".join(issues)) raise ValueError("\n".join(issues))
# validate inputs and outputs # validate inputs and outputs
if self.inputs is not None: for input in self.inputs:
for input in self.inputs: input.validate()
input.validate() for output in self.outputs:
if self.outputs is not None: output.validate()
for output in self.outputs:
output.validate()
def finalize(self): def finalize(self):
"""Add hidden based on selected schema options, and give outputs without ids default ids.""" """Add hidden based on selected schema options, and give outputs without ids default ids."""
# ensure inputs, outputs, and hidden are lists
if self.inputs is None:
self.inputs = []
if self.outputs is None:
self.outputs = []
if self.hidden is None:
self.hidden = []
# if is an api_node, will need key-related hidden # if is an api_node, will need key-related hidden
if self.is_api_node: if self.is_api_node:
if self.hidden is None:
self.hidden = []
if Hidden.auth_token_comfy_org not in self.hidden: if Hidden.auth_token_comfy_org not in self.hidden:
self.hidden.append(Hidden.auth_token_comfy_org) self.hidden.append(Hidden.auth_token_comfy_org)
if Hidden.api_key_comfy_org not in self.hidden: if Hidden.api_key_comfy_org not in self.hidden:
self.hidden.append(Hidden.api_key_comfy_org) self.hidden.append(Hidden.api_key_comfy_org)
# if is an output_node, will need prompt and extra_pnginfo # if is an output_node, will need prompt and extra_pnginfo
if self.is_output_node: if self.is_output_node:
if self.hidden is None:
self.hidden = []
if Hidden.prompt not in self.hidden: if Hidden.prompt not in self.hidden:
self.hidden.append(Hidden.prompt) self.hidden.append(Hidden.prompt)
if Hidden.extra_pnginfo not in self.hidden: if Hidden.extra_pnginfo not in self.hidden:
self.hidden.append(Hidden.extra_pnginfo) self.hidden.append(Hidden.extra_pnginfo)
# give outputs without ids default ids # give outputs without ids default ids
if self.outputs is not None: for i, output in enumerate(self.outputs):
for i, output in enumerate(self.outputs): if output.id is None:
if output.id is None: output.id = f"_{i}_{output.io_type}_"
output.id = f"_{i}_{output.io_type}_"
def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1: def get_v1_info(self, cls) -> NodeInfoV1:
# NOTE: live_inputs will not be used anymore very soon and this will be done another way
# get V1 inputs # get V1 inputs
input = create_input_dict_v1(self.inputs, live_inputs) input = create_input_dict_v1(self.inputs)
if self.hidden: if self.hidden:
for hidden in self.hidden: for hidden in self.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,) input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
@ -1378,33 +1411,54 @@ class Schema:
) )
return info return info
def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]:
out_dict = {
"required": {},
"optional": {},
"dynamic_paths": {},
}
d = d.copy()
# ignore hidden for parsing
hidden = d.pop("hidden", None)
parse_class_inputs(out_dict, live_inputs, d)
if hidden is not None and include_hidden:
out_dict["hidden"] = hidden
v3_data = {}
dynamic_paths = out_dict.pop("dynamic_paths", None)
if dynamic_paths is not None:
v3_data["dynamic_paths"] = dynamic_paths
return out_dict, hidden, v3_data
def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict: def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
for input_type, inner_d in curr_dict.items():
for id, value in inner_d.items():
io_type = value[0]
if io_type in DYNAMIC_INPUT_LOOKUP:
# dynamic inputs need to be handled with lookup functions
dynamic_input_func = get_dynamic_input_func(io_type)
new_prefix = handle_prefix(curr_prefix, id)
dynamic_input_func(out_dict, live_inputs, value, input_type, new_prefix)
else:
# non-dynamic inputs get directly transferred
finalized_id = finalize_prefix(curr_prefix, id)
out_dict[input_type][finalized_id] = value
if curr_prefix:
out_dict["dynamic_paths"][finalized_id] = finalized_id
def create_input_dict_v1(inputs: list[Input]) -> dict:
input = { input = {
"required": {} "required": {}
} }
add_to_input_dict_v1(input, inputs, live_inputs) for i in inputs:
add_to_dict_v1(i, input)
return input return input
def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''): def add_to_dict_v1(i: Input, d: dict):
for i in inputs:
if isinstance(i, DynamicInput):
add_to_dict_v1(i, d)
if live_inputs is not None:
i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
else:
add_to_dict_v1(i, d)
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
key = "optional" if i.optional else "required" key = "optional" if i.optional else "required"
as_dict = i.as_dict() as_dict = i.as_dict()
# for v1, we don't want to include the optional key # for v1, we don't want to include the optional key
as_dict.pop("optional", None) as_dict.pop("optional", None)
if dynamic_dict is None: d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
value = (i.get_io_type(), as_dict)
else:
value = (i.get_io_type(), as_dict, dynamic_dict)
d.setdefault(key, {})[i.id] = value
def add_to_dict_v3(io: Input | Output, d: dict): def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict()) d[io.id] = (io.get_io_type(), io.as_dict())
@ -1416,6 +1470,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
values = values.copy() values = values.copy()
result = {} result = {}
create_tuple = v3_data.get("create_dynamic_tuple", False)
for key, path in paths.items(): for key, path in paths.items():
parts = path.split(".") parts = path.split(".")
current = result current = result
@ -1424,7 +1480,10 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
is_last = (i == len(parts) - 1) is_last = (i == len(parts) - 1)
if is_last: if is_last:
current[p] = values.pop(key, None) value = values.pop(key, None)
if create_tuple:
value = (value, key)
current[p] = value
else: else:
current = current.setdefault(p, {}) current = current.setdefault(p, {})
@ -1439,7 +1498,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
SCHEMA = None SCHEMA = None
# filled in during execution # filled in during execution
resources: Resources = None
hidden: HiddenHolder = None hidden: HiddenHolder = None
@classmethod @classmethod
@ -1486,7 +1544,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
return [name for name in kwargs if kwargs[name] is None] return [name for name in kwargs if kwargs[name] is None]
def __init__(self): def __init__(self):
self.local_resources: ResourcesLocal = None
self.__class__.VALIDATE_CLASS() self.__class__.VALIDATE_CLASS()
@classmethod @classmethod
@ -1549,12 +1606,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final @final
@classmethod @classmethod
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]: def PREPARE_CLASS_CLONE(cls, v3_data: V3Data | None) -> type[ComfyNode]:
"""Creates clone of real node class to prevent monkey-patching.""" """Creates clone of real node class to prevent monkey-patching."""
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
type_clone: type[ComfyNode] = shallow_clone_class(c_type) type_clone: type[ComfyNode] = shallow_clone_class(c_type)
# set hidden # set hidden
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"]) type_clone.hidden = HiddenHolder.from_v3_data(v3_data)
return type_clone return type_clone
@final @final
@ -1671,19 +1728,10 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final @final
@classmethod @classmethod
def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]: def INPUT_TYPES(cls) -> dict[str, dict]:
schema = cls.FINALIZE_SCHEMA() schema = cls.FINALIZE_SCHEMA()
info = schema.get_v1_info(cls, live_inputs) info = schema.get_v1_info(cls)
input = info.input return info.input
if not include_hidden:
input.pop("hidden", None)
if return_schema:
v3_data: V3Data = {}
dynamic = input.pop("dynamic_paths", None)
if dynamic is not None:
v3_data["dynamic_paths"] = dynamic
return input, schema, v3_data
return input
@final @final
@classmethod @classmethod
@ -1802,7 +1850,7 @@ class NodeOutput(_NodeOutputInternal):
return self.args if len(self.args) > 0 else None return self.args if len(self.args) > 0 else None
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> "NodeOutput": def from_dict(cls, data: dict[str, Any]) -> NodeOutput:
args = () args = ()
ui = None ui = None
expand = None expand = None
@ -1815,7 +1863,7 @@ class NodeOutput(_NodeOutputInternal):
ui = data["ui"] ui = data["ui"]
if "expand" in data: if "expand" in data:
expand = data["expand"] expand = data["expand"]
return cls(args=args, ui=ui, expand=expand) return cls(*args, ui=ui, expand=expand)
def __getitem__(self, index) -> Any: def __getitem__(self, index) -> Any:
return self.args[index] return self.args[index]
@ -1894,10 +1942,11 @@ __all__ = [
"SEGS", "SEGS",
"AnyType", "AnyType",
"MultiType", "MultiType",
"Tracks",
# Dynamic Types # Dynamic Types
"MatchType", "MatchType",
# "DynamicCombo", "DynamicCombo",
# "Autogrow", "Autogrow",
# Other classes # Other classes
"HiddenHolder", "HiddenHolder",
"Hidden", "Hidden",

View File

@ -1,72 +0,0 @@
from __future__ import annotations
import comfy.utils
import folder_paths
import logging
from abc import ABC, abstractmethod
from typing import Any
import torch
class ResourceKey(ABC):
Type = Any
def __init__(self):
...
class TorchDictFolderFilename(ResourceKey):
'''Key for requesting a torch file via file_name from a folder category.'''
Type = dict[str, torch.Tensor]
def __init__(self, folder_name: str, file_name: str):
self.folder_name = folder_name
self.file_name = file_name
def __hash__(self):
return hash((self.folder_name, self.file_name))
def __eq__(self, other: object) -> bool:
if not isinstance(other, TorchDictFolderFilename):
return False
return self.folder_name == other.folder_name and self.file_name == other.file_name
def __str__(self):
return f"{self.folder_name} -> {self.file_name}"
class Resources(ABC):
def __init__(self):
...
@abstractmethod
def get(self, key: ResourceKey, default: Any=...) -> Any:
pass
class ResourcesLocal(Resources):
def __init__(self):
super().__init__()
self.local_resources: dict[ResourceKey, Any] = {}
def get(self, key: ResourceKey, default: Any=...) -> Any:
cached = self.local_resources.get(key, None)
if cached is not None:
logging.info(f"Using cached resource '{key}'")
return cached
logging.info(f"Loading resource '{key}'")
to_return = None
if isinstance(key, TorchDictFolderFilename):
if default is ...:
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
else:
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
if full_path is not None:
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
if to_return is not None:
self.local_resources[key] = to_return
return to_return
if default is not ...:
return default
raise Exception(f"Unsupported resource key type: {type(key)}")
class _RESOURCES:
ResourceKey = ResourceKey
TorchDictFolderFilename = TorchDictFolderFilename
Resources = Resources
ResourcesLocal = ResourcesLocal

View File

@ -5,7 +5,6 @@ import os
import random import random
import uuid import uuid
from io import BytesIO from io import BytesIO
from typing import Type
import av import av
import numpy as np import numpy as np
@ -22,7 +21,7 @@ import folder_paths
# used for image preview # used for image preview
from comfy.cli_args import args from comfy.cli_args import args
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput from ._io import ComfyNode, FolderType, Image, _UIOutput
class SavedResult(dict): class SavedResult(dict):
@ -83,7 +82,7 @@ class ImageSaveHelper:
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8)) return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
@staticmethod @staticmethod
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None: def _create_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo.""" """Creates a PngInfo object with prompt and extra_pnginfo."""
if args.disable_metadata or cls is None or not cls.hidden: if args.disable_metadata or cls is None or not cls.hidden:
return None return None
@ -96,7 +95,7 @@ class ImageSaveHelper:
return metadata return metadata
@staticmethod @staticmethod
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None: def _create_animated_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG).""" """Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
if args.disable_metadata or cls is None or not cls.hidden: if args.disable_metadata or cls is None or not cls.hidden:
return None return None
@ -121,7 +120,7 @@ class ImageSaveHelper:
return metadata return metadata
@staticmethod @staticmethod
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif: def _create_webp_metadata(pil_image: PILImage.Image, cls: type[ComfyNode] | None) -> PILImage.Exif:
"""Creates EXIF metadata bytes for WebP images.""" """Creates EXIF metadata bytes for WebP images."""
exif_data = pil_image.getexif() exif_data = pil_image.getexif()
if args.disable_metadata or cls is None or cls.hidden is None: if args.disable_metadata or cls is None or cls.hidden is None:
@ -137,7 +136,7 @@ class ImageSaveHelper:
@staticmethod @staticmethod
def save_images( def save_images(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4, images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, compress_level = 4,
) -> list[SavedResult]: ) -> list[SavedResult]:
"""Saves a batch of images as individual PNG files.""" """Saves a batch of images as individual PNG files."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -155,7 +154,7 @@ class ImageSaveHelper:
return results return results
@staticmethod @staticmethod
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages: def get_save_images_ui(images, filename_prefix: str, cls: type[ComfyNode] | None, compress_level=4) -> SavedImages:
"""Saves a batch of images and returns a UI object for the node output.""" """Saves a batch of images and returns a UI object for the node output."""
return SavedImages( return SavedImages(
ImageSaveHelper.save_images( ImageSaveHelper.save_images(
@ -169,7 +168,7 @@ class ImageSaveHelper:
@staticmethod @staticmethod
def save_animated_png( def save_animated_png(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedResult: ) -> SavedResult:
"""Saves a batch of images as a single animated PNG.""" """Saves a batch of images as a single animated PNG."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -191,7 +190,7 @@ class ImageSaveHelper:
@staticmethod @staticmethod
def get_save_animated_png_ui( def get_save_animated_png_ui(
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int images, filename_prefix: str, cls: type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedImages: ) -> SavedImages:
"""Saves an animated PNG and returns a UI object for the node output.""" """Saves an animated PNG and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_png( result = ImageSaveHelper.save_animated_png(
@ -209,7 +208,7 @@ class ImageSaveHelper:
images, images,
filename_prefix: str, filename_prefix: str,
folder_type: FolderType, folder_type: FolderType,
cls: Type[ComfyNode] | None, cls: type[ComfyNode] | None,
fps: float, fps: float,
lossless: bool, lossless: bool,
quality: int, quality: int,
@ -238,7 +237,7 @@ class ImageSaveHelper:
def get_save_animated_webp_ui( def get_save_animated_webp_ui(
images, images,
filename_prefix: str, filename_prefix: str,
cls: Type[ComfyNode] | None, cls: type[ComfyNode] | None,
fps: float, fps: float,
lossless: bool, lossless: bool,
quality: int, quality: int,
@ -267,7 +266,7 @@ class AudioSaveHelper:
audio: dict, audio: dict,
filename_prefix: str, filename_prefix: str,
folder_type: FolderType, folder_type: FolderType,
cls: Type[ComfyNode] | None, cls: type[ComfyNode] | None,
format: str = "flac", format: str = "flac",
quality: str = "128k", quality: str = "128k",
) -> list[SavedResult]: ) -> list[SavedResult]:
@ -372,7 +371,7 @@ class AudioSaveHelper:
@staticmethod @staticmethod
def get_save_audio_ui( def get_save_audio_ui(
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k", audio, filename_prefix: str, cls: type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
) -> SavedAudios: ) -> SavedAudios:
"""Save and instantly wrap for UI.""" """Save and instantly wrap for UI."""
return SavedAudios( return SavedAudios(
@ -388,7 +387,7 @@ class AudioSaveHelper:
class PreviewImage(_UIOutput): class PreviewImage(_UIOutput):
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs): def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs):
self.values = ImageSaveHelper.save_images( self.values = ImageSaveHelper.save_images(
image, image,
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)), filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
@ -412,7 +411,7 @@ class PreviewMask(PreviewImage):
class PreviewAudio(_UIOutput): class PreviewAudio(_UIOutput):
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs): def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs):
self.values = AudioSaveHelper.save_audio( self.values = AudioSaveHelper.save_audio(
audio, audio,
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)), filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),

View File

@ -1,5 +1,6 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents from .video_types import VideoContainer, VideoCodec, VideoComponents
from .geometry_types import VOXEL, MESH from .geometry_types import VOXEL, MESH
from .image_types import SVG
__all__ = [ __all__ = [
# Utility Types # Utility Types
@ -8,4 +9,5 @@ __all__ = [
"VideoComponents", "VideoComponents",
"VOXEL", "VOXEL",
"MESH", "MESH",
"SVG",
] ]

View File

@ -0,0 +1,18 @@
from io import BytesIO
class SVG:
"""Stores SVG representations via a list of BytesIO objects."""
def __init__(self, data: list[BytesIO]):
self.data = data
def combine(self, other: 'SVG') -> 'SVG':
return SVG(self.data + other.data)
@staticmethod
def combine_all(svgs: list['SVG']) -> 'SVG':
all_svgs_list: list[BytesIO] = []
for svg_item in svgs:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from fractions import Fraction from fractions import Fraction
from typing import Optional from typing import Optional
from comfy_api.latest._input import ImageInput, AudioInput from .._input import ImageInput, AudioInput
class VideoCodec(str, Enum): class VideoCodec(str, Enum):
AUTO = "auto" AUTO = "auto"

View File

@ -2,9 +2,8 @@ from comfy_api.latest import ComfyAPI_latest
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2 from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1 from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
from comfy_api.internal import ComfyAPIBase from comfy_api.internal import ComfyAPIBase
from typing import List, Type
supported_versions: List[Type[ComfyAPIBase]] = [ supported_versions: list[type[ComfyAPIBase]] = [
ComfyAPI_latest, ComfyAPI_latest,
ComfyAPIAdapter_v0_0_2, ComfyAPIAdapter_v0_0_2,
ComfyAPIAdapter_v0_0_1, ComfyAPIAdapter_v0_0_1,

View File

@ -0,0 +1,144 @@
from typing import Literal
from pydantic import BaseModel, Field
class Text2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str | None = Field("url")
size: str | None = Field(None)
seed: int | None = Field(0, ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(False)
class Image2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str | None = Field("url")
image: str = Field(..., description="Base64 encoded string or image URL")
size: str | None = Field("adaptive")
seed: int | None = Field(..., ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(False)
class Seedream4Options(BaseModel):
max_images: int = Field(15)
class Seedream4TaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str = Field("url")
image: list[str] | None = Field(None, description="Image URLs")
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(False)
class ImageTaskCreationResponse(BaseModel):
model: str = Field(...)
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
data: list = Field([], description="Contains information about the generated image(s).")
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
class TaskTextContent(BaseModel):
type: str = Field("text")
text: str = Field(...)
class TaskImageContentUrl(BaseModel):
url: str = Field(...)
class TaskImageContent(BaseModel):
type: str = Field("image_url")
image_url: TaskImageContentUrl = Field(...)
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
class Text2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
content: list[TaskTextContent] = Field(..., min_length=1)
class Image2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2)
class TaskCreationResponse(BaseModel):
id: str = Field(...)
class TaskStatusError(BaseModel):
code: str = Field(...)
message: str = Field(...)
class TaskStatusResult(BaseModel):
video_url: str = Field(...)
class TaskStatusResponse(BaseModel):
id: str = Field(...)
model: str = Field(...)
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
error: TaskStatusError | None = Field(None)
content: TaskStatusResult | None = Field(None)
RECOMMENDED_PRESETS = [
("1024x1024 (1:1)", 1024, 1024),
("864x1152 (3:4)", 864, 1152),
("1152x864 (4:3)", 1152, 864),
("1280x720 (16:9)", 1280, 720),
("720x1280 (9:16)", 720, 1280),
("832x1248 (2:3)", 832, 1248),
("1248x832 (3:2)", 1248, 832),
("1512x648 (21:9)", 1512, 648),
("2048x2048 (1:1)", 2048, 2048),
("Custom", None, None),
]
RECOMMENDED_PRESETS_SEEDREAM_4 = [
("2048x2048 (1:1)", 2048, 2048),
("2304x1728 (4:3)", 2304, 1728),
("1728x2304 (3:4)", 1728, 2304),
("2560x1440 (16:9)", 2560, 1440),
("1440x2560 (9:16)", 1440, 2560),
("2496x1664 (3:2)", 2496, 1664),
("1664x2496 (2:3)", 1664, 2496),
("3024x1296 (21:9)", 3024, 1296),
("4096x4096 (1:1)", 4096, 4096),
("Custom", None, None),
]
# The time in this dictionary are given for 10 seconds duration.
VIDEO_TASKS_EXECUTION_TIME = {
"seedance-1-0-lite-t2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-lite-i2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-pro-250528": {
"480p": 70,
"720p": 85,
"1080p": 115,
},
"seedance-1-0-pro-fast-251015": {
"480p": 50,
"720p": 65,
"1080p": 100,
},
}

View File

@ -84,15 +84,7 @@ class GeminiSystemInstructionContent(BaseModel):
description="A list of ordered parts that make up a single message. " description="A list of ordered parts that make up a single message. "
"Different parts may have different IANA MIME types.", "Different parts may have different IANA MIME types.",
) )
role: GeminiRole = Field( role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.")
...,
description="The identity of the entity that creates the message. "
"The following values are supported: "
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
"model: This indicates that the message is generated by the model. "
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
"For non-multi-turn conversations, this field can be left blank or unset.",
)
class GeminiFunctionDeclaration(BaseModel): class GeminiFunctionDeclaration(BaseModel):
@ -141,6 +133,7 @@ class GeminiImageGenerateContentRequest(BaseModel):
systemInstruction: GeminiSystemInstructionContent | None = Field(None) systemInstruction: GeminiSystemInstructionContent | None = Field(None)
tools: list[GeminiTool] | None = Field(None) tools: list[GeminiTool] | None = Field(None)
videoMetadata: GeminiVideoMetadata | None = Field(None) videoMetadata: GeminiVideoMetadata | None = Field(None)
uploadImagesToStorage: bool = Field(True)
class GeminiGenerateContentRequest(BaseModel): class GeminiGenerateContentRequest(BaseModel):

View File

@ -51,25 +51,25 @@ class TaskStatusImageResult(BaseModel):
url: str = Field(..., description="URL for generated image") url: str = Field(..., description="URL for generated image")
class OmniTaskStatusResults(BaseModel): class TaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None) videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None) images: list[TaskStatusImageResult] | None = Field(None)
class OmniTaskStatusResponseData(BaseModel): class TaskStatusResponseData(BaseModel):
created_at: int | None = Field(None, description="Task creation time") created_at: int | None = Field(None, description="Task creation time")
updated_at: int | None = Field(None, description="Task update time") updated_at: int | None = Field(None, description="Task update time")
task_status: str | None = None task_status: str | None = None
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.") task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
task_id: str | None = Field(None, description="Task ID") task_id: str | None = Field(None, description="Task ID")
task_result: OmniTaskStatusResults | None = Field(None) task_result: TaskStatusResults | None = Field(None)
class OmniTaskStatusResponse(BaseModel): class TaskStatusResponse(BaseModel):
code: int | None = Field(None, description="Error code") code: int | None = Field(None, description="Error code")
message: str | None = Field(None, description="Error message") message: str | None = Field(None, description="Error message")
request_id: str | None = Field(None, description="Request ID") request_id: str | None = Field(None, description="Request ID")
data: OmniTaskStatusResponseData | None = Field(None) data: TaskStatusResponseData | None = Field(None)
class OmniImageParamImage(BaseModel): class OmniImageParamImage(BaseModel):
@ -84,3 +84,30 @@ class OmniProImageRequest(BaseModel):
mode: str = Field("pro") mode: str = Field("pro")
n: int | None = Field(1, le=9) n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10) image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
class TextToVideoWithAudioRequest(BaseModel):
model_name: str = Field(..., description="kling-v2-6")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
class ImageToVideoWithAudioRequest(BaseModel):
model_name: str = Field(..., description="kling-v2-6")
image: str = Field(...)
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
class MotionControlRequest(BaseModel):
prompt: str = Field(...)
image_url: str = Field(...)
video_url: str = Field(...)
keep_original_sound: str = Field(...)
character_orientation: str = Field(...)
mode: str = Field(..., description="'pro' or 'std'")

View File

@ -0,0 +1,52 @@
from pydantic import BaseModel, Field
class Datum2(BaseModel):
b64_json: str | None = Field(None, description="Base64 encoded image data")
revised_prompt: str | None = Field(None, description="Revised prompt")
url: str | None = Field(None, description="URL of the image")
class InputTokensDetails(BaseModel):
image_tokens: int | None = None
text_tokens: int | None = None
class Usage(BaseModel):
input_tokens: int | None = None
input_tokens_details: InputTokensDetails | None = None
output_tokens: int | None = None
total_tokens: int | None = None
class OpenAIImageGenerationResponse(BaseModel):
data: list[Datum2] | None = None
usage: Usage | None = None
class OpenAIImageEditRequest(BaseModel):
background: str | None = Field(None, description="Background transparency")
model: str = Field(...)
moderation: str | None = Field(None)
n: int | None = Field(None, description="The number of images to generate")
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
output_format: str | None = Field(None)
prompt: str = Field(...)
quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
size: str | None = Field(None, description="Size of the output image")
class OpenAIImageGenerationRequest(BaseModel):
background: str | None = Field(None, description="Background transparency")
model: str | None = Field(None)
moderation: str | None = Field(None)
n: int | None = Field(
None,
description="The number of images to generate.",
)
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
output_format: str | None = Field(None)
prompt: str = Field(...)
quality: str | None = Field(None, description="The quality of the generated image")
size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
style: str | None = Field(None, description="Style of the image (only for dall-e-3)")

View File

@ -1,100 +0,0 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
duration: Optional[int] = Field(5)
ingredientsMode: str = Field(...)
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = Field('1080p')
seed: Optional[int] = Field(None)
class PikaGenerateResponse(BaseModel):
video_id: str = Field(...)
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
duration: Optional[int] = Field(None, ge=5, le=10)
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
aspectRatio: Optional[float] = Field(
1.7777777777777777,
description='Aspect ratio (width / height)',
ge=0.4,
le=2.5,
)
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
pikaffect: Optional[str] = None
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
modifyRegionRoi: Optional[str] = Field(None)
class PikaStatusEnum(str, Enum):
queued = "queued"
started = "started"
finished = "finished"
failed = "failed"
class PikaVideoResponse(BaseModel):
id: str = Field(...)
progress: Optional[int] = Field(None)
status: PikaStatusEnum
url: Optional[str] = Field(None)

View File

@ -5,11 +5,17 @@ from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, RootModel from pydantic import BaseModel, Field, RootModel
class TripoModelVersion(str, Enum): class TripoModelVersion(str, Enum):
v3_0_20250812 = 'v3.0-20250812'
v2_5_20250123 = 'v2.5-20250123' v2_5_20250123 = 'v2.5-20250123'
v2_0_20240919 = 'v2.0-20240919' v2_0_20240919 = 'v2.0-20240919'
v1_4_20240625 = 'v1.4-20240625' v1_4_20240625 = 'v1.4-20240625'
class TripoGeometryQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
class TripoTextureQuality(str, Enum): class TripoTextureQuality(str, Enum):
standard = 'standard' standard = 'standard'
detailed = 'detailed' detailed = 'detailed'
@ -61,14 +67,20 @@ class TripoSpec(str, Enum):
class TripoAnimation(str, Enum): class TripoAnimation(str, Enum):
IDLE = "preset:idle" IDLE = "preset:idle"
WALK = "preset:walk" WALK = "preset:walk"
RUN = "preset:run"
DIVE = "preset:dive"
CLIMB = "preset:climb" CLIMB = "preset:climb"
JUMP = "preset:jump" JUMP = "preset:jump"
RUN = "preset:run"
SLASH = "preset:slash" SLASH = "preset:slash"
SHOOT = "preset:shoot" SHOOT = "preset:shoot"
HURT = "preset:hurt" HURT = "preset:hurt"
FALL = "preset:fall" FALL = "preset:fall"
TURN = "preset:turn" TURN = "preset:turn"
QUADRUPED_WALK = "preset:quadruped:walk"
HEXAPOD_WALK = "preset:hexapod:walk"
OCTOPOD_WALK = "preset:octopod:walk"
SERPENTINE_MARCH = "preset:serpentine:march"
AQUATIC_MARCH = "preset:aquatic:march"
class TripoStylizeStyle(str, Enum): class TripoStylizeStyle(str, Enum):
LEGO = "lego" LEGO = "lego"
@ -105,6 +117,11 @@ class TripoTaskStatus(str, Enum):
BANNED = "banned" BANNED = "banned"
EXPIRED = "expired" EXPIRED = "expired"
class TripoFbxPreset(str, Enum):
BLENDER = "blender"
MIXAMO = "mixamo"
_3DSMAX = "3dsmax"
class TripoFileTokenReference(BaseModel): class TripoFileTokenReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference') type: Optional[str] = Field(None, description='The type of the reference')
file_token: str file_token: str
@ -142,6 +159,7 @@ class TripoTextToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model') model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
style: Optional[TripoStyle] = None style: Optional[TripoStyle] = None
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
@ -156,6 +174,7 @@ class TripoImageToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model') model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method') texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model') style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
@ -173,6 +192,7 @@ class TripoMultiviewToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model') model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model') orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
@ -219,14 +239,24 @@ class TripoConvertModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task') type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
format: TripoConvertFormat = Field(..., description='The format to convert to') format: TripoConvertFormat = Field(..., description='The format to convert to')
original_model_task_id: str = Field(..., description='The task ID of the original model') original_model_task_id: str = Field(..., description='The task ID of the original model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model') quad: Optional[bool] = Field(None, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry') force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry')
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to') face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model') flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom') flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(4096, description='The size of the texture') texture_size: Optional[int] = Field(None, description='The size of the texture')
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture') texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom') pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom')
scale_factor: Optional[float] = Field(None, description='The scale factor for the model')
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
bake: Optional[bool] = Field(None, description='Whether to bake the model')
part_names: Optional[List[str]] = Field(None, description='The names of the parts to include')
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
class TripoTaskRequest(RootModel): class TripoTaskRequest(RootModel):
root: Union[ root: Union[

View File

@ -85,7 +85,7 @@ class Response1(BaseModel):
raiMediaFilteredReasons: Optional[list[str]] = Field( raiMediaFilteredReasons: Optional[list[str]] = Field(
None, description='Reasons why media was filtered by responsible AI policies' None, description='Reasons why media was filtered by responsible AI policies'
) )
videos: Optional[list[Video]] = None videos: Optional[list[Video]] = Field(None)
class VeoGenVidPollResponse(BaseModel): class VeoGenVidPollResponse(BaseModel):

View File

@ -1,10 +1,8 @@
from inspect import cleandoc
import torch import torch
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bfl_api import ( from comfy_api_nodes.apis.bfl_api import (
BFLFluxExpandImageRequest, BFLFluxExpandImageRequest,
BFLFluxFillImageRequest, BFLFluxFillImageRequest,
@ -28,7 +26,7 @@ from comfy_api_nodes.util import (
) )
def convert_mask_to_image(mask: torch.Tensor): def convert_mask_to_image(mask: Input.Image):
""" """
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
""" """
@ -38,9 +36,6 @@ def convert_mask_to_image(mask: torch.Tensor):
class FluxProUltraImageNode(IO.ComfyNode): class FluxProUltraImageNode(IO.ComfyNode):
"""
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
@ -48,7 +43,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
node_id="FluxProUltraImageNode", node_id="FluxProUltraImageNode",
display_name="Flux 1.1 [pro] Ultra Image", display_name="Flux 1.1 [pro] Ultra Image",
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -117,7 +112,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
prompt_upsampling: bool = False, prompt_upsampling: bool = False,
raw: bool = False, raw: bool = False,
seed: int = 0, seed: int = 0,
image_prompt: torch.Tensor | None = None, image_prompt: Input.Image | None = None,
image_prompt_strength: float = 0.1, image_prompt_strength: float = 0.1,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if image_prompt is None: if image_prompt is None:
@ -155,9 +150,6 @@ class FluxProUltraImageNode(IO.ComfyNode):
class FluxKontextProImageNode(IO.ComfyNode): class FluxKontextProImageNode(IO.ComfyNode):
"""
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
@ -165,7 +157,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
node_id=cls.NODE_ID, node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME, display_name=cls.DISPLAY_NAME,
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -231,7 +223,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
aspect_ratio: str, aspect_ratio: str,
guidance: float, guidance: float,
steps: int, steps: int,
input_image: torch.Tensor | None = None, input_image: Input.Image | None = None,
seed=0, seed=0,
prompt_upsampling=False, prompt_upsampling=False,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
@ -271,20 +263,14 @@ class FluxKontextProImageNode(IO.ComfyNode):
class FluxKontextMaxImageNode(FluxKontextProImageNode): class FluxKontextMaxImageNode(FluxKontextProImageNode):
"""
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
"""
DESCRIPTION = cleandoc(__doc__ or "") DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio."
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate" BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
NODE_ID = "FluxKontextMaxImageNode" NODE_ID = "FluxKontextMaxImageNode"
DISPLAY_NAME = "Flux.1 Kontext [max] Image" DISPLAY_NAME = "Flux.1 Kontext [max] Image"
class FluxProExpandNode(IO.ComfyNode): class FluxProExpandNode(IO.ComfyNode):
"""
Outpaints image based on prompt.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
@ -292,7 +278,7 @@ class FluxProExpandNode(IO.ComfyNode):
node_id="FluxProExpandNode", node_id="FluxProExpandNode",
display_name="Flux.1 Expand Image", display_name="Flux.1 Expand Image",
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Outpaints image based on prompt.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.String.Input( IO.String.Input(
@ -371,7 +357,7 @@ class FluxProExpandNode(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
image: torch.Tensor, image: Input.Image,
prompt: str, prompt: str,
prompt_upsampling: bool, prompt_upsampling: bool,
top: int, top: int,
@ -418,9 +404,6 @@ class FluxProExpandNode(IO.ComfyNode):
class FluxProFillNode(IO.ComfyNode): class FluxProFillNode(IO.ComfyNode):
"""
Inpaints image based on mask and prompt.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
@ -428,7 +411,7 @@ class FluxProFillNode(IO.ComfyNode):
node_id="FluxProFillNode", node_id="FluxProFillNode",
display_name="Flux.1 Fill Image", display_name="Flux.1 Fill Image",
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Inpaints image based on mask and prompt.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.Mask.Input("mask"), IO.Mask.Input("mask"),
@ -480,8 +463,8 @@ class FluxProFillNode(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
image: torch.Tensor, image: Input.Image,
mask: torch.Tensor, mask: Input.Image,
prompt: str, prompt: str,
prompt_upsampling: bool, prompt_upsampling: bool,
steps: int, steps: int,
@ -525,11 +508,15 @@ class FluxProFillNode(IO.ComfyNode):
class Flux2ProImageNode(IO.ComfyNode): class Flux2ProImageNode(IO.ComfyNode):
NODE_ID = "Flux2ProImageNode"
DISPLAY_NAME = "Flux.2 [pro] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate"
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="Flux2ProImageNode", node_id=cls.NODE_ID,
display_name="Flux.2 [pro] Image", display_name=cls.DISPLAY_NAME,
category="api node/image/BFL", category="api node/image/BFL",
description="Generates images synchronously based on prompt and resolution.", description="Generates images synchronously based on prompt and resolution.",
inputs=[ inputs=[
@ -563,12 +550,11 @@ class Flux2ProImageNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"prompt_upsampling", "prompt_upsampling",
default=False, default=True,
tooltip="Whether to perform upsampling on the prompt. " tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, " "If active, automatically modifies the prompt for more creative generation.",
"but results are nondeterministic (same seed will not produce exactly the same result).",
), ),
IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."), IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."),
], ],
outputs=[IO.Image.Output()], outputs=[IO.Image.Output()],
hidden=[ hidden=[
@ -587,7 +573,7 @@ class Flux2ProImageNode(IO.ComfyNode):
height: int, height: int,
seed: int, seed: int,
prompt_upsampling: bool, prompt_upsampling: bool,
images: torch.Tensor | None = None, images: Input.Image | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
reference_images = {} reference_images = {}
if images is not None: if images is not None:
@ -598,7 +584,7 @@ class Flux2ProImageNode(IO.ComfyNode):
reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048) reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"), ApiEndpoint(path=cls.API_ENDPOINT, method="POST"),
response_model=BFLFluxProGenerateResponse, response_model=BFLFluxProGenerateResponse,
data=Flux2ProGenerateRequest( data=Flux2ProGenerateRequest(
prompt=prompt, prompt=prompt,
@ -632,6 +618,13 @@ class Flux2ProImageNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class Flux2MaxImageNode(Flux2ProImageNode):
NODE_ID = "Flux2MaxImageNode"
DISPLAY_NAME = "Flux.2 [max] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-max/generate"
class BFLExtension(ComfyExtension): class BFLExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -642,6 +635,7 @@ class BFLExtension(ComfyExtension):
FluxProExpandNode, FluxProExpandNode,
FluxProFillNode, FluxProFillNode,
Flux2ProImageNode, Flux2ProImageNode,
Flux2MaxImageNode,
] ]

View File

@ -1,13 +1,27 @@
import logging import logging
import math import math
from enum import Enum
from typing import Literal, Optional, Union
import torch import torch
from pydantic import BaseModel, Field
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bytedance_api import (
RECOMMENDED_PRESETS,
RECOMMENDED_PRESETS_SEEDREAM_4,
VIDEO_TASKS_EXECUTION_TIME,
Image2ImageTaskCreationRequest,
Image2VideoTaskCreationRequest,
ImageTaskCreationResponse,
Seedream4Options,
Seedream4TaskCreationRequest,
TaskCreationResponse,
TaskImageContent,
TaskImageContentUrl,
TaskStatusResponse,
TaskTextContent,
Text2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest,
)
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
download_url_to_image_tensor, download_url_to_image_tensor,
@ -29,162 +43,6 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
class Text2ImageModelName(str, Enum):
seedream_3 = "seedream-3-0-t2i-250415"
class Image2ImageModelName(str, Enum):
seededit_3 = "seededit-3-0-i2i-250628"
class Text2VideoModelName(str, Enum):
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_lite = "seedance-1-0-lite-t2v-250428"
class Image2VideoModelName(str, Enum):
"""note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_lite = "seedance-1-0-lite-i2v-250428"
class Text2ImageTaskCreationRequest(BaseModel):
model: Text2ImageModelName = Text2ImageModelName.seedream_3
prompt: str = Field(...)
response_format: Optional[str] = Field("url")
size: Optional[str] = Field(None)
seed: Optional[int] = Field(0, ge=0, le=2147483647)
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
watermark: Optional[bool] = Field(True)
class Image2ImageTaskCreationRequest(BaseModel):
model: Image2ImageModelName = Image2ImageModelName.seededit_3
prompt: str = Field(...)
response_format: Optional[str] = Field("url")
image: str = Field(..., description="Base64 encoded string or image URL")
size: Optional[str] = Field("adaptive")
seed: Optional[int] = Field(..., ge=0, le=2147483647)
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
watermark: Optional[bool] = Field(True)
class Seedream4Options(BaseModel):
max_images: int = Field(15)
class Seedream4TaskCreationRequest(BaseModel):
model: str = Field("seedream-4-0-250828")
prompt: str = Field(...)
response_format: str = Field("url")
image: Optional[list[str]] = Field(None, description="Image URLs")
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(True)
class ImageTaskCreationResponse(BaseModel):
model: str = Field(...)
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
data: list = Field([], description="Contains information about the generated image(s).")
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
class TaskTextContent(BaseModel):
type: str = Field("text")
text: str = Field(...)
class TaskImageContentUrl(BaseModel):
url: str = Field(...)
class TaskImageContent(BaseModel):
type: str = Field("image_url")
image_url: TaskImageContentUrl = Field(...)
role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None)
class Text2VideoTaskCreationRequest(BaseModel):
model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro
content: list[TaskTextContent] = Field(..., min_length=1)
class Image2VideoTaskCreationRequest(BaseModel):
model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro
content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2)
class TaskCreationResponse(BaseModel):
id: str = Field(...)
class TaskStatusError(BaseModel):
code: str = Field(...)
message: str = Field(...)
class TaskStatusResult(BaseModel):
video_url: str = Field(...)
class TaskStatusResponse(BaseModel):
id: str = Field(...)
model: str = Field(...)
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
error: Optional[TaskStatusError] = Field(None)
content: Optional[TaskStatusResult] = Field(None)
RECOMMENDED_PRESETS = [
("1024x1024 (1:1)", 1024, 1024),
("864x1152 (3:4)", 864, 1152),
("1152x864 (4:3)", 1152, 864),
("1280x720 (16:9)", 1280, 720),
("720x1280 (9:16)", 720, 1280),
("832x1248 (2:3)", 832, 1248),
("1248x832 (3:2)", 1248, 832),
("1512x648 (21:9)", 1512, 648),
("2048x2048 (1:1)", 2048, 2048),
("Custom", None, None),
]
RECOMMENDED_PRESETS_SEEDREAM_4 = [
("2048x2048 (1:1)", 2048, 2048),
("2304x1728 (4:3)", 2304, 1728),
("1728x2304 (3:4)", 1728, 2304),
("2560x1440 (16:9)", 2560, 1440),
("1440x2560 (9:16)", 1440, 2560),
("2496x1664 (3:2)", 2496, 1664),
("1664x2496 (2:3)", 1664, 2496),
("3024x1296 (21:9)", 3024, 1296),
("4096x4096 (1:1)", 4096, 4096),
("Custom", None, None),
]
# The time in this dictionary are given for 10 seconds duration.
VIDEO_TASKS_EXECUTION_TIME = {
"seedance-1-0-lite-t2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-lite-i2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-pro-250528": {
"480p": 70,
"720p": 85,
"1080p": 115,
},
}
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
if response.error: if response.error:
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}" error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
@ -194,13 +52,6 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
return response.data[0]["url"] return response.data[0]["url"]
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the video URL from the task status response if it exists."""
if hasattr(response, "content") and response.content:
return response.content.video_url
return None
class ByteDanceImageNode(IO.ComfyNode): class ByteDanceImageNode(IO.ComfyNode):
@classmethod @classmethod
@ -211,12 +62,7 @@ class ByteDanceImageNode(IO.ComfyNode):
category="api node/image/ByteDance", category="api node/image/ByteDance",
description="Generate images using ByteDance models via api based on prompt", description="Generate images using ByteDance models via api based on prompt",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
"model",
options=Text2ImageModelName,
default=Text2ImageModelName.seedream_3,
tooltip="Model name",
),
IO.String.Input( IO.String.Input(
"prompt", "prompt",
multiline=True, multiline=True,
@ -266,7 +112,7 @@ class ByteDanceImageNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the image', tooltip='Whether to add an "AI generated" watermark to the image',
optional=True, optional=True,
), ),
@ -335,12 +181,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
category="api node/image/ByteDance", category="api node/image/ByteDance",
description="Edit images using ByteDance models via api based on prompt", description="Edit images using ByteDance models via api based on prompt",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]),
"model",
options=Image2ImageModelName,
default=Image2ImageModelName.seededit_3,
tooltip="Model name",
),
IO.Image.Input( IO.Image.Input(
"image", "image",
tooltip="The base image to edit", tooltip="The base image to edit",
@ -374,7 +215,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the image', tooltip='Whether to add an "AI generated" watermark to the image',
optional=True, optional=True,
), ),
@ -388,13 +229,14 @@ class ByteDanceImageEditNode(IO.ComfyNode):
IO.Hidden.unique_id, IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
is_deprecated=True,
) )
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
model: str, model: str,
image: torch.Tensor, image: Input.Image,
prompt: str, prompt: str,
seed: int, seed: int,
guidance_scale: float, guidance_scale: float,
@ -428,13 +270,13 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="ByteDanceSeedreamNode", node_id="ByteDanceSeedreamNode",
display_name="ByteDance Seedream 4", display_name="ByteDance Seedream 4.5",
category="api node/image/ByteDance", category="api node/image/ByteDance",
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
"model", "model",
options=["seedream-4-0-250828"], options=["seedream-4-5-251128", "seedream-4-0-250828"],
tooltip="Model name", tooltip="Model name",
), ),
IO.String.Input( IO.String.Input(
@ -459,7 +301,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=2048, default=2048,
min=1024, min=1024,
max=4096, max=4096,
step=64, step=8,
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`", tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
optional=True, optional=True,
), ),
@ -468,7 +310,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=2048, default=2048,
min=1024, min=1024,
max=4096, max=4096,
step=64, step=8,
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`", tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
optional=True, optional=True,
), ),
@ -505,7 +347,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the image.', tooltip='Whether to add an "AI generated" watermark to the image.',
optional=True, optional=True,
), ),
@ -532,14 +374,14 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
cls, cls,
model: str, model: str,
prompt: str, prompt: str,
image: torch.Tensor = None, image: Input.Image | None = None,
size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0], size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0],
width: int = 2048, width: int = 2048,
height: int = 2048, height: int = 2048,
sequential_image_generation: str = "disabled", sequential_image_generation: str = "disabled",
max_images: int = 1, max_images: int = 1,
seed: int = 0, seed: int = 0,
watermark: bool = True, watermark: bool = False,
fail_on_partial: bool = True, fail_on_partial: bool = True,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
@ -555,6 +397,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
raise ValueError( raise ValueError(
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels." f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
) )
out_num_pixels = w * h
mp_provided = out_num_pixels / 1_000_000.0
if "seedream-4-5" in model and out_num_pixels < 3686400:
raise ValueError(
f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, "
f"but {mp_provided:.2f}MP provided."
)
if "seedream-4-0" in model and out_num_pixels < 921600:
raise ValueError(
f"Minimum image resolution that the selected model can generate is 0.92MP, "
f"but {mp_provided:.2f}MP provided."
)
n_input_images = get_number_of_images(image) if image is not None else 0 n_input_images = get_number_of_images(image) if image is not None else 0
if n_input_images > 10: if n_input_images > 10:
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.") raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
@ -607,9 +461,8 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
"model", "model",
options=Text2VideoModelName, options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
default=Text2VideoModelName.seedance_1_pro, default="seedance-1-0-pro-fast-251015",
tooltip="Model name",
), ),
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -655,7 +508,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the video.', tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True, optional=True,
), ),
@ -714,9 +567,8 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
"model", "model",
options=Image2VideoModelName, options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
default=Image2VideoModelName.seedance_1_pro, default="seedance-1-0-pro-fast-251015",
tooltip="Model name",
), ),
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -766,7 +618,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the video.', tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True, optional=True,
), ),
@ -787,7 +639,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
cls, cls,
model: str, model: str,
prompt: str, prompt: str,
image: torch.Tensor, image: Input.Image,
resolution: str, resolution: str,
aspect_ratio: str, aspect_ratio: str,
duration: int, duration: int,
@ -833,9 +685,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
"model", "model",
options=[model.value for model in Image2VideoModelName], options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
default=Image2VideoModelName.seedance_1_lite.value, default="seedance-1-0-lite-i2v-250428",
tooltip="Model name",
), ),
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -889,7 +740,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the video.', tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True, optional=True,
), ),
@ -910,8 +761,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
cls, cls,
model: str, model: str,
prompt: str, prompt: str,
first_frame: torch.Tensor, first_frame: Input.Image,
last_frame: torch.Tensor, last_frame: Input.Image,
resolution: str, resolution: str,
aspect_ratio: str, aspect_ratio: str,
duration: int, duration: int,
@ -968,9 +819,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
"model", "model",
options=[Image2VideoModelName.seedance_1_lite.value], options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
default=Image2VideoModelName.seedance_1_lite.value, default="seedance-1-0-lite-i2v-250428",
tooltip="Model name",
), ),
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -1013,7 +863,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the video.', tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True, optional=True,
), ),
@ -1034,7 +884,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
cls, cls,
model: str, model: str,
prompt: str, prompt: str,
images: torch.Tensor, images: Input.Image,
resolution: str, resolution: str,
aspect_ratio: str, aspect_ratio: str,
duration: int, duration: int,
@ -1069,8 +919,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
async def process_video_task( async def process_video_task(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
estimated_duration: Optional[int], estimated_duration: int | None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
@ -1085,7 +935,7 @@ async def process_video_task(
estimated_duration=estimated_duration, estimated_duration=estimated_duration,
response_model=TaskStatusResponse, response_model=TaskStatusResponse,
) )
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response))) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
def raise_if_text_params(prompt: str, text_params: list[str]) -> None: def raise_if_text_params(prompt: str, text_params: list[str]) -> None:

View File

@ -13,8 +13,7 @@ import torch
from typing_extensions import override from typing_extensions import override
import folder_paths import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api.util import VideoCodec, VideoContainer
from comfy_api_nodes.apis.gemini_api import ( from comfy_api_nodes.apis.gemini_api import (
GeminiContent, GeminiContent,
GeminiFileData, GeminiFileData,
@ -27,12 +26,15 @@ from comfy_api_nodes.apis.gemini_api import (
GeminiMimeType, GeminiMimeType,
GeminiPart, GeminiPart,
GeminiRole, GeminiRole,
GeminiSystemInstructionContent,
GeminiTextPart,
Modality, Modality,
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
audio_to_base64_string, audio_to_base64_string,
bytesio_to_image_tensor, bytesio_to_image_tensor,
download_url_to_image_tensor,
get_number_of_images, get_number_of_images,
sync_op, sync_op,
tensor_to_base64_string, tensor_to_base64_string,
@ -43,6 +45,14 @@ from comfy_api_nodes.util import (
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
GEMINI_IMAGE_SYS_PROMPT = (
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
"Interpret all user input—regardless of "
"format, intent, or abstraction—as literal visual directives for image composition.\n"
"If a prompt is conversational or lacks specific visual details, "
"you must creatively invent a concrete visual scenario that depicts the concept.\n"
"Prioritize generating the visual representation above any text, formatting, or conversational requests."
)
class GeminiModel(str, Enum): class GeminiModel(str, Enum):
@ -68,7 +78,7 @@ class GeminiImageModel(str, Enum):
async def create_image_parts( async def create_image_parts(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
images: torch.Tensor, images: Input.Image,
image_limit: int = 0, image_limit: int = 0,
) -> list[GeminiPart]: ) -> list[GeminiPart]:
image_parts: list[GeminiPart] = [] image_parts: list[GeminiPart] = []
@ -132,9 +142,11 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
) )
parts = [] parts = []
for part in response.candidates[0].content.parts: for part in response.candidates[0].content.parts:
if part_type == "text" and hasattr(part, "text") and part.text: if part_type == "text" and part.text:
parts.append(part) parts.append(part)
elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type: elif part.inlineData and part.inlineData.mimeType == part_type:
parts.append(part)
elif part.fileData and part.fileData.mimeType == part_type:
parts.append(part) parts.append(part)
# Skip parts that don't match the requested type # Skip parts that don't match the requested type
return parts return parts
@ -154,12 +166,15 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts]) return "\n".join([part.text for part in parts])
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor: async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[torch.Tensor] = [] image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/png") parts = get_parts_by_type(response, "image/png")
for part in parts: for part in parts:
image_data = base64.b64decode(part.inlineData.data) if part.inlineData:
returned_image = bytesio_to_image_tensor(BytesIO(image_data)) image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
else:
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
image_tensors.append(returned_image) image_tensors.append(returned_image)
if len(image_tensors) == 0: if len(image_tensors) == 0:
return torch.zeros((1, 1024, 1024, 4)) return torch.zeros((1, 1024, 1024, 4))
@ -277,6 +292,13 @@ class GeminiNode(IO.ComfyNode):
tooltip="Optional file(s) to use as context for the model. " tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.", "Accepts inputs from the Gemini Generate Content Input Files node.",
), ),
IO.String.Input(
"system_prompt",
multiline=True,
default="",
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
], ],
outputs=[ outputs=[
IO.String.Output(), IO.String.Output(),
@ -293,7 +315,9 @@ class GeminiNode(IO.ComfyNode):
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]: def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
"""Convert video input to Gemini API compatible parts.""" """Convert video input to Gemini API compatible parts."""
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264) base_64_string = video_to_base64_string(
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
)
return [ return [
GeminiPart( GeminiPart(
inlineData=GeminiInlineData( inlineData=GeminiInlineData(
@ -343,10 +367,11 @@ class GeminiNode(IO.ComfyNode):
prompt: str, prompt: str,
model: str, model: str,
seed: int, seed: int,
images: torch.Tensor | None = None, images: Input.Image | None = None,
audio: Input.Audio | None = None, audio: Input.Audio | None = None,
video: Input.Video | None = None, video: Input.Video | None = None,
files: list[GeminiPart] | None = None, files: list[GeminiPart] | None = None,
system_prompt: str = "",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
@ -363,7 +388,10 @@ class GeminiNode(IO.ComfyNode):
if files is not None: if files is not None:
parts.extend(files) parts.extend(files)
# Create response gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op( response = await sync_op(
cls, cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@ -373,7 +401,8 @@ class GeminiNode(IO.ComfyNode):
role=GeminiRole.user, role=GeminiRole.user,
parts=parts, parts=parts,
) )
] ],
systemInstruction=gemini_system_prompt,
), ),
response_model=GeminiGenerateContentResponse, response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price, price_extractor=calculate_tokens_price,
@ -523,6 +552,13 @@ class GeminiImage(IO.ComfyNode):
"'IMAGE+TEXT' to return both the generated image and a text response.", "'IMAGE+TEXT' to return both the generated image and a text response.",
optional=True, optional=True,
), ),
IO.String.Input(
"system_prompt",
multiline=True,
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
], ],
outputs=[ outputs=[
IO.Image.Output(), IO.Image.Output(),
@ -542,10 +578,11 @@ class GeminiImage(IO.ComfyNode):
prompt: str, prompt: str,
model: str, model: str,
seed: int, seed: int,
images: torch.Tensor | None = None, images: Input.Image | None = None,
files: list[GeminiPart] | None = None, files: list[GeminiPart] | None = None,
aspect_ratio: str = "auto", aspect_ratio: str = "auto",
response_modalities: str = "IMAGE+TEXT", response_modalities: str = "IMAGE+TEXT",
system_prompt: str = "",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
parts: list[GeminiPart] = [GeminiPart(text=prompt)] parts: list[GeminiPart] = [GeminiPart(text=prompt)]
@ -559,9 +596,13 @@ class GeminiImage(IO.ComfyNode):
if files is not None: if files is not None:
parts.extend(files) parts.extend(files)
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op( response = await sync_op(
cls, cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
data=GeminiImageGenerateContentRequest( data=GeminiImageGenerateContentRequest(
contents=[ contents=[
GeminiContent(role=GeminiRole.user, parts=parts), GeminiContent(role=GeminiRole.user, parts=parts),
@ -570,11 +611,12 @@ class GeminiImage(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=None if aspect_ratio == "auto" else image_config, imageConfig=None if aspect_ratio == "auto" else image_config,
), ),
systemInstruction=gemini_system_prompt,
), ),
response_model=GeminiGenerateContentResponse, response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price, price_extractor=calculate_tokens_price,
) )
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiImage2(IO.ComfyNode): class GeminiImage2(IO.ComfyNode):
@ -640,6 +682,13 @@ class GeminiImage2(IO.ComfyNode):
tooltip="Optional file(s) to use as context for the model. " tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.", "Accepts inputs from the Gemini Generate Content Input Files node.",
), ),
IO.String.Input(
"system_prompt",
multiline=True,
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
], ],
outputs=[ outputs=[
IO.Image.Output(), IO.Image.Output(),
@ -662,8 +711,9 @@ class GeminiImage2(IO.ComfyNode):
aspect_ratio: str, aspect_ratio: str,
resolution: str, resolution: str,
response_modalities: str, response_modalities: str,
images: torch.Tensor | None = None, images: Input.Image | None = None,
files: list[GeminiPart] | None = None, files: list[GeminiPart] | None = None,
system_prompt: str = "",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
@ -679,9 +729,13 @@ class GeminiImage2(IO.ComfyNode):
if aspect_ratio != "auto": if aspect_ratio != "auto":
image_config.aspectRatio = aspect_ratio image_config.aspectRatio = aspect_ratio
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
data=GeminiImageGenerateContentRequest( data=GeminiImageGenerateContentRequest(
contents=[ contents=[
GeminiContent(role=GeminiRole.user, parts=parts), GeminiContent(role=GeminiRole.user, parts=parts),
@ -690,11 +744,12 @@ class GeminiImage2(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=image_config, imageConfig=image_config,
), ),
systemInstruction=gemini_system_prompt,
), ),
response_model=GeminiGenerateContentResponse, response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price, price_extractor=calculate_tokens_price,
) )
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiExtension(ComfyExtension): class GeminiExtension(ComfyExtension):

View File

@ -50,6 +50,8 @@ from comfy_api_nodes.apis import (
KlingSingleImageEffectModelName, KlingSingleImageEffectModelName,
) )
from comfy_api_nodes.apis.kling_api import ( from comfy_api_nodes.apis.kling_api import (
ImageToVideoWithAudioRequest,
MotionControlRequest,
OmniImageParamImage, OmniImageParamImage,
OmniParamImage, OmniParamImage,
OmniParamVideo, OmniParamVideo,
@ -57,7 +59,8 @@ from comfy_api_nodes.apis.kling_api import (
OmniProImageRequest, OmniProImageRequest,
OmniProReferences2VideoRequest, OmniProReferences2VideoRequest,
OmniProText2VideoRequest, OmniProText2VideoRequest,
OmniTaskStatusResponse, TaskStatusResponse,
TextToVideoWithAudioRequest,
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
@ -103,10 +106,6 @@ AVERAGE_DURATION_VIDEO_EXTEND = 320
MODE_TEXT2VIDEO = { MODE_TEXT2VIDEO = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
"standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"),
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
"pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"),
"standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"), "standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"),
"standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"), "standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"),
"pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"), "pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"),
@ -127,8 +126,6 @@ See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document
MODE_START_END_FRAME = { MODE_START_END_FRAME = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
"pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"), "pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"),
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"), "pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"), "pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
@ -242,7 +239,7 @@ def normalize_omni_prompt_references(prompt: str) -> str:
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt) return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput: async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput:
if response.code: if response.code:
raise RuntimeError( raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
@ -250,7 +247,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStat
final_response = await poll_op( final_response = await poll_op(
cls, cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=OmniTaskStatusResponse, response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None), status_extractor=lambda r: (r.data.task_status if r.data else None),
max_poll_attempts=160, max_poll_attempts=160,
) )
@ -483,12 +480,12 @@ async def execute_image2video(
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = await poll_op( final_response = await poll_op(
cls, cls,
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"), ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
response_model=KlingImage2VideoResponse, response_model=KlingImage2VideoResponse,
estimated_duration=AVERAGE_DURATION_I2V, estimated_duration=AVERAGE_DURATION_I2V,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
) )
validate_video_result_response(final_response) validate_video_result_response(final_response)
video = get_video_from_response(final_response) video = get_video_from_response(final_response)
@ -752,7 +749,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
IO.Combo.Input( IO.Combo.Input(
"mode", "mode",
options=modes, options=modes,
default=modes[4], default=modes[8],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
), ),
], ],
@ -810,6 +807,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
), ),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Combo.Input("duration", options=[5, 10]), IO.Combo.Input("duration", options=[5, 10]),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
], ],
outputs=[ outputs=[
IO.Video.Output(), IO.Video.Output(),
@ -829,17 +827,19 @@ class OmniProTextToVideoNode(IO.ComfyNode):
prompt: str, prompt: str,
aspect_ratio: str, aspect_ratio: str,
duration: int, duration: int,
resolution: str = "1080p",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500) validate_string(prompt, min_length=1, max_length=2500)
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse, response_model=TaskStatusResponse,
data=OmniProText2VideoRequest( data=OmniProText2VideoRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
duration=str(duration), duration=str(duration),
mode="pro" if resolution == "1080p" else "std",
), ),
) )
return await finish_omni_video_task(cls, response) return await finish_omni_video_task(cls, response)
@ -862,7 +862,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
tooltip="A text prompt describing the video content. " tooltip="A text prompt describing the video content. "
"This can include both positive and negative descriptions.", "This can include both positive and negative descriptions.",
), ),
IO.Combo.Input("duration", options=["5", "10"]), IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider),
IO.Image.Input("first_frame"), IO.Image.Input("first_frame"),
IO.Image.Input( IO.Image.Input(
"end_frame", "end_frame",
@ -875,6 +875,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
optional=True, optional=True,
tooltip="Up to 6 additional reference images.", tooltip="Up to 6 additional reference images.",
), ),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
], ],
outputs=[ outputs=[
IO.Video.Output(), IO.Video.Output(),
@ -896,11 +897,16 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
first_frame: Input.Image, first_frame: Input.Image,
end_frame: Input.Image | None = None, end_frame: Input.Image | None = None,
reference_images: Input.Image | None = None, reference_images: Input.Image | None = None,
resolution: str = "1080p",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt) prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500) validate_string(prompt, min_length=1, max_length=2500)
if end_frame is not None and reference_images is not None: if end_frame is not None and reference_images is not None:
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
if duration not in (5, 10) and end_frame is None and reference_images is None:
raise ValueError(
"Duration is only supported for 5 or 10 seconds if there is no end frame or reference images."
)
validate_image_dimensions(first_frame, min_width=300, min_height=300) validate_image_dimensions(first_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
image_list: list[OmniParamImage] = [ image_list: list[OmniParamImage] = [
@ -929,12 +935,13 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse, response_model=TaskStatusResponse,
data=OmniProFirstLastFrameRequest( data=OmniProFirstLastFrameRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
duration=str(duration), duration=str(duration),
image_list=image_list, image_list=image_list,
mode="pro" if resolution == "1080p" else "std",
), ),
) )
return await finish_omni_video_task(cls, response) return await finish_omni_video_task(cls, response)
@ -963,6 +970,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
"reference_images", "reference_images",
tooltip="Up to 7 reference images.", tooltip="Up to 7 reference images.",
), ),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
], ],
outputs=[ outputs=[
IO.Video.Output(), IO.Video.Output(),
@ -983,6 +991,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
aspect_ratio: str, aspect_ratio: str,
duration: int, duration: int,
reference_images: Input.Image, reference_images: Input.Image,
resolution: str = "1080p",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt) prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500) validate_string(prompt, min_length=1, max_length=2500)
@ -997,13 +1006,14 @@ class OmniProImageToVideoNode(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse, response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest( data=OmniProReferences2VideoRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
duration=str(duration), duration=str(duration),
image_list=image_list, image_list=image_list,
mode="pro" if resolution == "1080p" else "std",
), ),
) )
return await finish_omni_video_task(cls, response) return await finish_omni_video_task(cls, response)
@ -1035,6 +1045,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
tooltip="Up to 4 additional reference images.", tooltip="Up to 4 additional reference images.",
optional=True, optional=True,
), ),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
], ],
outputs=[ outputs=[
IO.Video.Output(), IO.Video.Output(),
@ -1057,6 +1068,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
reference_video: Input.Video, reference_video: Input.Video,
keep_original_sound: bool, keep_original_sound: bool,
reference_images: Input.Image | None = None, reference_images: Input.Image | None = None,
resolution: str = "1080p",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt) prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500) validate_string(prompt, min_length=1, max_length=2500)
@ -1081,7 +1093,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse, response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest( data=OmniProReferences2VideoRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
@ -1089,6 +1101,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
duration=str(duration), duration=str(duration),
image_list=image_list if image_list else None, image_list=image_list if image_list else None,
video_list=video_list, video_list=video_list,
mode="pro" if resolution == "1080p" else "std",
), ),
) )
return await finish_omni_video_task(cls, response) return await finish_omni_video_task(cls, response)
@ -1118,6 +1131,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
tooltip="Up to 4 additional reference images.", tooltip="Up to 4 additional reference images.",
optional=True, optional=True,
), ),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
], ],
outputs=[ outputs=[
IO.Video.Output(), IO.Video.Output(),
@ -1138,6 +1152,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
video: Input.Video, video: Input.Video,
keep_original_sound: bool, keep_original_sound: bool,
reference_images: Input.Image | None = None, reference_images: Input.Image | None = None,
resolution: str = "1080p",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt) prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500) validate_string(prompt, min_length=1, max_length=2500)
@ -1162,7 +1177,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse, response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest( data=OmniProReferences2VideoRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
@ -1170,6 +1185,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
duration=None, duration=None,
image_list=image_list if image_list else None, image_list=image_list if image_list else None,
video_list=video_list, video_list=video_list,
mode="pro" if resolution == "1080p" else "std",
), ),
) )
return await finish_omni_video_task(cls, response) return await finish_omni_video_task(cls, response)
@ -1237,7 +1253,7 @@ class OmniProImageNode(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"), ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
response_model=OmniTaskStatusResponse, response_model=TaskStatusResponse,
data=OmniProImageRequest( data=OmniProImageRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
@ -1253,7 +1269,7 @@ class OmniProImageNode(IO.ComfyNode):
final_response = await poll_op( final_response = await poll_op(
cls, cls,
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"), ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
response_model=OmniTaskStatusResponse, response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None), status_extractor=lambda r: (r.data.task_status if r.data else None),
) )
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url)) return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
@ -1328,9 +1344,8 @@ class KlingImage2VideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="KlingImage2VideoNode", node_id="KlingImage2VideoNode",
display_name="Kling Image to Video", display_name="Kling Image(First Frame) to Video",
category="api node/video/Kling", category="api node/video/Kling",
description="Kling Image to Video Node",
inputs=[ inputs=[
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."), IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -1488,7 +1503,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
IO.Combo.Input( IO.Combo.Input(
"mode", "mode",
options=modes, options=modes,
default=modes[8], default=modes[6],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
), ),
], ],
@ -1951,7 +1966,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
IO.Combo.Input( IO.Combo.Input(
"model_name", "model_name",
options=[i.value for i in KlingImageGenModelName], options=[i.value for i in KlingImageGenModelName],
default="kling-v1", default="kling-v2",
), ),
IO.Combo.Input( IO.Combo.Input(
"aspect_ratio", "aspect_ratio",
@ -2034,6 +2049,221 @@ class KlingImageGenerationNode(IO.ComfyNode):
return IO.NodeOutput(await image_result_to_node_output(images)) return IO.NodeOutput(await image_result_to_node_output(images))
class TextToVideoWithAudio(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingTextToVideoWithAudio",
display_name="Kling Text to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
IO.Combo.Input("mode", options=["pro"]),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Boolean.Input("generate_audio", default=True),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
mode: str,
aspect_ratio: str,
duration: int,
generate_audio: bool,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
response_model=TaskStatusResponse,
data=TextToVideoWithAudioRequest(
model_name=model_name,
prompt=prompt,
mode=mode,
aspect_ratio=aspect_ratio,
duration=str(duration),
sound="on" if generate_audio else "off",
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class ImageToVideoWithAudio(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImageToVideoWithAudio",
display_name="Kling Image(First Frame) to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.Image.Input("start_frame"),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
IO.Combo.Input("mode", options=["pro"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Boolean.Input("generate_audio", default=True),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
start_frame: Input.Image,
prompt: str,
mode: str,
duration: int,
generate_audio: bool,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
validate_image_dimensions(start_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
response_model=TaskStatusResponse,
data=ImageToVideoWithAudioRequest(
model_name=model_name,
image=(await upload_images_to_comfyapi(cls, start_frame))[0],
prompt=prompt,
mode=mode,
duration=str(duration),
sound="on" if generate_audio else "off",
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class MotionControl(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingMotionControl",
display_name="Kling Motion Control",
category="api node/video/Kling",
inputs=[
IO.String.Input("prompt", multiline=True),
IO.Image.Input("reference_image"),
IO.Video.Input(
"reference_video",
tooltip="Motion reference video used to drive movement/expression.\n"
"Duration limits depend on character_orientation:\n"
" - image: 310s (max 10s)\n"
" - video: 330s (max 30s)",
),
IO.Boolean.Input("keep_original_sound", default=True),
IO.Combo.Input(
"character_orientation",
options=["video", "image"],
tooltip="Controls where the character's facing/orientation comes from.\n"
"video: movements, expressions, camera moves, and orientation "
"follow the motion reference video (other details via prompt).\n"
"image: movements and expressions still follow the motion reference video, "
"but the character orientation matches the reference image (camera/other details via prompt).",
),
IO.Combo.Input("mode", options=["pro", "std"]),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
reference_image: Input.Image,
reference_video: Input.Video,
keep_original_sound: bool,
character_orientation: str,
mode: str,
) -> IO.NodeOutput:
validate_string(prompt, max_length=2500)
validate_image_dimensions(reference_image, min_width=340, min_height=340)
validate_image_aspect_ratio(reference_image, (1, 2.5), (2.5, 1))
if character_orientation == "image":
validate_video_duration(reference_video, min_duration=3, max_duration=10)
else:
validate_video_duration(reference_video, min_duration=3, max_duration=30)
validate_video_dimensions(reference_video, min_width=340, min_height=340, max_width=3850, max_height=3850)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/motion-control", method="POST"),
response_model=TaskStatusResponse,
data=MotionControlRequest(
prompt=prompt,
image_url=(await upload_images_to_comfyapi(cls, reference_image))[0],
video_url=await upload_video_to_comfyapi(cls, reference_video),
keep_original_sound="yes" if keep_original_sound else "no",
character_orientation=character_orientation,
mode=mode,
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/motion-control/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class KlingExtension(ComfyExtension): class KlingExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -2056,7 +2286,10 @@ class KlingExtension(ComfyExtension):
OmniProImageToVideoNode, OmniProImageToVideoNode,
OmniProVideoToVideoNode, OmniProVideoToVideoNode,
OmniProEditVideoNode, OmniProEditVideoNode,
# OmniProImageNode, # need support from backend OmniProImageNode,
TextToVideoWithAudio,
ImageToVideoWithAudio,
MotionControl,
] ]

View File

@ -1,12 +1,9 @@
from io import BytesIO from io import BytesIO
from typing import Optional
import torch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import override from typing_extensions import override
from comfy_api.input_impl import VideoFromFile from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
get_number_of_images, get_number_of_images,
@ -26,9 +23,9 @@ class ExecuteTaskRequest(BaseModel):
model: str = Field(...) model: str = Field(...)
duration: int = Field(...) duration: int = Field(...)
resolution: str = Field(...) resolution: str = Field(...)
fps: Optional[int] = Field(25) fps: int | None = Field(25)
generate_audio: Optional[bool] = Field(True) generate_audio: bool | None = Field(True)
image_uri: Optional[str] = Field(None) image_uri: str | None = Field(None)
class TextToVideoNode(IO.ComfyNode): class TextToVideoNode(IO.ComfyNode):
@ -103,7 +100,7 @@ class TextToVideoNode(IO.ComfyNode):
as_binary=True, as_binary=True,
max_retries=1, max_retries=1,
) )
return IO.NodeOutput(VideoFromFile(BytesIO(response))) return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
class ImageToVideoNode(IO.ComfyNode): class ImageToVideoNode(IO.ComfyNode):
@ -153,7 +150,7 @@ class ImageToVideoNode(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
image: torch.Tensor, image: Input.Image,
model: str, model: str,
prompt: str, prompt: str,
duration: int, duration: int,
@ -183,7 +180,7 @@ class ImageToVideoNode(IO.ComfyNode):
as_binary=True, as_binary=True,
max_retries=1, max_retries=1,
) )
return IO.NodeOutput(VideoFromFile(BytesIO(response))) return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
class LtxvApiExtension(ComfyExtension): class LtxvApiExtension(ComfyExtension):

View File

@ -1,11 +1,8 @@
import logging import logging
from typing import Optional
import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.input import VideoInput from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import (
MoonvalleyPromptResponse, MoonvalleyPromptResponse,
MoonvalleyTextToVideoInferenceParams, MoonvalleyTextToVideoInferenceParams,
@ -61,7 +58,7 @@ def validate_task_creation_response(response) -> None:
raise RuntimeError(error_msg) raise RuntimeError(error_msg)
def validate_video_to_video_input(video: VideoInput) -> VideoInput: def validate_video_to_video_input(video: Input.Video) -> Input.Video:
""" """
Validates and processes video input for Moonvalley Video-to-Video generation. Validates and processes video input for Moonvalley Video-to-Video generation.
@ -82,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
return _validate_and_trim_duration(video) return _validate_and_trim_duration(video)
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]: def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
"""Extracts video dimensions with error handling.""" """Extracts video dimensions with error handling."""
try: try:
return video.get_dimensions() return video.get_dimensions()
@ -106,7 +103,7 @@ def _validate_video_dimensions(width: int, height: int) -> None:
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
def _validate_and_trim_duration(video: VideoInput) -> VideoInput: def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
"""Validates video duration and trims to 5 seconds if needed.""" """Validates video duration and trims to 5 seconds if needed."""
duration = video.get_duration() duration = video.get_duration()
_validate_minimum_duration(duration) _validate_minimum_duration(duration)
@ -119,7 +116,7 @@ def _validate_minimum_duration(duration: float) -> None:
raise ValueError("Input video must be at least 5 seconds long.") raise ValueError("Input video must be at least 5 seconds long.")
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
"""Trims video to 5 seconds if longer.""" """Trims video to 5 seconds if longer."""
if duration > 5: if duration > 5:
return trim_video(video, 5) return trim_video(video, 5)
@ -241,7 +238,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
image: torch.Tensor, image: Input.Image,
prompt: str, prompt: str,
negative_prompt: str, negative_prompt: str,
resolution: str, resolution: str,
@ -362,9 +359,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
prompt: str, prompt: str,
negative_prompt: str, negative_prompt: str,
seed: int, seed: int,
video: Optional[VideoInput] = None, video: Input.Video | None = None,
control_type: str = "Motion Transfer", control_type: str = "Motion Transfer",
motion_intensity: Optional[int] = 100, motion_intensity: int | None = 100,
steps=33, steps=33,
prompt_adherence=4.5, prompt_adherence=4.5,
) -> IO.NodeOutput: ) -> IO.NodeOutput:

View File

@ -1,46 +1,45 @@
from io import BytesIO import base64
import os import os
from enum import Enum from enum import Enum
from inspect import cleandoc from io import BytesIO
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
import folder_paths
import base64
from comfy_api.latest import IO, ComfyExtension
from typing_extensions import override from typing_extensions import override
import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import (
OpenAIImageGenerationRequest,
OpenAIImageEditRequest,
OpenAIImageGenerationResponse,
OpenAICreateResponse,
OpenAIResponse,
CreateModelResponseProperties, CreateModelResponseProperties,
Item,
OutputContent,
InputImageContent,
Detail, Detail,
InputTextContent,
InputMessage,
InputMessageContentList,
InputContent, InputContent,
InputFileContent, InputFileContent,
InputImageContent,
InputMessage,
InputMessageContentList,
InputTextContent,
Item,
OpenAICreateResponse,
OpenAIResponse,
OutputContent,
)
from comfy_api_nodes.apis.openai_api import (
OpenAIImageEditRequest,
OpenAIImageGenerationRequest,
OpenAIImageGenerationResponse,
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
downscale_image_tensor,
download_url_to_bytesio,
validate_string,
tensor_to_base64_string,
ApiEndpoint, ApiEndpoint,
sync_op, download_url_to_bytesio,
downscale_image_tensor,
poll_op, poll_op,
sync_op,
tensor_to_base64_string,
text_filepath_to_data_uri, text_filepath_to_data_uri,
validate_string,
) )
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses" RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>" STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
@ -98,9 +97,6 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
class OpenAIDalle2(IO.ComfyNode): class OpenAIDalle2(IO.ComfyNode):
"""
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
"""
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -108,7 +104,7 @@ class OpenAIDalle2(IO.ComfyNode):
node_id="OpenAIDalle2", node_id="OpenAIDalle2",
display_name="OpenAI DALL·E 2", display_name="OpenAI DALL·E 2",
category="api node/image/OpenAI", category="api node/image/OpenAI",
description=cleandoc(cls.__doc__ or ""), description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -234,9 +230,6 @@ class OpenAIDalle2(IO.ComfyNode):
class OpenAIDalle3(IO.ComfyNode): class OpenAIDalle3(IO.ComfyNode):
"""
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
"""
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -244,7 +237,7 @@ class OpenAIDalle3(IO.ComfyNode):
node_id="OpenAIDalle3", node_id="OpenAIDalle3",
display_name="OpenAI DALL·E 3", display_name="OpenAI DALL·E 3",
category="api node/image/OpenAI", category="api node/image/OpenAI",
description=cleandoc(cls.__doc__ or ""), description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -326,10 +319,16 @@ class OpenAIDalle3(IO.ComfyNode):
return IO.NodeOutput(await validate_and_cast_response(response)) return IO.NodeOutput(await validate_and_cast_response(response))
def calculate_tokens_price_image_1(response: OpenAIImageGenerationResponse) -> float | None:
# https://platform.openai.com/docs/pricing
return ((response.usage.input_tokens * 10.0) + (response.usage.output_tokens * 40.0)) / 1_000_000.0
def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) -> float | None:
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
class OpenAIGPTImage1(IO.ComfyNode): class OpenAIGPTImage1(IO.ComfyNode):
"""
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
"""
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -337,13 +336,13 @@ class OpenAIGPTImage1(IO.ComfyNode):
node_id="OpenAIGPTImage1", node_id="OpenAIGPTImage1",
display_name="OpenAI GPT Image 1", display_name="OpenAI GPT Image 1",
category="api node/image/OpenAI", category="api node/image/OpenAI",
description=cleandoc(cls.__doc__ or ""), description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
default="", default="",
multiline=True, multiline=True,
tooltip="Text prompt for GPT Image 1", tooltip="Text prompt for GPT Image",
), ),
IO.Int.Input( IO.Int.Input(
"seed", "seed",
@ -365,8 +364,8 @@ class OpenAIGPTImage1(IO.ComfyNode):
), ),
IO.Combo.Input( IO.Combo.Input(
"background", "background",
default="opaque", default="auto",
options=["opaque", "transparent"], options=["auto", "opaque", "transparent"],
tooltip="Return image with or without background", tooltip="Return image with or without background",
optional=True, optional=True,
), ),
@ -397,6 +396,11 @@ class OpenAIGPTImage1(IO.ComfyNode):
tooltip="Optional mask for inpainting (white areas will be replaced)", tooltip="Optional mask for inpainting (white areas will be replaced)",
optional=True, optional=True,
), ),
IO.Combo.Input(
"model",
options=["gpt-image-1", "gpt-image-1.5"],
optional=True,
),
], ],
outputs=[ outputs=[
IO.Image.Output(), IO.Image.Output(),
@ -412,32 +416,34 @@ class OpenAIGPTImage1(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
prompt, prompt: str,
seed=0, seed: int = 0,
quality="low", quality: str = "low",
background="opaque", background: str = "opaque",
image=None, image: Input.Image | None = None,
mask=None, mask: Input.Image | None = None,
n=1, n: int = 1,
size="1024x1024", size: str = "1024x1024",
model: str = "gpt-image-1",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
model = "gpt-image-1"
path = "/proxy/openai/images/generations" if mask is not None and image is None:
content_type = "application/json" raise ValueError("Cannot use a mask without an input image")
request_class = OpenAIImageGenerationRequest
files = [] if model == "gpt-image-1":
price_extractor = calculate_tokens_price_image_1
elif model == "gpt-image-1.5":
price_extractor = calculate_tokens_price_image_1_5
else:
raise ValueError(f"Unknown model: {model}")
if image is not None: if image is not None:
path = "/proxy/openai/images/edits" files = []
request_class = OpenAIImageEditRequest
content_type = "multipart/form-data"
batch_size = image.shape[0] batch_size = image.shape[0]
for i in range(batch_size): for i in range(batch_size):
single_image = image[i : i + 1] single_image = image[i: i + 1]
scaled_image = downscale_image_tensor(single_image).squeeze() scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze()
image_np = (scaled_image.numpy() * 255).astype(np.uint8) image_np = (scaled_image.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np) img = Image.fromarray(image_np)
@ -450,44 +456,59 @@ class OpenAIGPTImage1(IO.ComfyNode):
else: else:
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png"))) files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
if mask is not None: if mask is not None:
if image is None: if image.shape[0] != 1:
raise Exception("Cannot use a mask without an input image") raise Exception("Cannot use a mask with multiple image")
if image.shape[0] != 1: if mask.shape[1:] != image.shape[1:-1]:
raise Exception("Cannot use a mask with multiple image") raise Exception("Mask and Image must be the same size")
if mask.shape[1:] != image.shape[1:-1]: _, height, width = mask.shape
raise Exception("Mask and Image must be the same size") rgba_mask = torch.zeros(height, width, 4, device="cpu")
batch, height, width = mask.shape rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
rgba_mask = torch.zeros(height, width, 4, device="cpu")
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze() scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze()
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np) mask_img = Image.fromarray(mask_np)
mask_img_byte_arr = BytesIO() mask_img_byte_arr = BytesIO()
mask_img.save(mask_img_byte_arr, format="PNG") mask_img.save(mask_img_byte_arr, format="PNG")
mask_img_byte_arr.seek(0) mask_img_byte_arr.seek(0)
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png"))) files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
# Build the operation
response = await sync_op(
cls,
ApiEndpoint(path=path, method="POST"),
response_model=OpenAIImageGenerationResponse,
data=request_class(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=n,
seed=seed,
size=size,
),
files=files if files else None,
content_type=content_type,
)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
response_model=OpenAIImageGenerationResponse,
data=OpenAIImageEditRequest(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=n,
seed=seed,
size=size,
moderation="low",
),
content_type="multipart/form-data",
files=files,
price_extractor=price_extractor,
)
else:
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
response_model=OpenAIImageGenerationResponse,
data=OpenAIImageGenerationRequest(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=n,
seed=seed,
size=size,
moderation="low",
),
price_extractor=price_extractor,
)
return IO.NodeOutput(await validate_and_cast_response(response)) return IO.NodeOutput(await validate_and_cast_response(response))

View File

@ -1,575 +0,0 @@
"""
Pika x ComfyUI API Nodes
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
"""
from __future__ import annotations
from io import BytesIO
import logging
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apis import pika_api as pika_defs
from comfy_api_nodes.util import (
validate_string,
download_url_to_video_output,
tensor_to_bytesio,
ApiEndpoint,
sync_op,
poll_op,
)
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
PIKA_API_VERSION = "2.2"
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
PATH_VIDEO_GET = "/proxy/pika/videos"
async def execute_task(
task_id: str,
cls: type[IO.ComfyNode],
) -> IO.NodeOutput:
final_response: pika_defs.PikaVideoResponse = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
response_model=pika_defs.PikaVideoResponse,
status_extractor=lambda response: (response.status.value if response.status else None),
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
estimated_duration=60,
max_poll_attempts=240,
)
if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
logging.error(error_msg)
raise Exception(error_msg)
video_url = final_response.url
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
return IO.NodeOutput(await download_url_to_video_output(video_url))
def get_base_inputs_types() -> list[IO.Input]:
"""Get the base required inputs types common to all Pika nodes."""
return [
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
IO.Combo.Input("duration", options=[5, 10], default=5),
]
class PikaImageToVideo(IO.ComfyNode):
"""Pika 2.2 Image to Video Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaImageToVideoNode2_2",
display_name="Pika Image to Video",
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image", tooltip="The image to convert to video"),
*get_base_inputs_types(),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
) -> IO.NodeOutput:
image_bytes_io = tensor_to_bytesio(image)
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaTextToVideoNode(IO.ComfyNode):
"""Pika Text2Video v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaTextToVideoNode2_2",
display_name="Pika Text to Video",
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
)
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
aspect_ratio: float,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
),
content_type="application/x-www-form-urlencoded",
)
return await execute_task(initial_operation.video_id, cls)
class PikaScenes(IO.ComfyNode):
"""PikaScenes v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaScenesV2_2",
display_name="Pika Scenes (Video Image Composition)",
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
IO.Combo.Input(
"ingredients_mode",
options=["creative", "precise"],
default="creative",
),
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
),
IO.Image.Input(
"image_ingredient_1",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_2",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_3",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_4",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_5",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
ingredients_mode: str,
aspect_ratio: float,
image_ingredient_1: Optional[torch.Tensor] = None,
image_ingredient_2: Optional[torch.Tensor] = None,
image_ingredient_3: Optional[torch.Tensor] = None,
image_ingredient_4: Optional[torch.Tensor] = None,
image_ingredient_5: Optional[torch.Tensor] = None,
) -> IO.NodeOutput:
all_image_bytes_io = []
for image in [
image_ingredient_1,
image_ingredient_2,
image_ingredient_3,
image_ingredient_4,
image_ingredient_5,
]:
if image is not None:
all_image_bytes_io.append(tensor_to_bytesio(image))
pika_files = [
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
for i, image_bytes_io in enumerate(all_image_bytes_io)
]
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
ingredientsMode=ingredients_mode,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikAdditionsNode(IO.ComfyNode):
"""Pika Pikadditions Node. Add an image into a video."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikadditions",
display_name="Pikadditions (Video Object Insertion)",
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
category="api node/video/Pika",
inputs=[
IO.Video.Input("video", tooltip="The video to add an image to."),
IO.Image.Input("image", tooltip="The image to add to the video."),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input(
"seed",
min=0,
max=0xFFFFFFFF,
control_after_generate=True,
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
image_bytes_io = tensor_to_bytesio(image)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"),
}
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaSwapsNode(IO.ComfyNode):
"""Pika Pikaswaps Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaswaps",
display_name="Pika Swaps (Video Object Replacement)",
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
category="api node/video/Pika",
inputs=[
IO.Video.Input("video", tooltip="The video to swap an object in."),
IO.Image.Input(
"image",
tooltip="The image used to replace the masked object in the video.",
optional=True,
),
IO.Mask.Input(
"mask",
tooltip="Use the mask to define areas in the video to replace.",
optional=True,
),
IO.String.Input("prompt_text", multiline=True, optional=True),
IO.String.Input("negative_prompt", multiline=True, optional=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
IO.String.Input(
"region_to_modify",
multiline=True,
optional=True,
tooltip="Plaintext description of the object / region to modify.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
image: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
prompt_text: str = "",
negative_prompt: str = "",
seed: int = 0,
region_to_modify: str = "",
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
}
if mask is not None:
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
if image is not None:
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaffectsNode(IO.ComfyNode):
"""Pika Pikaffects Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaffects",
display_name="Pikaffects (Video Effects)",
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
IO.Combo.Input(
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
pikaffect: str,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaStartEndFrameNode(IO.ComfyNode):
"""PikaFrames v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaStartEndFrameNode2_2",
display_name="Pika Start and End Frame to Video",
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image_start", tooltip="The first image to combine."),
IO.Image.Input("image_end", tooltip="The last image to combine."),
*get_base_inputs_types(),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image_start: torch.Tensor,
image_end: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
) -> IO.NodeOutput:
validate_string(prompt_text, field_name="prompt_text", min_length=1)
pika_files = [
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
]
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
),
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaApiNodesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
PikaImageToVideo,
PikaTextToVideoNode,
PikaScenes,
PikAdditionsNode,
PikaSwapsNode,
PikaffectsNode,
PikaStartEndFrameNode,
]
async def comfy_entrypoint() -> PikaApiNodesExtension:
return PikaApiNodesExtension()

View File

@ -11,12 +11,11 @@ User Guides:
""" """
from typing import Union, Optional
from typing_extensions import override
from enum import Enum from enum import Enum
import torch from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import (
RunwayImageToVideoRequest, RunwayImageToVideoRequest,
RunwayImageToVideoResponse, RunwayImageToVideoResponse,
@ -44,8 +43,6 @@ from comfy_api_nodes.util import (
sync_op, sync_op,
poll_op, poll_op,
) )
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video" PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image" PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
@ -80,7 +77,7 @@ class RunwayGen3aAspectRatio(str, Enum):
field_1280_768 = "1280:768" field_1280_768 = "1280:768"
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
"""Returns the video URL from the task status response if it exists.""" """Returns the video URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0: if hasattr(response, "output") and len(response.output) > 0:
return response.output[0] return response.output[0]
@ -89,13 +86,13 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
def extract_progress_from_task_status( def extract_progress_from_task_status(
response: TaskStatusResponse, response: TaskStatusResponse,
) -> Union[float, None]: ) -> float | None:
if hasattr(response, "progress") and response.progress is not None: if hasattr(response, "progress") and response.progress is not None:
return response.progress * 100 return response.progress * 100
return None return None
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
"""Returns the image URL from the task status response if it exists.""" """Returns the image URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0: if hasattr(response, "output") and len(response.output) > 0:
return response.output[0] return response.output[0]
@ -103,7 +100,7 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
async def get_response( async def get_response(
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
) -> TaskStatusResponse: ) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response.""" """Poll the task status until it is finished then get the response."""
return await poll_op( return await poll_op(
@ -119,8 +116,8 @@ async def get_response(
async def generate_video( async def generate_video(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
request: RunwayImageToVideoRequest, request: RunwayImageToVideoRequest,
estimated_duration: Optional[int] = None, estimated_duration: int | None = None,
) -> VideoFromFile: ) -> InputImpl.VideoFromFile:
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
@ -193,7 +190,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
async def execute( async def execute(
cls, cls,
prompt: str, prompt: str,
start_frame: torch.Tensor, start_frame: Input.Image,
duration: str, duration: str,
ratio: str, ratio: str,
seed: int, seed: int,
@ -283,7 +280,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
async def execute( async def execute(
cls, cls,
prompt: str, prompt: str,
start_frame: torch.Tensor, start_frame: Input.Image,
duration: str, duration: str,
ratio: str, ratio: str,
seed: int, seed: int,
@ -381,8 +378,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
async def execute( async def execute(
cls, cls,
prompt: str, prompt: str,
start_frame: torch.Tensor, start_frame: Input.Image,
end_frame: torch.Tensor, end_frame: Input.Image,
duration: str, duration: str,
ratio: str, ratio: str,
seed: int, seed: int,
@ -467,7 +464,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
cls, cls,
prompt: str, prompt: str,
ratio: str, ratio: str,
reference_image: Optional[torch.Tensor] = None, reference_image: Input.Image | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, min_length=1) validate_string(prompt, min_length=1)

View File

@ -23,10 +23,6 @@ UPSCALER_MODELS_MAP = {
"Starlight (Astra) Fast": "slf-1", "Starlight (Astra) Fast": "slf-1",
"Starlight (Astra) Creative": "slc-1", "Starlight (Astra) Creative": "slc-1",
} }
UPSCALER_VALUES_MAP = {
"FullHD (1080p)": 1920,
"4K (2160p)": 3840,
}
class TopazImageEnhance(IO.ComfyNode): class TopazImageEnhance(IO.ComfyNode):
@ -214,7 +210,7 @@ class TopazVideoEnhance(IO.ComfyNode):
IO.Video.Input("video"), IO.Video.Input("video"),
IO.Boolean.Input("upscaler_enabled", default=True), IO.Boolean.Input("upscaler_enabled", default=True),
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())), IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())), IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Combo.Input( IO.Combo.Input(
"upscaler_creativity", "upscaler_creativity",
options=["low", "middle", "high"], options=["low", "middle", "high"],
@ -306,8 +302,33 @@ class TopazVideoEnhance(IO.ComfyNode):
target_frame_rate = src_frame_rate target_frame_rate = src_frame_rate
filters = [] filters = []
if upscaler_enabled: if upscaler_enabled:
target_width = UPSCALER_VALUES_MAP[upscaler_resolution] if "1080p" in upscaler_resolution:
target_height = UPSCALER_VALUES_MAP[upscaler_resolution] target_pixel_p = 1080
max_long_side = 1920
else:
target_pixel_p = 2160
max_long_side = 3840
ar = src_width / src_height
if src_width >= src_height:
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
target_height = target_pixel_p
target_width = int(target_height * ar)
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
if target_width > max_long_side:
target_width = max_long_side
target_height = int(target_width / ar)
else:
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
target_width = target_pixel_p
target_height = int(target_width / ar)
# Check if height exceeds standard bounds
if target_height > max_long_side:
target_height = max_long_side
target_width = int(target_height * ar)
if target_width % 2 != 0:
target_width += 1
if target_height % 2 != 0:
target_height += 1
filters.append( filters.append(
topaz_api.VideoEnhancementFilter( topaz_api.VideoEnhancementFilter(
model=UPSCALER_MODELS_MAP[upscaler_model], model=UPSCALER_MODELS_MAP[upscaler_model],

View File

@ -102,8 +102,9 @@ class TripoTextToModelNode(IO.ComfyNode):
IO.Int.Input("model_seed", default=42, optional=True), IO.Int.Input("model_seed", default=42, optional=True),
IO.Int.Input("texture_seed", default=42, optional=True), IO.Int.Input("texture_seed", default=42, optional=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True), IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
], ],
outputs=[ outputs=[
IO.String.Output(display_name="model_file"), IO.String.Output(display_name="model_file"),
@ -131,6 +132,7 @@ class TripoTextToModelNode(IO.ComfyNode):
model_seed: Optional[int] = None, model_seed: Optional[int] = None,
texture_seed: Optional[int] = None, texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None, texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
face_limit: Optional[int] = None, face_limit: Optional[int] = None,
quad: Optional[bool] = None, quad: Optional[bool] = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
@ -153,7 +155,8 @@ class TripoTextToModelNode(IO.ComfyNode):
model_seed=model_seed, model_seed=model_seed,
texture_seed=texture_seed, texture_seed=texture_seed,
texture_quality=texture_quality, texture_quality=texture_quality,
face_limit=face_limit, face_limit=face_limit if face_limit != -1 else None,
geometry_quality=geometry_quality,
auto_size=True, auto_size=True,
quad=quad, quad=quad,
), ),
@ -194,6 +197,7 @@ class TripoImageToModelNode(IO.ComfyNode):
), ),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True), IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
], ],
outputs=[ outputs=[
IO.String.Output(display_name="model_file"), IO.String.Output(display_name="model_file"),
@ -220,6 +224,7 @@ class TripoImageToModelNode(IO.ComfyNode):
orientation=None, orientation=None,
texture_seed: Optional[int] = None, texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None, texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
texture_alignment: Optional[str] = None, texture_alignment: Optional[str] = None,
face_limit: Optional[int] = None, face_limit: Optional[int] = None,
quad: Optional[bool] = None, quad: Optional[bool] = None,
@ -246,10 +251,11 @@ class TripoImageToModelNode(IO.ComfyNode):
pbr=pbr, pbr=pbr,
model_seed=model_seed, model_seed=model_seed,
orientation=orientation, orientation=orientation,
geometry_quality=geometry_quality,
texture_alignment=texture_alignment, texture_alignment=texture_alignment,
texture_seed=texture_seed, texture_seed=texture_seed,
texture_quality=texture_quality, texture_quality=texture_quality,
face_limit=face_limit, face_limit=face_limit if face_limit != -1 else None,
auto_size=True, auto_size=True,
quad=quad, quad=quad,
), ),
@ -295,6 +301,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
), ),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True), IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
], ],
outputs=[ outputs=[
IO.String.Output(display_name="model_file"), IO.String.Output(display_name="model_file"),
@ -323,6 +330,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
model_seed: Optional[int] = None, model_seed: Optional[int] = None,
texture_seed: Optional[int] = None, texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None, texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
texture_alignment: Optional[str] = None, texture_alignment: Optional[str] = None,
face_limit: Optional[int] = None, face_limit: Optional[int] = None,
quad: Optional[bool] = None, quad: Optional[bool] = None,
@ -359,8 +367,9 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
model_seed=model_seed, model_seed=model_seed,
texture_seed=texture_seed, texture_seed=texture_seed,
texture_quality=texture_quality, texture_quality=texture_quality,
geometry_quality=geometry_quality,
texture_alignment=texture_alignment, texture_alignment=texture_alignment,
face_limit=face_limit, face_limit=face_limit if face_limit != -1 else None,
quad=quad, quad=quad,
), ),
) )
@ -508,6 +517,8 @@ class TripoRetargetNode(IO.ComfyNode):
options=[ options=[
"preset:idle", "preset:idle",
"preset:walk", "preset:walk",
"preset:run",
"preset:dive",
"preset:climb", "preset:climb",
"preset:jump", "preset:jump",
"preset:slash", "preset:slash",
@ -515,6 +526,11 @@ class TripoRetargetNode(IO.ComfyNode):
"preset:hurt", "preset:hurt",
"preset:fall", "preset:fall",
"preset:turn", "preset:turn",
"preset:quadruped:walk",
"preset:hexapod:walk",
"preset:octopod:walk",
"preset:serpentine:march",
"preset:aquatic:march"
], ],
), ),
], ],
@ -563,7 +579,7 @@ class TripoConversionNode(IO.ComfyNode):
"face_limit", "face_limit",
default=-1, default=-1,
min=-1, min=-1,
max=500000, max=2000000,
optional=True, optional=True,
), ),
IO.Int.Input( IO.Int.Input(
@ -579,6 +595,40 @@ class TripoConversionNode(IO.ComfyNode):
default="JPEG", default="JPEG",
optional=True, optional=True,
), ),
IO.Boolean.Input("force_symmetry", default=False, optional=True),
IO.Boolean.Input("flatten_bottom", default=False, optional=True),
IO.Float.Input(
"flatten_bottom_threshold",
default=0.0,
min=0.0,
max=1.0,
optional=True,
),
IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True),
IO.Float.Input(
"scale_factor",
default=1.0,
min=0.0,
optional=True,
),
IO.Boolean.Input("with_animation", default=False, optional=True),
IO.Boolean.Input("pack_uv", default=False, optional=True),
IO.Boolean.Input("bake", default=False, optional=True),
IO.String.Input("part_names", default="", optional=True), # comma-separated list
IO.Combo.Input(
"fbx_preset",
options=["blender", "mixamo", "3dsmax"],
default="blender",
optional=True,
),
IO.Boolean.Input("export_vertex_colors", default=False, optional=True),
IO.Combo.Input(
"export_orientation",
options=["align_image", "default"],
default="default",
optional=True,
),
IO.Boolean.Input("animate_in_place", default=False, optional=True),
], ],
outputs=[], outputs=[],
hidden=[ hidden=[
@ -604,12 +654,31 @@ class TripoConversionNode(IO.ComfyNode):
original_model_task_id, original_model_task_id,
format: str, format: str,
quad: bool, quad: bool,
force_symmetry: bool,
face_limit: int, face_limit: int,
flatten_bottom: bool,
flatten_bottom_threshold: float,
texture_size: int, texture_size: int,
texture_format: str, texture_format: str,
pivot_to_center_bottom: bool,
scale_factor: float,
with_animation: bool,
pack_uv: bool,
bake: bool,
part_names: str,
fbx_preset: str,
export_vertex_colors: bool,
export_orientation: str,
animate_in_place: bool,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if not original_model_task_id: if not original_model_task_id:
raise RuntimeError("original_model_task_id is required") raise RuntimeError("original_model_task_id is required")
# Parse part_names from comma-separated string to list
part_names_list = None
if part_names and part_names.strip():
part_names_list = [name.strip() for name in part_names.split(',') if name.strip()]
response = await sync_op( response = await sync_op(
cls, cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
@ -618,9 +687,22 @@ class TripoConversionNode(IO.ComfyNode):
original_model_task_id=original_model_task_id, original_model_task_id=original_model_task_id,
format=format, format=format,
quad=quad if quad else None, quad=quad if quad else None,
force_symmetry=force_symmetry if force_symmetry else None,
face_limit=face_limit if face_limit != -1 else None, face_limit=face_limit if face_limit != -1 else None,
flatten_bottom=flatten_bottom if flatten_bottom else None,
flatten_bottom_threshold=flatten_bottom_threshold if flatten_bottom_threshold != 0.0 else None,
texture_size=texture_size if texture_size != 4096 else None, texture_size=texture_size if texture_size != 4096 else None,
texture_format=texture_format if texture_format != "JPEG" else None, texture_format=texture_format if texture_format != "JPEG" else None,
pivot_to_center_bottom=pivot_to_center_bottom if pivot_to_center_bottom else None,
scale_factor=scale_factor if scale_factor != 1.0 else None,
with_animation=with_animation if with_animation else None,
pack_uv=pack_uv if pack_uv else None,
bake=bake if bake else None,
part_names=part_names_list,
fbx_preset=fbx_preset if fbx_preset != "blender" else None,
export_vertex_colors=export_vertex_colors if export_vertex_colors else None,
export_orientation=export_orientation if export_orientation != "default" else None,
animate_in_place=animate_in_place if animate_in_place else None,
), ),
) )
return await poll_until_finished(cls, response, average_duration=30) return await poll_until_finished(cls, response, average_duration=30)

View File

@ -1,11 +1,9 @@
import base64 import base64
from io import BytesIO from io import BytesIO
import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.input_impl.video_types import VideoFromFile from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.veo_api import ( from comfy_api_nodes.apis.veo_api import (
VeoGenVidPollRequest, VeoGenVidPollRequest,
VeoGenVidPollResponse, VeoGenVidPollResponse,
@ -170,6 +168,8 @@ class VeoVideoGenerationNode(IO.ComfyNode):
# Only add generateAudio for Veo 3 models # Only add generateAudio for Veo 3 models
if model.find("veo-2.0") == -1: if model.find("veo-2.0") == -1:
parameters["generateAudio"] = generate_audio parameters["generateAudio"] = generate_audio
# force "enhance_prompt" to True for Veo3 models
parameters["enhancePrompt"] = True
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
@ -232,7 +232,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
# Check if video is provided as base64 or URL # Check if video is provided as base64 or URL
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded: if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if hasattr(video, "gcsUri") and video.gcsUri: if hasattr(video, "gcsUri") and video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
@ -293,7 +293,7 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
IO.Boolean.Input( IO.Boolean.Input(
"enhance_prompt", "enhance_prompt",
default=True, default=True,
tooltip="Whether to enhance the prompt with AI assistance", tooltip="This parameter is deprecated and ignored.",
optional=True, optional=True,
), ),
IO.Combo.Input( IO.Combo.Input(
@ -431,8 +431,8 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
aspect_ratio: str, aspect_ratio: str,
duration: int, duration: int,
seed: int, seed: int,
first_frame: torch.Tensor, first_frame: Input.Image,
last_frame: torch.Tensor, last_frame: Input.Image,
model: str, model: str,
generate_audio: bool, generate_audio: bool,
): ):
@ -493,7 +493,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
if response.videos: if response.videos:
video = response.videos[0] video = response.videos[0]
if video.bytesBase64Encoded: if video.bytesBase64Encoded:
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if video.gcsUri: if video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
raise Exception("Video returned but no data or URL was provided") raise Exception("Video returned but no data or URL was provided")

View File

@ -1,7 +1,5 @@
import re import re
from typing import Optional
import torch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import override from typing_extensions import override
@ -15,32 +13,40 @@ from comfy_api_nodes.util import (
poll_op, poll_op,
sync_op, sync_op,
tensor_to_base64_string, tensor_to_base64_string,
upload_video_to_comfyapi,
validate_audio_duration, validate_audio_duration,
validate_video_duration,
) )
class Text2ImageInputField(BaseModel): class Text2ImageInputField(BaseModel):
prompt: str = Field(...) prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None) negative_prompt: str | None = Field(None)
class Image2ImageInputField(BaseModel): class Image2ImageInputField(BaseModel):
prompt: str = Field(...) prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None) negative_prompt: str | None = Field(None)
images: list[str] = Field(..., min_length=1, max_length=2) images: list[str] = Field(..., min_length=1, max_length=2)
class Text2VideoInputField(BaseModel): class Text2VideoInputField(BaseModel):
prompt: str = Field(...) prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None) negative_prompt: str | None = Field(None)
audio_url: Optional[str] = Field(None) audio_url: str | None = Field(None)
class Image2VideoInputField(BaseModel): class Image2VideoInputField(BaseModel):
prompt: str = Field(...) prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None) negative_prompt: str | None = Field(None)
img_url: str = Field(...) img_url: str = Field(...)
audio_url: Optional[str] = Field(None) audio_url: str | None = Field(None)
class Reference2VideoInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: str | None = Field(None)
reference_video_urls: list[str] = Field(...)
class Txt2ImageParametersField(BaseModel): class Txt2ImageParametersField(BaseModel):
@ -48,32 +54,42 @@ class Txt2ImageParametersField(BaseModel):
n: int = Field(1, description="Number of images to generate.") # we support only value=1 n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647) seed: int = Field(..., ge=0, le=2147483647)
prompt_extend: bool = Field(True) prompt_extend: bool = Field(True)
watermark: bool = Field(True) watermark: bool = Field(False)
class Image2ImageParametersField(BaseModel): class Image2ImageParametersField(BaseModel):
size: Optional[str] = Field(None) size: str | None = Field(None)
n: int = Field(1, description="Number of images to generate.") # we support only value=1 n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647) seed: int = Field(..., ge=0, le=2147483647)
watermark: bool = Field(True) watermark: bool = Field(False)
class Text2VideoParametersField(BaseModel): class Text2VideoParametersField(BaseModel):
size: str = Field(...) size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647) seed: int = Field(..., ge=0, le=2147483647)
duration: int = Field(5, ge=5, le=10) duration: int = Field(5, ge=5, le=15)
prompt_extend: bool = Field(True) prompt_extend: bool = Field(True)
watermark: bool = Field(True) watermark: bool = Field(False)
audio: bool = Field(False, description="Should be audio generated automatically") audio: bool = Field(False, description="Whether to generate audio automatically.")
shot_type: str = Field("single")
class Image2VideoParametersField(BaseModel): class Image2VideoParametersField(BaseModel):
resolution: str = Field(...) resolution: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647) seed: int = Field(..., ge=0, le=2147483647)
duration: int = Field(5, ge=5, le=10) duration: int = Field(5, ge=5, le=15)
prompt_extend: bool = Field(True) prompt_extend: bool = Field(True)
watermark: bool = Field(True) watermark: bool = Field(False)
audio: bool = Field(False, description="Should be audio generated automatically") audio: bool = Field(False, description="Whether to generate audio automatically.")
shot_type: str = Field("single")
class Reference2VideoParametersField(BaseModel):
size: str = Field(...)
duration: int = Field(5, ge=5, le=15)
shot_type: str = Field("single")
seed: int = Field(..., ge=0, le=2147483647)
watermark: bool = Field(False)
class Text2ImageTaskCreationRequest(BaseModel): class Text2ImageTaskCreationRequest(BaseModel):
@ -100,45 +116,51 @@ class Image2VideoTaskCreationRequest(BaseModel):
parameters: Image2VideoParametersField = Field(...) parameters: Image2VideoParametersField = Field(...)
class Reference2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Reference2VideoInputField = Field(...)
parameters: Reference2VideoParametersField = Field(...)
class TaskCreationOutputField(BaseModel): class TaskCreationOutputField(BaseModel):
task_id: str = Field(...) task_id: str = Field(...)
task_status: str = Field(...) task_status: str = Field(...)
class TaskCreationResponse(BaseModel): class TaskCreationResponse(BaseModel):
output: Optional[TaskCreationOutputField] = Field(None) output: TaskCreationOutputField | None = Field(None)
request_id: str = Field(...) request_id: str = Field(...)
code: Optional[str] = Field(None, description="The error code of the failed request.") code: str | None = Field(None, description="Error code for the failed request.")
message: Optional[str] = Field(None, description="Details of the failed request.") message: str | None = Field(None, description="Details about the failed request.")
class TaskResult(BaseModel): class TaskResult(BaseModel):
url: Optional[str] = Field(None) url: str | None = Field(None)
code: Optional[str] = Field(None) code: str | None = Field(None)
message: Optional[str] = Field(None) message: str | None = Field(None)
class ImageTaskStatusOutputField(TaskCreationOutputField): class ImageTaskStatusOutputField(TaskCreationOutputField):
task_id: str = Field(...) task_id: str = Field(...)
task_status: str = Field(...) task_status: str = Field(...)
results: Optional[list[TaskResult]] = Field(None) results: list[TaskResult] | None = Field(None)
class VideoTaskStatusOutputField(TaskCreationOutputField): class VideoTaskStatusOutputField(TaskCreationOutputField):
task_id: str = Field(...) task_id: str = Field(...)
task_status: str = Field(...) task_status: str = Field(...)
video_url: Optional[str] = Field(None) video_url: str | None = Field(None)
code: Optional[str] = Field(None) code: str | None = Field(None)
message: Optional[str] = Field(None) message: str | None = Field(None)
class ImageTaskStatusResponse(BaseModel): class ImageTaskStatusResponse(BaseModel):
output: Optional[ImageTaskStatusOutputField] = Field(None) output: ImageTaskStatusOutputField | None = Field(None)
request_id: str = Field(...) request_id: str = Field(...)
class VideoTaskStatusResponse(BaseModel): class VideoTaskStatusResponse(BaseModel):
output: Optional[VideoTaskStatusOutputField] = Field(None) output: VideoTaskStatusOutputField | None = Field(None)
request_id: str = Field(...) request_id: str = Field(...)
@ -152,7 +174,7 @@ class WanTextToImageApi(IO.ComfyNode):
node_id="WanTextToImageApi", node_id="WanTextToImageApi",
display_name="Wan Text to Image", display_name="Wan Text to Image",
category="api node/image/Wan", category="api node/image/Wan",
description="Generates image based on text prompt.", description="Generates an image based on a text prompt.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
"model", "model",
@ -164,13 +186,13 @@ class WanTextToImageApi(IO.ComfyNode):
"prompt", "prompt",
multiline=True, multiline=True,
default="", default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
), ),
IO.String.Input( IO.String.Input(
"negative_prompt", "negative_prompt",
multiline=True, multiline=True,
default="", default="",
tooltip="Negative text prompt to guide what to avoid.", tooltip="Negative prompt describing what to avoid.",
optional=True, optional=True,
), ),
IO.Int.Input( IO.Int.Input(
@ -208,8 +230,8 @@ class WanTextToImageApi(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the result.', tooltip="Whether to add an AI-generated watermark to the result.",
optional=True, optional=True,
), ),
], ],
@ -234,7 +256,7 @@ class WanTextToImageApi(IO.ComfyNode):
height: int = 1024, height: int = 1024,
seed: int = 0, seed: int = 0,
prompt_extend: bool = True, prompt_extend: bool = True,
watermark: bool = True, watermark: bool = False,
): ):
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
@ -252,7 +274,7 @@ class WanTextToImageApi(IO.ComfyNode):
), ),
) )
if not initial_response.output: if not initial_response.output:
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op( response = await poll_op(
cls, cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
@ -272,7 +294,7 @@ class WanImageToImageApi(IO.ComfyNode):
display_name="Wan Image to Image", display_name="Wan Image to Image",
category="api node/image/Wan", category="api node/image/Wan",
description="Generates an image from one or two input images and a text prompt. " description="Generates an image from one or two input images and a text prompt. "
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).", "The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
"model", "model",
@ -282,19 +304,19 @@ class WanImageToImageApi(IO.ComfyNode):
), ),
IO.Image.Input( IO.Image.Input(
"image", "image",
tooltip="Single-image editing or multi-image fusion, maximum 2 images.", tooltip="Single-image editing or multi-image fusion. Maximum 2 images.",
), ),
IO.String.Input( IO.String.Input(
"prompt", "prompt",
multiline=True, multiline=True,
default="", default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
), ),
IO.String.Input( IO.String.Input(
"negative_prompt", "negative_prompt",
multiline=True, multiline=True,
default="", default="",
tooltip="Negative text prompt to guide what to avoid.", tooltip="Negative prompt describing what to avoid.",
optional=True, optional=True,
), ),
# redo this later as an optional combo of recommended resolutions # redo this later as an optional combo of recommended resolutions
@ -327,8 +349,8 @@ class WanImageToImageApi(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the result.', tooltip="Whether to add an AI-generated watermark to the result.",
optional=True, optional=True,
), ),
], ],
@ -347,17 +369,17 @@ class WanImageToImageApi(IO.ComfyNode):
async def execute( async def execute(
cls, cls,
model: str, model: str,
image: torch.Tensor, image: Input.Image,
prompt: str, prompt: str,
negative_prompt: str = "", negative_prompt: str = "",
# width: int = 1024, # width: int = 1024,
# height: int = 1024, # height: int = 1024,
seed: int = 0, seed: int = 0,
watermark: bool = True, watermark: bool = False,
): ):
n_images = get_number_of_images(image) n_images = get_number_of_images(image)
if n_images not in (1, 2): if n_images not in (1, 2):
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.") raise ValueError(f"Expected 1 or 2 input images, but got {n_images}.")
images = [] images = []
for i in image: for i in image:
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096)) images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
@ -376,7 +398,7 @@ class WanImageToImageApi(IO.ComfyNode):
), ),
) )
if not initial_response.output: if not initial_response.output:
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op( response = await poll_op(
cls, cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
@ -395,25 +417,25 @@ class WanTextToVideoApi(IO.ComfyNode):
node_id="WanTextToVideoApi", node_id="WanTextToVideoApi",
display_name="Wan Text to Video", display_name="Wan Text to Video",
category="api node/video/Wan", category="api node/video/Wan",
description="Generates video based on text prompt.", description="Generates a video based on a text prompt.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
"model", "model",
options=["wan2.5-t2v-preview"], options=["wan2.5-t2v-preview", "wan2.6-t2v"],
default="wan2.5-t2v-preview", default="wan2.6-t2v",
tooltip="Model to use.", tooltip="Model to use.",
), ),
IO.String.Input( IO.String.Input(
"prompt", "prompt",
multiline=True, multiline=True,
default="", default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
), ),
IO.String.Input( IO.String.Input(
"negative_prompt", "negative_prompt",
multiline=True, multiline=True,
default="", default="",
tooltip="Negative text prompt to guide what to avoid.", tooltip="Negative prompt describing what to avoid.",
optional=True, optional=True,
), ),
IO.Combo.Input( IO.Combo.Input(
@ -433,23 +455,23 @@ class WanTextToVideoApi(IO.ComfyNode):
"1080p: 4:3 (1632x1248)", "1080p: 4:3 (1632x1248)",
"1080p: 3:4 (1248x1632)", "1080p: 3:4 (1248x1632)",
], ],
default="480p: 1:1 (624x624)", default="720p: 1:1 (960x960)",
optional=True, optional=True,
), ),
IO.Int.Input( IO.Int.Input(
"duration", "duration",
default=5, default=5,
min=5, min=5,
max=10, max=15,
step=5, step=5,
display_mode=IO.NumberDisplay.number, display_mode=IO.NumberDisplay.number,
tooltip="Available durations: 5 and 10 seconds", tooltip="A 15-second duration is available only for the Wan 2.6 model.",
optional=True, optional=True,
), ),
IO.Audio.Input( IO.Audio.Input(
"audio", "audio",
optional=True, optional=True,
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.", tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.",
), ),
IO.Int.Input( IO.Int.Input(
"seed", "seed",
@ -466,7 +488,7 @@ class WanTextToVideoApi(IO.ComfyNode):
"generate_audio", "generate_audio",
default=False, default=False,
optional=True, optional=True,
tooltip="If there is no audio input, generate audio automatically.", tooltip="If no audio input is provided, generate audio automatically.",
), ),
IO.Boolean.Input( IO.Boolean.Input(
"prompt_extend", "prompt_extend",
@ -476,8 +498,16 @@ class WanTextToVideoApi(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the result.', tooltip="Whether to add an AI-generated watermark to the result.",
optional=True,
),
IO.Combo.Input(
"shot_type",
options=["single", "multi"],
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
"single continuous shot or multiple shots with cuts. "
"This parameter takes effect only when prompt_extend is True.",
optional=True, optional=True,
), ),
], ],
@ -498,14 +528,19 @@ class WanTextToVideoApi(IO.ComfyNode):
model: str, model: str,
prompt: str, prompt: str,
negative_prompt: str = "", negative_prompt: str = "",
size: str = "480p: 1:1 (624x624)", size: str = "720p: 1:1 (960x960)",
duration: int = 5, duration: int = 5,
audio: Optional[Input.Audio] = None, audio: Input.Audio | None = None,
seed: int = 0, seed: int = 0,
generate_audio: bool = False, generate_audio: bool = False,
prompt_extend: bool = True, prompt_extend: bool = True,
watermark: bool = True, watermark: bool = False,
shot_type: str = "single",
): ):
if "480p" in size and model == "wan2.6-t2v":
raise ValueError("The Wan 2.6 model does not support 480p.")
if duration == 15 and model == "wan2.5-t2v-preview":
raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.")
width, height = RES_IN_PARENS.search(size).groups() width, height = RES_IN_PARENS.search(size).groups()
audio_url = None audio_url = None
if audio is not None: if audio is not None:
@ -526,11 +561,12 @@ class WanTextToVideoApi(IO.ComfyNode):
audio=generate_audio, audio=generate_audio,
prompt_extend=prompt_extend, prompt_extend=prompt_extend,
watermark=watermark, watermark=watermark,
shot_type=shot_type,
), ),
), ),
) )
if not initial_response.output: if not initial_response.output:
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op( response = await poll_op(
cls, cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
@ -549,12 +585,12 @@ class WanImageToVideoApi(IO.ComfyNode):
node_id="WanImageToVideoApi", node_id="WanImageToVideoApi",
display_name="Wan Image to Video", display_name="Wan Image to Video",
category="api node/video/Wan", category="api node/video/Wan",
description="Generates video based on the first frame and text prompt.", description="Generates a video from the first frame and a text prompt.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
"model", "model",
options=["wan2.5-i2v-preview"], options=["wan2.5-i2v-preview", "wan2.6-i2v"],
default="wan2.5-i2v-preview", default="wan2.6-i2v",
tooltip="Model to use.", tooltip="Model to use.",
), ),
IO.Image.Input( IO.Image.Input(
@ -564,13 +600,13 @@ class WanImageToVideoApi(IO.ComfyNode):
"prompt", "prompt",
multiline=True, multiline=True,
default="", default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
), ),
IO.String.Input( IO.String.Input(
"negative_prompt", "negative_prompt",
multiline=True, multiline=True,
default="", default="",
tooltip="Negative text prompt to guide what to avoid.", tooltip="Negative prompt describing what to avoid.",
optional=True, optional=True,
), ),
IO.Combo.Input( IO.Combo.Input(
@ -580,23 +616,23 @@ class WanImageToVideoApi(IO.ComfyNode):
"720P", "720P",
"1080P", "1080P",
], ],
default="480P", default="720P",
optional=True, optional=True,
), ),
IO.Int.Input( IO.Int.Input(
"duration", "duration",
default=5, default=5,
min=5, min=5,
max=10, max=15,
step=5, step=5,
display_mode=IO.NumberDisplay.number, display_mode=IO.NumberDisplay.number,
tooltip="Available durations: 5 and 10 seconds", tooltip="Duration 15 available only for WAN2.6 model.",
optional=True, optional=True,
), ),
IO.Audio.Input( IO.Audio.Input(
"audio", "audio",
optional=True, optional=True,
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.", tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.",
), ),
IO.Int.Input( IO.Int.Input(
"seed", "seed",
@ -613,7 +649,7 @@ class WanImageToVideoApi(IO.ComfyNode):
"generate_audio", "generate_audio",
default=False, default=False,
optional=True, optional=True,
tooltip="If there is no audio input, generate audio automatically.", tooltip="If no audio input is provided, generate audio automatically.",
), ),
IO.Boolean.Input( IO.Boolean.Input(
"prompt_extend", "prompt_extend",
@ -623,8 +659,16 @@ class WanImageToVideoApi(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the result.', tooltip="Whether to add an AI-generated watermark to the result.",
optional=True,
),
IO.Combo.Input(
"shot_type",
options=["single", "multi"],
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
"single continuous shot or multiple shots with cuts. "
"This parameter takes effect only when prompt_extend is True.",
optional=True, optional=True,
), ),
], ],
@ -643,19 +687,24 @@ class WanImageToVideoApi(IO.ComfyNode):
async def execute( async def execute(
cls, cls,
model: str, model: str,
image: torch.Tensor, image: Input.Image,
prompt: str, prompt: str,
negative_prompt: str = "", negative_prompt: str = "",
resolution: str = "480P", resolution: str = "720P",
duration: int = 5, duration: int = 5,
audio: Optional[Input.Audio] = None, audio: Input.Audio | None = None,
seed: int = 0, seed: int = 0,
generate_audio: bool = False, generate_audio: bool = False,
prompt_extend: bool = True, prompt_extend: bool = True,
watermark: bool = True, watermark: bool = False,
shot_type: str = "single",
): ):
if get_number_of_images(image) != 1: if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.") raise ValueError("Exactly one input image is required.")
if "480P" in resolution and model == "wan2.6-i2v":
raise ValueError("The Wan 2.6 model does not support 480P.")
if duration == 15 and model == "wan2.5-i2v-preview":
raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.")
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000) image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
audio_url = None audio_url = None
if audio is not None: if audio is not None:
@ -677,11 +726,12 @@ class WanImageToVideoApi(IO.ComfyNode):
audio=generate_audio, audio=generate_audio,
prompt_extend=prompt_extend, prompt_extend=prompt_extend,
watermark=watermark, watermark=watermark,
shot_type=shot_type,
), ),
), ),
) )
if not initial_response.output: if not initial_response.output:
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op( response = await poll_op(
cls, cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
@ -693,6 +743,143 @@ class WanImageToVideoApi(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
class WanReferenceVideoApi(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="WanReferenceVideoApi",
display_name="Wan Reference to Video",
category="api node/video/Wan",
description="Use the character and voice from input videos, combined with a prompt, "
"to generate a new video that maintains character consistency.",
inputs=[
IO.Combo.Input("model", options=["wan2.6-r2v"]),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt describing the elements and visual features. Supports English and Chinese. "
"Use identifiers such as `character1` and `character2` to refer to the reference characters.",
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative prompt describing what to avoid.",
),
IO.Autogrow.Input(
"reference_videos",
template=IO.Autogrow.TemplateNames(
IO.Video.Input("reference_video"),
names=["character1", "character2", "character3"],
min=1,
),
),
IO.Combo.Input(
"size",
options=[
"720p: 1:1 (960x960)",
"720p: 16:9 (1280x720)",
"720p: 9:16 (720x1280)",
"720p: 4:3 (1088x832)",
"720p: 3:4 (832x1088)",
"1080p: 1:1 (1440x1440)",
"1080p: 16:9 (1920x1080)",
"1080p: 9:16 (1080x1920)",
"1080p: 4:3 (1632x1248)",
"1080p: 3:4 (1248x1632)",
],
),
IO.Int.Input(
"duration",
default=5,
min=5,
max=10,
step=5,
display_mode=IO.NumberDisplay.slider,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
),
IO.Combo.Input(
"shot_type",
options=["single", "multi"],
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
"single continuous shot or multiple shots with cuts.",
),
IO.Boolean.Input(
"watermark",
default=False,
tooltip="Whether to add an AI-generated watermark to the result.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
negative_prompt: str,
reference_videos: IO.Autogrow.Type,
size: str,
duration: int,
seed: int,
shot_type: str,
watermark: bool,
):
reference_video_urls = []
for i in reference_videos:
validate_video_duration(reference_videos[i], min_duration=2, max_duration=30)
for i in reference_videos:
reference_video_urls.append(await upload_video_to_comfyapi(cls, reference_videos[i]))
width, height = RES_IN_PARENS.search(size).groups()
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
response_model=TaskCreationResponse,
data=Reference2VideoTaskCreationRequest(
model=model,
input=Reference2VideoInputField(
prompt=prompt, negative_prompt=negative_prompt, reference_video_urls=reference_video_urls
),
parameters=Reference2VideoParametersField(
size=f"{width}*{height}",
duration=duration,
shot_type=shot_type,
watermark=watermark,
seed=seed,
),
),
)
if not initial_response.output:
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
response_model=VideoTaskStatusResponse,
status_extractor=lambda x: x.output.task_status,
poll_interval=6,
max_poll_attempts=280,
)
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
class WanApiExtension(ComfyExtension): class WanApiExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -701,6 +888,7 @@ class WanApiExtension(ComfyExtension):
WanImageToImageApi, WanImageToImageApi,
WanTextToVideoApi, WanTextToVideoApi,
WanImageToVideoApi, WanImageToVideoApi,
WanReferenceVideoApi,
] ]

View File

@ -1,16 +1,22 @@
import asyncio import asyncio
import contextlib import contextlib
import os import os
import re
import time import time
from collections.abc import Callable from collections.abc import Callable
from io import BytesIO from io import BytesIO
from yarl import URL
from comfy.cli_args import args from comfy.cli_args import args
from comfy.model_management import processing_interrupted from comfy.model_management import processing_interrupted
from comfy_api.latest import IO from comfy_api.latest import IO
from .common_exceptions import ProcessingInterrupted from .common_exceptions import ProcessingInterrupted
_HAS_PCT_ESC = re.compile(r"%[0-9A-Fa-f]{2}") # any % followed by 2 hex digits
_HAS_BAD_PCT = re.compile(r"%(?![0-9A-Fa-f]{2})") # any % not followed by 2 hex digits
def is_processing_interrupted() -> bool: def is_processing_interrupted() -> bool:
"""Return True if user/runtime requested interruption.""" """Return True if user/runtime requested interruption."""
@ -69,3 +75,17 @@ def get_fs_object_size(path_or_object: str | BytesIO) -> int:
if isinstance(path_or_object, str): if isinstance(path_or_object, str):
return os.path.getsize(path_or_object) return os.path.getsize(path_or_object)
return len(path_or_object.getvalue()) return len(path_or_object.getvalue())
def to_aiohttp_url(url: str) -> URL:
"""If `url` appears to be already percent-encoded (contains at least one valid %HH
escape and no malformed '%' sequences) and contains no raw whitespace/control
characters preserve the original encoding byte-for-byte (important for signed/presigned URLs).
Otherwise, return `URL(url)` and allow yarl to normalize/quote as needed."""
if any(c.isspace() for c in url) or any(ord(c) < 0x20 for c in url):
# Avoid encoded=True if URL contains raw whitespace/control chars
return URL(url)
if _HAS_PCT_ESC.search(url) and not _HAS_BAD_PCT.search(url):
# Preserve encoding only if it appears pre-encoded AND has no invalid % sequences
return URL(url, encoded=True)
return URL(url)

View File

@ -430,9 +430,9 @@ def _display_text(
if status: if status:
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
if price is not None: if price is not None:
p = f"{float(price):,.4f}".rstrip("0").rstrip(".") p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
if p != "0": if p != "0":
display_lines.append(f"Price: ${p}") display_lines.append(f"Price: {p} credits")
if text is not None: if text is not None:
display_lines.append(text) display_lines.append(text)
if display_lines: if display_lines:

View File

@ -129,7 +129,7 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
return img_byte_arr return img_byte_arr
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor:
"""Downscale input image tensor to roughly the specified total pixels.""" """Downscale input image tensor to roughly the specified total pixels."""
samples = image.movedim(-1, 1) samples = image.movedim(-1, 1)
total = int(total_pixels) total = int(total_pixels)

Some files were not shown because too many files have changed in this diff Show More