mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Merge branch 'Comfy-Org:master' into offloader-maifee
This commit is contained in:
commit
ffa7a369ba
@ -53,6 +53,16 @@ try:
|
|||||||
repo.stash(ident)
|
repo.stash(ident)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print("nothing to stash") # noqa: T201
|
print("nothing to stash") # noqa: T201
|
||||||
|
except:
|
||||||
|
print("Could not stash, cleaning index and trying again.") # noqa: T201
|
||||||
|
repo.state_cleanup()
|
||||||
|
repo.index.read_tree(repo.head.peel().tree)
|
||||||
|
repo.index.write()
|
||||||
|
try:
|
||||||
|
repo.stash(ident)
|
||||||
|
except KeyError:
|
||||||
|
print("nothing to stash.") # noqa: T201
|
||||||
|
|
||||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||||
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
||||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||||
pause
|
pause
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||||
pause
|
pause
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||||
pause
|
pause
|
||||||
|
|||||||
2
.github/workflows/stable-release.yml
vendored
2
.github/workflows/stable-release.yml
vendored
@ -117,7 +117,7 @@ jobs:
|
|||||||
./python.exe get-pip.py
|
./python.exe get-pip.py
|
||||||
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
|
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
|
||||||
|
|
||||||
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
|
grep comfy ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
|
||||||
./python.exe -s -m pip install -r requirements_comfyui.txt
|
./python.exe -s -m pip install -r requirements_comfyui.txt
|
||||||
rm requirements_comfyui.txt
|
rm requirements_comfyui.txt
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/test-build.yml
vendored
2
.github/workflows/test-build.yml
vendored
@ -18,7 +18,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
|||||||
3
.github/workflows/test-ci.yml
vendored
3
.github/workflows/test-ci.yml
vendored
@ -5,6 +5,7 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- master
|
- master
|
||||||
|
- release/**
|
||||||
paths-ignore:
|
paths-ignore:
|
||||||
- 'app/**'
|
- 'app/**'
|
||||||
- 'input/**'
|
- 'input/**'
|
||||||
@ -19,6 +20,7 @@ jobs:
|
|||||||
test-stable:
|
test-stable:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
max-parallel: 1 # This forces sequential execution
|
||||||
matrix:
|
matrix:
|
||||||
# os: [macos, linux, windows]
|
# os: [macos, linux, windows]
|
||||||
# os: [macos, linux]
|
# os: [macos, linux]
|
||||||
@ -73,6 +75,7 @@ jobs:
|
|||||||
test-unix-nightly:
|
test-unix-nightly:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
max-parallel: 1 # This forces sequential execution
|
||||||
matrix:
|
matrix:
|
||||||
# os: [macos, linux]
|
# os: [macos, linux]
|
||||||
os: [linux]
|
os: [linux]
|
||||||
|
|||||||
4
.github/workflows/test-execution.yml
vendored
4
.github/workflows/test-execution.yml
vendored
@ -2,9 +2,9 @@ name: Execution Tests
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
|||||||
8
.github/workflows/test-launch.yml
vendored
8
.github/workflows/test-launch.yml
vendored
@ -2,9 +2,9 @@ name: Test server launches without errors
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
@ -32,7 +32,9 @@ jobs:
|
|||||||
working-directory: ComfyUI
|
working-directory: ComfyUI
|
||||||
- name: Check for unhandled exceptions in server log
|
- name: Check for unhandled exceptions in server log
|
||||||
run: |
|
run: |
|
||||||
if grep -qE "Exception|Error" console_output.log; then
|
grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': True, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" console_output.log | grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': False, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" > console_output_filtered.log
|
||||||
|
cat console_output_filtered.log
|
||||||
|
if grep -qE "Exception|Error" console_output_filtered.log; then
|
||||||
echo "Unhandled exception/error found in server log."
|
echo "Unhandled exception/error found in server log."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|||||||
4
.github/workflows/test-unit.yml
vendored
4
.github/workflows/test-unit.yml
vendored
@ -2,9 +2,9 @@ name: Unit Tests
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
|||||||
1
.github/workflows/update-version.yml
vendored
1
.github/workflows/update-version.yml
vendored
@ -6,6 +6,7 @@ on:
|
|||||||
- "pyproject.toml"
|
- "pyproject.toml"
|
||||||
branches:
|
branches:
|
||||||
- master
|
- master
|
||||||
|
- release/**
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
update-version:
|
update-version:
|
||||||
|
|||||||
@ -119,6 +119,9 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
|
|||||||
|
|
||||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||||
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
||||||
|
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
|
||||||
|
- Minor versions will be used for releases off the master branch.
|
||||||
|
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
|
||||||
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
|
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
|
||||||
- Serves as the foundation for the desktop release
|
- Serves as the foundation for the desktop release
|
||||||
|
|
||||||
@ -209,6 +212,8 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
|
|||||||
|
|
||||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||||
|
|
||||||
|
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old.
|
||||||
|
|
||||||
### Instructions:
|
### Instructions:
|
||||||
|
|
||||||
Git clone this repo.
|
Git clone this repo.
|
||||||
|
|||||||
@ -58,8 +58,13 @@ class InternalRoutes:
|
|||||||
return web.json_response({"error": "Invalid directory type"}, status=400)
|
return web.json_response({"error": "Invalid directory type"}, status=400)
|
||||||
|
|
||||||
directory = get_directory_by_type(directory_type)
|
directory = get_directory_by_type(directory_type)
|
||||||
|
|
||||||
|
def is_visible_file(entry: os.DirEntry) -> bool:
|
||||||
|
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
|
||||||
|
return entry.is_file() and not entry.name.startswith('.')
|
||||||
|
|
||||||
sorted_files = sorted(
|
sorted_files = sorted(
|
||||||
(entry for entry in os.scandir(directory) if entry.is_file()),
|
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
|
||||||
key=lambda entry: -entry.stat().st_mtime
|
key=lambda entry: -entry.stat().st_mtime
|
||||||
)
|
)
|
||||||
return web.json_response([entry.name for entry in sorted_files], status=200)
|
return web.json_response([entry.name for entry in sorted_files], status=200)
|
||||||
|
|||||||
@ -44,7 +44,7 @@ class ModelFileManager:
|
|||||||
@routes.get("/experiment/models/{folder}")
|
@routes.get("/experiment/models/{folder}")
|
||||||
async def get_all_models(request):
|
async def get_all_models(request):
|
||||||
folder = request.match_info.get("folder", None)
|
folder = request.match_info.get("folder", None)
|
||||||
if not folder in folder_paths.folder_names_and_paths:
|
if folder not in folder_paths.folder_names_and_paths:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
files = self.get_model_file_list(folder)
|
files = self.get_model_file_list(folder)
|
||||||
return web.json_response(files)
|
return web.json_response(files)
|
||||||
@ -55,7 +55,7 @@ class ModelFileManager:
|
|||||||
path_index = int(request.match_info.get("path_index", None))
|
path_index = int(request.match_info.get("path_index", None))
|
||||||
filename = request.match_info.get("filename", None)
|
filename = request.match_info.get("filename", None)
|
||||||
|
|
||||||
if not folder_name in folder_paths.folder_names_and_paths:
|
if folder_name not in folder_paths.folder_names_and_paths:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||||
|
|||||||
@ -97,6 +97,13 @@ class LatentPreviewMethod(enum.Enum):
|
|||||||
Latent2RGB = "latent2rgb"
|
Latent2RGB = "latent2rgb"
|
||||||
TAESD = "taesd"
|
TAESD = "taesd"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, value: str):
|
||||||
|
for member in cls:
|
||||||
|
if member.value == value:
|
||||||
|
return member
|
||||||
|
return None
|
||||||
|
|
||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||||
|
|
||||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||||
|
|||||||
@ -2,6 +2,25 @@ import torch
|
|||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
|
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||||
|
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||||
|
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||||
|
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||||
|
image = image.movedim(-1, 1)
|
||||||
|
if not (image.shape[2] == size and image.shape[3] == size):
|
||||||
|
if crop:
|
||||||
|
scale = (size / min(image.shape[2], image.shape[3]))
|
||||||
|
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||||
|
else:
|
||||||
|
scale_size = (size, size)
|
||||||
|
|
||||||
|
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||||
|
h = (image.shape[2] - size)//2
|
||||||
|
w = (image.shape[3] - size)//2
|
||||||
|
image = image[:,:,h:h+size,w:w+size]
|
||||||
|
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||||
|
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||||
|
|
||||||
class CLIPAttention(torch.nn.Module):
|
class CLIPAttention(torch.nn.Module):
|
||||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -17,24 +16,7 @@ class Output:
|
|||||||
def __setitem__(self, key, item):
|
def __setitem__(self, key, item):
|
||||||
setattr(self, key, item)
|
setattr(self, key, item)
|
||||||
|
|
||||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually
|
||||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
|
||||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
|
||||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
|
||||||
image = image.movedim(-1, 1)
|
|
||||||
if not (image.shape[2] == size and image.shape[3] == size):
|
|
||||||
if crop:
|
|
||||||
scale = (size / min(image.shape[2], image.shape[3]))
|
|
||||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
|
||||||
else:
|
|
||||||
scale_size = (size, size)
|
|
||||||
|
|
||||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
|
||||||
h = (image.shape[2] - size)//2
|
|
||||||
w = (image.shape[3] - size)//2
|
|
||||||
image = image[:,:,h:h+size,w:w+size]
|
|
||||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
|
||||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
|
||||||
|
|
||||||
IMAGE_ENCODERS = {
|
IMAGE_ENCODERS = {
|
||||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
@ -73,7 +55,7 @@ class ClipVisionModel():
|
|||||||
|
|
||||||
def encode_image(self, image, crop=True):
|
def encode_image(self, image, crop=True):
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||||
|
|
||||||
outputs = Output()
|
outputs = Output()
|
||||||
|
|||||||
@ -87,6 +87,7 @@ class IndexListCallbacks:
|
|||||||
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
||||||
EXECUTE_START = "execute_start"
|
EXECUTE_START = "execute_start"
|
||||||
EXECUTE_CLEANUP = "execute_cleanup"
|
EXECUTE_CLEANUP = "execute_cleanup"
|
||||||
|
RESIZE_COND_ITEM = "resize_cond_item"
|
||||||
|
|
||||||
def init_callbacks(self):
|
def init_callbacks(self):
|
||||||
return {}
|
return {}
|
||||||
@ -142,7 +143,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
# if multiple conds, split based on primary region
|
# if multiple conds, split based on primary region
|
||||||
if self.split_conds_to_windows and len(cond_in) > 1:
|
if self.split_conds_to_windows and len(cond_in) > 1:
|
||||||
region = window.get_region_index(len(cond_in))
|
region = window.get_region_index(len(cond_in))
|
||||||
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
|
logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}")
|
||||||
cond_in = [cond_in[region]]
|
cond_in = [cond_in[region]]
|
||||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||||
for actual_cond in cond_in:
|
for actual_cond in cond_in:
|
||||||
@ -166,6 +167,18 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
new_cond_item = cond_item.copy()
|
new_cond_item = cond_item.copy()
|
||||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||||
for cond_key, cond_value in new_cond_item.items():
|
for cond_key, cond_value in new_cond_item.items():
|
||||||
|
# Allow callbacks to handle custom conditioning items
|
||||||
|
handled = False
|
||||||
|
for callback in comfy.patcher_extension.get_all_callbacks(
|
||||||
|
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
|
||||||
|
):
|
||||||
|
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
|
||||||
|
if result is not None:
|
||||||
|
new_cond_item[cond_key] = result
|
||||||
|
handled = True
|
||||||
|
break
|
||||||
|
if handled:
|
||||||
|
continue
|
||||||
if isinstance(cond_value, torch.Tensor):
|
if isinstance(cond_value, torch.Tensor):
|
||||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||||
@ -175,6 +188,12 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
audio_cond = cond_value.cond
|
audio_cond = cond_value.cond
|
||||||
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
||||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
||||||
|
# Handle vace_context (temporal dim is 3)
|
||||||
|
elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
|
vace_cond = cond_value.cond
|
||||||
|
if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim):
|
||||||
|
sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list)
|
||||||
|
new_cond_item[cond_key] = cond_value._copy_with(sliced_vace)
|
||||||
# if has cond that is a Tensor, check if needs to be subset
|
# if has cond that is a Tensor, check if needs to be subset
|
||||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
||||||
|
|||||||
@ -527,7 +527,8 @@ class HookKeyframeGroup:
|
|||||||
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
||||||
break
|
break
|
||||||
# if eval_c is outside the percent range, stop looking further
|
# if eval_c is outside the percent range, stop looking further
|
||||||
else: break
|
else:
|
||||||
|
break
|
||||||
# update steps current context is used
|
# update steps current context is used
|
||||||
self._current_used_steps += 1
|
self._current_used_steps += 1
|
||||||
# update current timestep this was performed on
|
# update current timestep this was performed on
|
||||||
|
|||||||
@ -74,6 +74,9 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
|||||||
|
|
||||||
def default_noise_sampler(x, seed=None):
|
def default_noise_sampler(x, seed=None):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
|
if x.device == torch.device("cpu"):
|
||||||
|
seed += 1
|
||||||
|
|
||||||
generator = torch.Generator(device=x.device)
|
generator = torch.Generator(device=x.device)
|
||||||
generator.manual_seed(seed)
|
generator.manual_seed(seed)
|
||||||
else:
|
else:
|
||||||
@ -1557,10 +1560,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||||
"""
|
"""
|
||||||
|
if solver_type not in {"phi_1", "phi_2"}:
|
||||||
|
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||||
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
@ -1600,8 +1606,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||||
|
|
||||||
# Step 2
|
# Step 2
|
||||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
if solver_type == "phi_1":
|
||||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||||
|
elif solver_type == "phi_2":
|
||||||
|
b2 = ei_h_phi_2(-h_eta) / r
|
||||||
|
b1 = ei_h_phi_1(-h_eta) - b2
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||||
|
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
segment_factor = (r - 1) * h * eta
|
segment_factor = (r - 1) * h * eta
|
||||||
sde_noise = sde_noise * segment_factor.exp()
|
sde_noise = sde_noise * segment_factor.exp()
|
||||||
@ -1609,6 +1621,17 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"):
|
||||||
|
"""Deterministic exponential Heun second order method in data prediction (x0) and logSNR time."""
|
||||||
|
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"):
|
||||||
|
"""Stochastic exponential Heun second order method in data prediction (x0) and logSNR time."""
|
||||||
|
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||||
@ -1756,7 +1779,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
|||||||
# Predictor
|
# Predictor
|
||||||
if sigmas[i + 1] == 0:
|
if sigmas[i + 1] == 0:
|
||||||
# Denoising step
|
# Denoising step
|
||||||
x = denoised
|
x_pred = denoised
|
||||||
else:
|
else:
|
||||||
tau_t = tau_func(sigmas[i + 1])
|
tau_t = tau_func(sigmas[i + 1])
|
||||||
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
||||||
@ -1777,7 +1800,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
|||||||
if tau_t > 0 and s_noise > 0:
|
if tau_t > 0 and s_noise > 0:
|
||||||
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
||||||
x_pred = x_pred + noise
|
x_pred = x_pred + noise
|
||||||
return x
|
return x_pred
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
@ -407,6 +407,11 @@ class LTXV(LatentFormat):
|
|||||||
|
|
||||||
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
||||||
|
|
||||||
|
class LTXAV(LTXV):
|
||||||
|
def __init__(self):
|
||||||
|
self.latent_rgb_factors = None
|
||||||
|
self.latent_rgb_factors_bias = None
|
||||||
|
|
||||||
class HunyuanVideo(LatentFormat):
|
class HunyuanVideo(LatentFormat):
|
||||||
latent_channels = 16
|
latent_channels = 16
|
||||||
latent_dimensions = 3
|
latent_dimensions = 3
|
||||||
|
|||||||
@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
|
|||||||
nerf_final_head_type: str
|
nerf_final_head_type: str
|
||||||
# None means use the same dtype as the model.
|
# None means use the same dtype as the model.
|
||||||
nerf_embedder_dtype: Optional[torch.dtype]
|
nerf_embedder_dtype: Optional[torch.dtype]
|
||||||
|
use_x0: bool
|
||||||
|
|
||||||
class ChromaRadiance(Chroma):
|
class ChromaRadiance(Chroma):
|
||||||
"""
|
"""
|
||||||
@ -159,6 +159,9 @@ class ChromaRadiance(Chroma):
|
|||||||
self.skip_dit = []
|
self.skip_dit = []
|
||||||
self.lite = False
|
self.lite = False
|
||||||
|
|
||||||
|
if params.use_x0:
|
||||||
|
self.register_buffer("__x0__", torch.tensor([]))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _nerf_final_layer(self) -> nn.Module:
|
def _nerf_final_layer(self) -> nn.Module:
|
||||||
if self.params.nerf_final_head_type == "linear":
|
if self.params.nerf_final_head_type == "linear":
|
||||||
@ -267,7 +270,7 @@ class ChromaRadiance(Chroma):
|
|||||||
bad_keys = tuple(
|
bad_keys = tuple(
|
||||||
k
|
k
|
||||||
for k, v in overrides.items()
|
for k, v in overrides.items()
|
||||||
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
|
if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys)
|
||||||
)
|
)
|
||||||
if bad_keys:
|
if bad_keys:
|
||||||
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||||
@ -276,6 +279,12 @@ class ChromaRadiance(Chroma):
|
|||||||
params_dict |= overrides
|
params_dict |= overrides
|
||||||
return params.__class__(**params_dict)
|
return params.__class__(**params_dict)
|
||||||
|
|
||||||
|
def _apply_x0_residual(self, predicted, noisy, timesteps):
|
||||||
|
|
||||||
|
# non zero during training to prevent 0 div
|
||||||
|
eps = 0.0
|
||||||
|
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@ -316,4 +325,11 @@ class ChromaRadiance(Chroma):
|
|||||||
transformer_options,
|
transformer_options,
|
||||||
attn_mask=kwargs.get("attention_mask", None),
|
attn_mask=kwargs.get("attention_mask", None),
|
||||||
)
|
)
|
||||||
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
|
||||||
|
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||||
|
|
||||||
|
# If x0 variant → v-pred, just return this instead
|
||||||
|
if hasattr(self, "__x0__"):
|
||||||
|
out = self._apply_x0_residual(out, img, timestep)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from torch import Tensor
|
|||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||||
@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
|
|||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||||
assert dim % 2 == 0
|
assert dim % 2 == 0
|
||||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
||||||
@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||||
return out.to(dtype=torch.float32, device=pos.device)
|
return out.to(dtype=torch.float32, device=pos.device)
|
||||||
|
|
||||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
|
||||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
|
||||||
|
|
||||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
try:
|
||||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
import comfy.quant_ops
|
||||||
|
apply_rope = comfy.quant_ops.ck.apply_rope
|
||||||
|
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||||
|
except:
|
||||||
|
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||||
|
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||||
|
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||||
|
|
||||||
return x_out.reshape(*x.shape).type_as(x)
|
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||||
|
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||||
|
|
||||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
return x_out.reshape(*x.shape).type_as(x)
|
||||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
|
||||||
|
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||||
|
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||||
|
|||||||
@ -43,6 +43,7 @@ class HunyuanVideoParams:
|
|||||||
meanflow: bool
|
meanflow: bool
|
||||||
use_cond_type_embedding: bool
|
use_cond_type_embedding: bool
|
||||||
vision_in_dim: int
|
vision_in_dim: int
|
||||||
|
meanflow_sum: bool
|
||||||
|
|
||||||
|
|
||||||
class SelfAttentionRef(nn.Module):
|
class SelfAttentionRef(nn.Module):
|
||||||
@ -317,7 +318,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
||||||
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
|
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
|
||||||
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
|
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
|
||||||
vec = (vec + vec_r) / 2
|
vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2
|
||||||
|
|
||||||
if ref_latent is not None:
|
if ref_latent is not None:
|
||||||
ref_latent_ids = self.img_ids(ref_latent)
|
ref_latent_ids = self.img_ids(ref_latent)
|
||||||
|
|||||||
@ -3,7 +3,8 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||||
import model_management, model_patcher
|
import model_management
|
||||||
|
import model_patcher
|
||||||
|
|
||||||
class SRResidualCausalBlock3D(nn.Module):
|
class SRResidualCausalBlock3D(nn.Module):
|
||||||
def __init__(self, channels: int):
|
def __init__(self, channels: int):
|
||||||
|
|||||||
837
comfy/ldm/lightricks/av_model.py
Normal file
837
comfy/ldm/lightricks/av_model.py
Normal file
@ -0,0 +1,837 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from comfy.ldm.lightricks.model import (
|
||||||
|
CrossAttention,
|
||||||
|
FeedForward,
|
||||||
|
AdaLayerNormSingle,
|
||||||
|
PixArtAlphaTextProjection,
|
||||||
|
LTXVModel,
|
||||||
|
)
|
||||||
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
class BasicAVTransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
v_dim,
|
||||||
|
a_dim,
|
||||||
|
v_heads,
|
||||||
|
a_heads,
|
||||||
|
vd_head,
|
||||||
|
ad_head,
|
||||||
|
v_context_dim=None,
|
||||||
|
a_context_dim=None,
|
||||||
|
attn_precision=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn_precision = attn_precision
|
||||||
|
|
||||||
|
self.attn1 = CrossAttention(
|
||||||
|
query_dim=v_dim,
|
||||||
|
heads=v_heads,
|
||||||
|
dim_head=vd_head,
|
||||||
|
context_dim=None,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.audio_attn1 = CrossAttention(
|
||||||
|
query_dim=a_dim,
|
||||||
|
heads=a_heads,
|
||||||
|
dim_head=ad_head,
|
||||||
|
context_dim=None,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attn2 = CrossAttention(
|
||||||
|
query_dim=v_dim,
|
||||||
|
context_dim=v_context_dim,
|
||||||
|
heads=v_heads,
|
||||||
|
dim_head=vd_head,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.audio_attn2 = CrossAttention(
|
||||||
|
query_dim=a_dim,
|
||||||
|
context_dim=a_context_dim,
|
||||||
|
heads=a_heads,
|
||||||
|
dim_head=ad_head,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Q: Video, K,V: Audio
|
||||||
|
self.audio_to_video_attn = CrossAttention(
|
||||||
|
query_dim=v_dim,
|
||||||
|
context_dim=a_dim,
|
||||||
|
heads=a_heads,
|
||||||
|
dim_head=ad_head,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Q: Audio, K,V: Video
|
||||||
|
self.video_to_audio_attn = CrossAttention(
|
||||||
|
query_dim=a_dim,
|
||||||
|
context_dim=v_dim,
|
||||||
|
heads=a_heads,
|
||||||
|
dim_head=ad_head,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ff = FeedForward(
|
||||||
|
v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
self.audio_ff = FeedForward(
|
||||||
|
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
||||||
|
self.audio_scale_shift_table = nn.Parameter(
|
||||||
|
torch.empty(6, a_dim, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
||||||
|
torch.empty(5, a_dim, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
self.scale_shift_table_a2v_ca_video = nn.Parameter(
|
||||||
|
torch.empty(5, v_dim, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_ada_values(
|
||||||
|
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
||||||
|
):
|
||||||
|
num_ada_params = scale_shift_table.shape[0]
|
||||||
|
|
||||||
|
ada_values = (
|
||||||
|
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
|
||||||
|
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
|
||||||
|
).unbind(dim=2)
|
||||||
|
return ada_values
|
||||||
|
|
||||||
|
def get_av_ca_ada_values(
|
||||||
|
self,
|
||||||
|
scale_shift_table: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
scale_shift_timestep: torch.Tensor,
|
||||||
|
gate_timestep: torch.Tensor,
|
||||||
|
num_scale_shift_values: int = 4,
|
||||||
|
):
|
||||||
|
scale_shift_ada_values = self.get_ada_values(
|
||||||
|
scale_shift_table[:num_scale_shift_values, :],
|
||||||
|
batch_size,
|
||||||
|
scale_shift_timestep,
|
||||||
|
)
|
||||||
|
gate_ada_values = self.get_ada_values(
|
||||||
|
scale_shift_table[num_scale_shift_values:, :],
|
||||||
|
batch_size,
|
||||||
|
gate_timestep,
|
||||||
|
)
|
||||||
|
|
||||||
|
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
|
||||||
|
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
|
||||||
|
|
||||||
|
return (*scale_shift_chunks, *gate_ada_values)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
v_context=None,
|
||||||
|
a_context=None,
|
||||||
|
attention_mask=None,
|
||||||
|
v_timestep=None,
|
||||||
|
a_timestep=None,
|
||||||
|
v_pe=None,
|
||||||
|
a_pe=None,
|
||||||
|
v_cross_pe=None,
|
||||||
|
a_cross_pe=None,
|
||||||
|
v_cross_scale_shift_timestep=None,
|
||||||
|
a_cross_scale_shift_timestep=None,
|
||||||
|
v_cross_gate_timestep=None,
|
||||||
|
a_cross_gate_timestep=None,
|
||||||
|
transformer_options=None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
run_vx = transformer_options.get("run_vx", True)
|
||||||
|
run_ax = transformer_options.get("run_ax", True)
|
||||||
|
|
||||||
|
vx, ax = x
|
||||||
|
run_ax = run_ax and ax.numel() > 0
|
||||||
|
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||||
|
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||||
|
|
||||||
|
if run_vx:
|
||||||
|
vshift_msa, vscale_msa, vgate_msa = (
|
||||||
|
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
||||||
|
)
|
||||||
|
|
||||||
|
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||||
|
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
||||||
|
vx += self.attn2(
|
||||||
|
comfy.ldm.common_dit.rms_norm(vx),
|
||||||
|
context=v_context,
|
||||||
|
mask=attention_mask,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
del vshift_msa, vscale_msa, vgate_msa
|
||||||
|
|
||||||
|
if run_ax:
|
||||||
|
ashift_msa, ascale_msa, agate_msa = (
|
||||||
|
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
||||||
|
)
|
||||||
|
|
||||||
|
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||||
|
ax += (
|
||||||
|
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||||
|
* agate_msa
|
||||||
|
)
|
||||||
|
ax += self.audio_attn2(
|
||||||
|
comfy.ldm.common_dit.rms_norm(ax),
|
||||||
|
context=a_context,
|
||||||
|
mask=attention_mask,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
del ashift_msa, ascale_msa, agate_msa
|
||||||
|
|
||||||
|
# Audio - Video cross attention.
|
||||||
|
if run_a2v or run_v2a:
|
||||||
|
# norm3
|
||||||
|
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||||
|
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||||
|
|
||||||
|
(
|
||||||
|
scale_ca_audio_hidden_states_a2v,
|
||||||
|
shift_ca_audio_hidden_states_a2v,
|
||||||
|
scale_ca_audio_hidden_states_v2a,
|
||||||
|
shift_ca_audio_hidden_states_v2a,
|
||||||
|
gate_out_v2a,
|
||||||
|
) = self.get_av_ca_ada_values(
|
||||||
|
self.scale_shift_table_a2v_ca_audio,
|
||||||
|
ax.shape[0],
|
||||||
|
a_cross_scale_shift_timestep,
|
||||||
|
a_cross_gate_timestep,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
scale_ca_video_hidden_states_a2v,
|
||||||
|
shift_ca_video_hidden_states_a2v,
|
||||||
|
scale_ca_video_hidden_states_v2a,
|
||||||
|
shift_ca_video_hidden_states_v2a,
|
||||||
|
gate_out_a2v,
|
||||||
|
) = self.get_av_ca_ada_values(
|
||||||
|
self.scale_shift_table_a2v_ca_video,
|
||||||
|
vx.shape[0],
|
||||||
|
v_cross_scale_shift_timestep,
|
||||||
|
v_cross_gate_timestep,
|
||||||
|
)
|
||||||
|
|
||||||
|
if run_a2v:
|
||||||
|
vx_scaled = (
|
||||||
|
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
||||||
|
+ shift_ca_video_hidden_states_a2v
|
||||||
|
)
|
||||||
|
ax_scaled = (
|
||||||
|
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
||||||
|
+ shift_ca_audio_hidden_states_a2v
|
||||||
|
)
|
||||||
|
vx += (
|
||||||
|
self.audio_to_video_attn(
|
||||||
|
vx_scaled,
|
||||||
|
context=ax_scaled,
|
||||||
|
pe=v_cross_pe,
|
||||||
|
k_pe=a_cross_pe,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
* gate_out_a2v
|
||||||
|
)
|
||||||
|
|
||||||
|
del gate_out_a2v
|
||||||
|
del scale_ca_video_hidden_states_a2v,\
|
||||||
|
shift_ca_video_hidden_states_a2v,\
|
||||||
|
scale_ca_audio_hidden_states_a2v,\
|
||||||
|
shift_ca_audio_hidden_states_a2v,\
|
||||||
|
|
||||||
|
if run_v2a:
|
||||||
|
ax_scaled = (
|
||||||
|
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
||||||
|
+ shift_ca_audio_hidden_states_v2a
|
||||||
|
)
|
||||||
|
vx_scaled = (
|
||||||
|
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
||||||
|
+ shift_ca_video_hidden_states_v2a
|
||||||
|
)
|
||||||
|
ax += (
|
||||||
|
self.video_to_audio_attn(
|
||||||
|
ax_scaled,
|
||||||
|
context=vx_scaled,
|
||||||
|
pe=a_cross_pe,
|
||||||
|
k_pe=v_cross_pe,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
* gate_out_v2a
|
||||||
|
)
|
||||||
|
|
||||||
|
del gate_out_v2a
|
||||||
|
del scale_ca_video_hidden_states_v2a,\
|
||||||
|
shift_ca_video_hidden_states_v2a,\
|
||||||
|
scale_ca_audio_hidden_states_v2a,\
|
||||||
|
shift_ca_audio_hidden_states_v2a
|
||||||
|
|
||||||
|
if run_vx:
|
||||||
|
vshift_mlp, vscale_mlp, vgate_mlp = (
|
||||||
|
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
||||||
|
)
|
||||||
|
|
||||||
|
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||||
|
vx += self.ff(vx_scaled) * vgate_mlp
|
||||||
|
del vshift_mlp, vscale_mlp, vgate_mlp
|
||||||
|
|
||||||
|
if run_ax:
|
||||||
|
ashift_mlp, ascale_mlp, agate_mlp = (
|
||||||
|
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
||||||
|
)
|
||||||
|
|
||||||
|
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||||
|
ax += self.audio_ff(ax_scaled) * agate_mlp
|
||||||
|
|
||||||
|
del ashift_mlp, ascale_mlp, agate_mlp
|
||||||
|
|
||||||
|
|
||||||
|
return vx, ax
|
||||||
|
|
||||||
|
|
||||||
|
class LTXAVModel(LTXVModel):
|
||||||
|
"""LTXAV model for audio-video generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=128,
|
||||||
|
audio_in_channels=128,
|
||||||
|
cross_attention_dim=4096,
|
||||||
|
audio_cross_attention_dim=2048,
|
||||||
|
attention_head_dim=128,
|
||||||
|
audio_attention_head_dim=64,
|
||||||
|
num_attention_heads=32,
|
||||||
|
audio_num_attention_heads=32,
|
||||||
|
caption_channels=3840,
|
||||||
|
num_layers=48,
|
||||||
|
positional_embedding_theta=10000.0,
|
||||||
|
positional_embedding_max_pos=[20, 2048, 2048],
|
||||||
|
audio_positional_embedding_max_pos=[20],
|
||||||
|
causal_temporal_positioning=False,
|
||||||
|
vae_scale_factors=(8, 32, 32),
|
||||||
|
use_middle_indices_grid=False,
|
||||||
|
timestep_scale_multiplier=1000.0,
|
||||||
|
av_ca_timestep_scale_multiplier=1.0,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# Store audio-specific parameters
|
||||||
|
self.audio_in_channels = audio_in_channels
|
||||||
|
self.audio_cross_attention_dim = audio_cross_attention_dim
|
||||||
|
self.audio_attention_head_dim = audio_attention_head_dim
|
||||||
|
self.audio_num_attention_heads = audio_num_attention_heads
|
||||||
|
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
||||||
|
|
||||||
|
# Calculate audio dimensions
|
||||||
|
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
||||||
|
self.audio_out_channels = audio_in_channels
|
||||||
|
|
||||||
|
# Audio-specific constants
|
||||||
|
self.num_audio_channels = 8
|
||||||
|
self.audio_frequency_bins = 16
|
||||||
|
|
||||||
|
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
in_channels=in_channels,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
caption_channels=caption_channels,
|
||||||
|
num_layers=num_layers,
|
||||||
|
positional_embedding_theta=positional_embedding_theta,
|
||||||
|
positional_embedding_max_pos=positional_embedding_max_pos,
|
||||||
|
causal_temporal_positioning=causal_temporal_positioning,
|
||||||
|
vae_scale_factors=vae_scale_factors,
|
||||||
|
use_middle_indices_grid=use_middle_indices_grid,
|
||||||
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_model_components(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize LTXAV-specific components."""
|
||||||
|
# Audio-specific projections
|
||||||
|
self.audio_patchify_proj = self.operations.Linear(
|
||||||
|
self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audio-specific AdaLN
|
||||||
|
self.audio_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.audio_inner_dim,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_scale_shift_values = 4
|
||||||
|
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.inner_dim,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
embedding_coefficient=num_scale_shift_values,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.inner_dim,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
embedding_coefficient=1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.audio_inner_dim,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
embedding_coefficient=num_scale_shift_values,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.audio_inner_dim,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
embedding_coefficient=1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audio caption projection
|
||||||
|
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||||
|
in_features=self.caption_channels,
|
||||||
|
hidden_size=self.audio_inner_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize transformer blocks for LTXAV."""
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicAVTransformerBlock(
|
||||||
|
v_dim=self.inner_dim,
|
||||||
|
a_dim=self.audio_inner_dim,
|
||||||
|
v_heads=self.num_attention_heads,
|
||||||
|
a_heads=self.audio_num_attention_heads,
|
||||||
|
vd_head=self.attention_head_dim,
|
||||||
|
ad_head=self.audio_attention_head_dim,
|
||||||
|
v_context_dim=self.cross_attention_dim,
|
||||||
|
a_context_dim=self.audio_cross_attention_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_output_components(self, device, dtype):
|
||||||
|
"""Initialize output components for LTXAV."""
|
||||||
|
# Video output components
|
||||||
|
super()._init_output_components(device, dtype)
|
||||||
|
# Audio output components
|
||||||
|
self.audio_scale_shift_table = nn.Parameter(
|
||||||
|
torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
self.audio_norm_out = self.operations.LayerNorm(
|
||||||
|
self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.audio_proj_out = self.operations.Linear(
|
||||||
|
self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.a_patchifier = AudioPatchifier(1, start_end=True)
|
||||||
|
|
||||||
|
def separate_audio_and_video_latents(self, x, audio_length):
|
||||||
|
"""Separate audio and video latents from combined input."""
|
||||||
|
# vx = x[:, : self.in_channels]
|
||||||
|
# ax = x[:, self.in_channels :]
|
||||||
|
#
|
||||||
|
# ax = ax.reshape(ax.shape[0], -1)
|
||||||
|
# ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins]
|
||||||
|
#
|
||||||
|
# ax = ax.reshape(
|
||||||
|
# ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins
|
||||||
|
# )
|
||||||
|
|
||||||
|
vx = x[0]
|
||||||
|
ax = x[1] if len(x) > 1 else torch.zeros(
|
||||||
|
(vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins),
|
||||||
|
device=vx.device, dtype=vx.dtype
|
||||||
|
)
|
||||||
|
return vx, ax
|
||||||
|
|
||||||
|
def recombine_audio_and_video_latents(self, vx, ax, target_shape=None):
|
||||||
|
if ax.numel() == 0:
|
||||||
|
return vx
|
||||||
|
else:
|
||||||
|
return [vx, ax]
|
||||||
|
"""Recombine audio and video latents for output."""
|
||||||
|
# if ax.device != vx.device or ax.dtype != vx.dtype:
|
||||||
|
# logging.warning("Audio and video latents are on different devices or dtypes.")
|
||||||
|
# ax = ax.to(device=vx.device, dtype=vx.dtype)
|
||||||
|
# logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}")
|
||||||
|
#
|
||||||
|
# ax = ax.reshape(ax.shape[0], -1)
|
||||||
|
# # pad to f x h x w of the video latents
|
||||||
|
# divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3]
|
||||||
|
# if target_shape is None:
|
||||||
|
# repetitions = math.ceil(ax.shape[-1] / divisor)
|
||||||
|
# else:
|
||||||
|
# repetitions = target_shape[1] - vx.shape[1]
|
||||||
|
# padded_len = repetitions * divisor
|
||||||
|
# ax = F.pad(ax, (0, padded_len - ax.shape[-1]))
|
||||||
|
# ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1])
|
||||||
|
# return torch.cat([vx, ax], dim=1)
|
||||||
|
|
||||||
|
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||||
|
"""Process input for LTXAV - separate audio and video, then patchify."""
|
||||||
|
audio_length = kwargs.get("audio_length", 0)
|
||||||
|
# Separate audio and video latents
|
||||||
|
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||||
|
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||||
|
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||||
|
ax = self.audio_patchify_proj(ax)
|
||||||
|
|
||||||
|
# additional_args.update({"av_orig_shape": list(x.shape)})
|
||||||
|
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
|
||||||
|
|
||||||
|
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||||
|
"""Prepare timestep embeddings."""
|
||||||
|
# TODO: some code reuse is needed here.
|
||||||
|
grid_mask = kwargs.get("grid_mask", None)
|
||||||
|
if grid_mask is not None:
|
||||||
|
timestep = timestep[:, grid_mask]
|
||||||
|
|
||||||
|
timestep = timestep * self.timestep_scale_multiplier
|
||||||
|
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||||
|
timestep.flatten(),
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||||
|
v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1])
|
||||||
|
v_embedded_timestep = v_embedded_timestep.view(
|
||||||
|
batch_size, -1, v_embedded_timestep.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare audio timestep
|
||||||
|
a_timestep = kwargs.get("a_timestep")
|
||||||
|
if a_timestep is not None:
|
||||||
|
a_timestep = a_timestep * self.timestep_scale_multiplier
|
||||||
|
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
||||||
|
|
||||||
|
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||||
|
a_timestep.flatten(),
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||||
|
timestep.flatten(),
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||||
|
timestep.flatten() * av_ca_factor,
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||||
|
a_timestep.flatten() * av_ca_factor,
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
||||||
|
a_timestep.flatten(),
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||||
|
a_embedded_timestep = a_embedded_timestep.view(
|
||||||
|
batch_size, -1, a_embedded_timestep.shape[-1]
|
||||||
|
)
|
||||||
|
cross_av_timestep_ss = [
|
||||||
|
av_ca_audio_scale_shift_timestep,
|
||||||
|
av_ca_video_scale_shift_timestep,
|
||||||
|
av_ca_a2v_gate_noise_timestep,
|
||||||
|
av_ca_v2a_gate_noise_timestep,
|
||||||
|
]
|
||||||
|
cross_av_timestep_ss = list(
|
||||||
|
[t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
a_timestep = timestep
|
||||||
|
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||||
|
cross_av_timestep_ss = []
|
||||||
|
|
||||||
|
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
||||||
|
v_embedded_timestep,
|
||||||
|
a_embedded_timestep,
|
||||||
|
]
|
||||||
|
|
||||||
|
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||||
|
vx = x[0]
|
||||||
|
ax = x[1]
|
||||||
|
v_context, a_context = torch.split(
|
||||||
|
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
v_context, attention_mask = super()._prepare_context(
|
||||||
|
v_context, batch_size, vx, attention_mask
|
||||||
|
)
|
||||||
|
if self.audio_caption_projection is not None:
|
||||||
|
a_context = self.audio_caption_projection(a_context)
|
||||||
|
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
||||||
|
|
||||||
|
return [v_context, a_context], attention_mask
|
||||||
|
|
||||||
|
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
||||||
|
v_pixel_coords = pixel_coords[0]
|
||||||
|
v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype)
|
||||||
|
|
||||||
|
a_latent_coords = pixel_coords[1]
|
||||||
|
a_pe = self._precompute_freqs_cis(
|
||||||
|
a_latent_coords,
|
||||||
|
dim=self.audio_inner_dim,
|
||||||
|
out_dtype=x_dtype,
|
||||||
|
max_pos=self.audio_positional_embedding_max_pos,
|
||||||
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||||
|
num_attention_heads=self.audio_num_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate positional embeddings for the middle of the token duration, to use in av cross attention layers.
|
||||||
|
max_pos = max(
|
||||||
|
self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]
|
||||||
|
)
|
||||||
|
v_pixel_coords = v_pixel_coords.to(torch.float32)
|
||||||
|
v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate)
|
||||||
|
av_cross_video_freq_cis = self._precompute_freqs_cis(
|
||||||
|
v_pixel_coords[:, 0:1, :],
|
||||||
|
dim=self.audio_cross_attention_dim,
|
||||||
|
out_dtype=x_dtype,
|
||||||
|
max_pos=[max_pos],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=self.audio_num_attention_heads,
|
||||||
|
)
|
||||||
|
av_cross_audio_freq_cis = self._precompute_freqs_cis(
|
||||||
|
a_latent_coords[:, 0:1, :],
|
||||||
|
dim=self.audio_cross_attention_dim,
|
||||||
|
out_dtype=x_dtype,
|
||||||
|
max_pos=[max_pos],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=self.audio_num_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||||
|
|
||||||
|
def _process_transformer_blocks(
|
||||||
|
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
||||||
|
):
|
||||||
|
vx = x[0]
|
||||||
|
ax = x[1]
|
||||||
|
v_context = context[0]
|
||||||
|
a_context = context[1]
|
||||||
|
v_timestep = timestep[0]
|
||||||
|
a_timestep = timestep[1]
|
||||||
|
v_pe, av_cross_video_freq_cis = pe[0]
|
||||||
|
a_pe, av_cross_audio_freq_cis = pe[1]
|
||||||
|
|
||||||
|
(
|
||||||
|
av_ca_audio_scale_shift_timestep,
|
||||||
|
av_ca_video_scale_shift_timestep,
|
||||||
|
av_ca_a2v_gate_noise_timestep,
|
||||||
|
av_ca_v2a_gate_noise_timestep,
|
||||||
|
) = timestep[2]
|
||||||
|
|
||||||
|
"""Process transformer blocks for LTXAV."""
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
|
# Process transformer blocks
|
||||||
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(
|
||||||
|
args["img"],
|
||||||
|
v_context=args["v_context"],
|
||||||
|
a_context=args["a_context"],
|
||||||
|
attention_mask=args["attention_mask"],
|
||||||
|
v_timestep=args["v_timestep"],
|
||||||
|
a_timestep=args["a_timestep"],
|
||||||
|
v_pe=args["v_pe"],
|
||||||
|
a_pe=args["a_pe"],
|
||||||
|
v_cross_pe=args["v_cross_pe"],
|
||||||
|
a_cross_pe=args["a_cross_pe"],
|
||||||
|
v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"],
|
||||||
|
a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"],
|
||||||
|
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||||
|
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||||
|
transformer_options=args["transformer_options"],
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("double_block", i)](
|
||||||
|
{
|
||||||
|
"img": (vx, ax),
|
||||||
|
"v_context": v_context,
|
||||||
|
"a_context": a_context,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"v_timestep": v_timestep,
|
||||||
|
"a_timestep": a_timestep,
|
||||||
|
"v_pe": v_pe,
|
||||||
|
"a_pe": a_pe,
|
||||||
|
"v_cross_pe": av_cross_video_freq_cis,
|
||||||
|
"a_cross_pe": av_cross_audio_freq_cis,
|
||||||
|
"v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep,
|
||||||
|
"a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep,
|
||||||
|
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||||
|
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||||
|
"transformer_options": transformer_options,
|
||||||
|
},
|
||||||
|
{"original_block": block_wrap},
|
||||||
|
)
|
||||||
|
vx, ax = out["img"]
|
||||||
|
else:
|
||||||
|
vx, ax = block(
|
||||||
|
(vx, ax),
|
||||||
|
v_context=v_context,
|
||||||
|
a_context=a_context,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
v_timestep=v_timestep,
|
||||||
|
a_timestep=a_timestep,
|
||||||
|
v_pe=v_pe,
|
||||||
|
a_pe=a_pe,
|
||||||
|
v_cross_pe=av_cross_video_freq_cis,
|
||||||
|
a_cross_pe=av_cross_audio_freq_cis,
|
||||||
|
v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep,
|
||||||
|
a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep,
|
||||||
|
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||||
|
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [vx, ax]
|
||||||
|
|
||||||
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||||
|
vx = x[0]
|
||||||
|
ax = x[1]
|
||||||
|
v_embedded_timestep = embedded_timestep[0]
|
||||||
|
a_embedded_timestep = embedded_timestep[1]
|
||||||
|
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
||||||
|
|
||||||
|
# Process audio output
|
||||||
|
a_scale_shift_values = (
|
||||||
|
self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype)
|
||||||
|
+ a_embedded_timestep[:, :, None]
|
||||||
|
)
|
||||||
|
a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1]
|
||||||
|
|
||||||
|
ax = self.audio_norm_out(ax)
|
||||||
|
ax = ax * (1 + a_scale) + a_shift
|
||||||
|
ax = self.audio_proj_out(ax)
|
||||||
|
|
||||||
|
# Unpatchify audio
|
||||||
|
ax = self.a_patchifier.unpatchify(
|
||||||
|
ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recombine audio and video
|
||||||
|
original_shape = kwargs.get("av_orig_shape")
|
||||||
|
return self.recombine_audio_and_video_latents(vx, ax, original_shape)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
context,
|
||||||
|
attention_mask=None,
|
||||||
|
frame_rate=25,
|
||||||
|
transformer_options={},
|
||||||
|
keyframe_idxs=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass for LTXAV model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Combined audio-video input tensor
|
||||||
|
timestep: Tuple of (video_timestep, audio_timestep) or single timestep
|
||||||
|
context: Context tensor (e.g., text embeddings)
|
||||||
|
attention_mask: Attention mask tensor
|
||||||
|
frame_rate: Frame rate for temporal processing
|
||||||
|
transformer_options: Additional options for transformer blocks
|
||||||
|
keyframe_idxs: Keyframe indices for temporal processing
|
||||||
|
**kwargs: Additional keyword arguments including audio_length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined audio-video output tensor
|
||||||
|
"""
|
||||||
|
# Handle timestep format
|
||||||
|
if isinstance(timestep, (tuple, list)) and len(timestep) == 2:
|
||||||
|
v_timestep, a_timestep = timestep
|
||||||
|
kwargs["a_timestep"] = a_timestep
|
||||||
|
timestep = v_timestep
|
||||||
|
else:
|
||||||
|
kwargs["a_timestep"] = timestep
|
||||||
|
|
||||||
|
# Call parent forward method
|
||||||
|
return super().forward(
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
context,
|
||||||
|
attention_mask,
|
||||||
|
frame_rate,
|
||||||
|
transformer_options,
|
||||||
|
keyframe_idxs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
305
comfy/ldm/lightricks/embeddings_connector.py
Normal file
305
comfy/ldm/lightricks/embeddings_connector.py
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
import torch
|
||||||
|
from comfy.ldm.lightricks.model import (
|
||||||
|
CrossAttention,
|
||||||
|
FeedForward,
|
||||||
|
generate_freq_grid_np,
|
||||||
|
interleaved_freqs_cis,
|
||||||
|
split_freqs_cis,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTransformerBlock1D(nn.Module):
|
||||||
|
r"""
|
||||||
|
A basic Transformer block.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
|
||||||
|
dim (`int`): The number of channels in the input and output.
|
||||||
|
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||||
|
attention_head_dim (`int`): The number of channels in each head.
|
||||||
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||||
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||||
|
attention_bias (:
|
||||||
|
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||||
|
upcast_attention (`bool`, *optional*):
|
||||||
|
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||||
|
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use learnable elementwise affine parameters for normalization.
|
||||||
|
standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
|
||||||
|
norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers.
|
||||||
|
qk_norm (`str`, *optional*, defaults to None):
|
||||||
|
Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
|
||||||
|
final_dropout (`bool` *optional*, defaults to False):
|
||||||
|
Whether to apply a final dropout after the last feed-forward layer.
|
||||||
|
ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`.
|
||||||
|
ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer.
|
||||||
|
attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer.
|
||||||
|
use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE).
|
||||||
|
ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
context_dim=None,
|
||||||
|
attn_precision=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Define 3 blocks. Each block has its own normalization layer.
|
||||||
|
# 1. Self-Attn
|
||||||
|
self.attn1 = CrossAttention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
context_dim=None,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Feed-forward
|
||||||
|
self.ff = FeedForward(
|
||||||
|
dim,
|
||||||
|
dim_out=dim,
|
||||||
|
glu=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor:
|
||||||
|
|
||||||
|
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||||
|
|
||||||
|
# 1. Normalization Before Self-Attention
|
||||||
|
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||||
|
|
||||||
|
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
# 2. Self-Attention
|
||||||
|
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
|
||||||
|
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
if hidden_states.ndim == 4:
|
||||||
|
hidden_states = hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
# 3. Normalization before Feed-Forward
|
||||||
|
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||||
|
|
||||||
|
# 4. Feed-forward
|
||||||
|
ff_output = self.ff(norm_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = ff_output + hidden_states
|
||||||
|
if hidden_states.ndim == 4:
|
||||||
|
hidden_states = hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Embeddings1DConnector(nn.Module):
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=128,
|
||||||
|
cross_attention_dim=2048,
|
||||||
|
attention_head_dim=128,
|
||||||
|
num_attention_heads=30,
|
||||||
|
num_layers=2,
|
||||||
|
positional_embedding_theta=10000.0,
|
||||||
|
positional_embedding_max_pos=[4096],
|
||||||
|
causal_temporal_positioning=False,
|
||||||
|
num_learnable_registers: Optional[int] = 128,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
split_rope=False,
|
||||||
|
double_precision_rope=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.out_channels = in_channels
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.causal_temporal_positioning = causal_temporal_positioning
|
||||||
|
self.positional_embedding_theta = positional_embedding_theta
|
||||||
|
self.positional_embedding_max_pos = positional_embedding_max_pos
|
||||||
|
self.split_rope = split_rope
|
||||||
|
self.double_precision_rope = double_precision_rope
|
||||||
|
self.transformer_1d_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicTransformerBlock1D(
|
||||||
|
self.inner_dim,
|
||||||
|
num_attention_heads,
|
||||||
|
attention_head_dim,
|
||||||
|
context_dim=cross_attention_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.num_learnable_registers = num_learnable_registers
|
||||||
|
if self.num_learnable_registers:
|
||||||
|
self.learnable_registers = nn.Parameter(
|
||||||
|
torch.rand(
|
||||||
|
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
* 2.0
|
||||||
|
- 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_fractional_positions(self, indices_grid):
|
||||||
|
fractional_positions = torch.stack(
|
||||||
|
[
|
||||||
|
indices_grid[:, i] / self.positional_embedding_max_pos[i]
|
||||||
|
for i in range(1)
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
return fractional_positions
|
||||||
|
|
||||||
|
def precompute_freqs(self, indices_grid, spacing):
|
||||||
|
source_dtype = indices_grid.dtype
|
||||||
|
dtype = (
|
||||||
|
torch.float32
|
||||||
|
if source_dtype in (torch.bfloat16, torch.float16)
|
||||||
|
else source_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
fractional_positions = self.get_fractional_positions(indices_grid)
|
||||||
|
indices = (
|
||||||
|
generate_freq_grid_np(
|
||||||
|
self.positional_embedding_theta,
|
||||||
|
indices_grid.shape[1],
|
||||||
|
self.inner_dim,
|
||||||
|
)
|
||||||
|
if self.double_precision_rope
|
||||||
|
else self.generate_freq_grid(spacing, dtype, fractional_positions.device)
|
||||||
|
).to(device=fractional_positions.device)
|
||||||
|
|
||||||
|
if spacing == "exp_2":
|
||||||
|
freqs = (
|
||||||
|
(indices * fractional_positions.unsqueeze(-1))
|
||||||
|
.transpose(-1, -2)
|
||||||
|
.flatten(2)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
freqs = (
|
||||||
|
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||||
|
.transpose(-1, -2)
|
||||||
|
.flatten(2)
|
||||||
|
)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
def generate_freq_grid(self, spacing, dtype, device):
|
||||||
|
dim = self.inner_dim
|
||||||
|
theta = self.positional_embedding_theta
|
||||||
|
n_pos_dims = 1
|
||||||
|
n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6
|
||||||
|
start = 1
|
||||||
|
end = theta
|
||||||
|
|
||||||
|
if spacing == "exp":
|
||||||
|
indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem))
|
||||||
|
indices = indices.to(dtype=dtype, device=device)
|
||||||
|
elif spacing == "exp_2":
|
||||||
|
indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim)
|
||||||
|
indices = indices.to(dtype=dtype)
|
||||||
|
elif spacing == "linear":
|
||||||
|
indices = torch.linspace(
|
||||||
|
start, end, dim // n_elem, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
elif spacing == "sqrt":
|
||||||
|
indices = torch.linspace(
|
||||||
|
start**2, end**2, dim // n_elem, device=device, dtype=dtype
|
||||||
|
).sqrt()
|
||||||
|
|
||||||
|
indices = indices * math.pi / 2
|
||||||
|
|
||||||
|
return indices
|
||||||
|
|
||||||
|
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
||||||
|
dim = self.inner_dim
|
||||||
|
n_elem = 2 # 2 because of cos and sin
|
||||||
|
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||||
|
if self.split_rope:
|
||||||
|
expected_freqs = dim // 2
|
||||||
|
current_freqs = freqs.shape[-1]
|
||||||
|
pad_size = expected_freqs - current_freqs
|
||||||
|
cos_freq, sin_freq = split_freqs_cis(
|
||||||
|
freqs, pad_size, self.num_attention_heads
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||||
|
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
The [`Transformer2DModel`] forward method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||||
|
Input `hidden_states`.
|
||||||
|
indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
|
||||||
|
attention_mask ( `torch.Tensor`, *optional*):
|
||||||
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||||
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||||
|
negative values to the attention scores corresponding to "discard" tokens.
|
||||||
|
Returns:
|
||||||
|
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||||
|
`tuple` where the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
# 1. Input
|
||||||
|
|
||||||
|
if self.num_learnable_registers:
|
||||||
|
num_registers_duplications = math.ceil(
|
||||||
|
max(1024, hidden_states.shape[1]) / self.num_learnable_registers
|
||||||
|
)
|
||||||
|
learnable_registers = torch.tile(
|
||||||
|
self.learnable_registers.to(hidden_states), (num_registers_duplications, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device)
|
||||||
|
|
||||||
|
indices_grid = torch.arange(
|
||||||
|
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
indices_grid = indices_grid[None, None, :]
|
||||||
|
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
||||||
|
|
||||||
|
# 2. Blocks
|
||||||
|
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||||
|
hidden_states = block(
|
||||||
|
hidden_states, attention_mask=attention_mask, pe=freqs_cis
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Output
|
||||||
|
# if self.output_scale is not None:
|
||||||
|
# hidden_states = hidden_states / self.output_scale
|
||||||
|
|
||||||
|
hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, attention_mask
|
||||||
292
comfy/ldm/lightricks/latent_upsampler.py
Normal file
292
comfy/ldm/lightricks/latent_upsampler.py
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||||
|
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
|
||||||
|
if float(scale) not in mapping:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}"
|
||||||
|
)
|
||||||
|
return mapping[float(scale)]
|
||||||
|
|
||||||
|
|
||||||
|
class PixelShuffleND(nn.Module):
|
||||||
|
def __init__(self, dims, upscale_factors=(2, 2, 2)):
|
||||||
|
super().__init__()
|
||||||
|
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
|
||||||
|
self.dims = dims
|
||||||
|
self.upscale_factors = upscale_factors
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.dims == 3:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
p2=self.upscale_factors[1],
|
||||||
|
p3=self.upscale_factors[2],
|
||||||
|
)
|
||||||
|
elif self.dims == 2:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
p2=self.upscale_factors[1],
|
||||||
|
)
|
||||||
|
elif self.dims == 1:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1) f h w -> b c (f p1) h w",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlurDownsample(nn.Module):
|
||||||
|
"""
|
||||||
|
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
|
||||||
|
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dims: int, stride: int):
|
||||||
|
super().__init__()
|
||||||
|
assert dims in (2, 3)
|
||||||
|
assert stride >= 1 and isinstance(stride, int)
|
||||||
|
self.dims = dims
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
# 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized
|
||||||
|
k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0])
|
||||||
|
k2d = k[:, None] @ k[None, :]
|
||||||
|
k2d = (k2d / k2d.sum()).float() # shape (5,5)
|
||||||
|
self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.stride == 1:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _apply_2d(x2d: torch.Tensor) -> torch.Tensor:
|
||||||
|
# x2d: (B, C, H, W)
|
||||||
|
B, C, H, W = x2d.shape
|
||||||
|
weight = self.kernel.expand(C, 1, 5, 5) # depthwise
|
||||||
|
x2d = F.conv2d(
|
||||||
|
x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C
|
||||||
|
)
|
||||||
|
return x2d
|
||||||
|
|
||||||
|
if self.dims == 2:
|
||||||
|
return _apply_2d(x)
|
||||||
|
else:
|
||||||
|
# dims == 3: apply per-frame on H,W
|
||||||
|
b, c, f, h, w = x.shape
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = _apply_2d(x)
|
||||||
|
h2, w2 = x.shape[-2:]
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialRationalResampler(nn.Module):
|
||||||
|
"""
|
||||||
|
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
|
||||||
|
downsample by 'den' using fixed blur + stride. Operates on H,W only.
|
||||||
|
|
||||||
|
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mid_channels: int, scale: float):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.num, self.den = _rational_for_scale(self.scale)
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1
|
||||||
|
)
|
||||||
|
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
|
||||||
|
self.blur_down = BlurDownsample(dims=2, stride=self.den)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
b, c, f, h, w = x.shape
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.pixel_shuffle(x)
|
||||||
|
x = self.blur_down(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if mid_channels is None:
|
||||||
|
mid_channels = channels
|
||||||
|
|
||||||
|
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
|
||||||
|
|
||||||
|
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.norm1 = nn.GroupNorm(32, mid_channels)
|
||||||
|
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||||
|
self.norm2 = nn.GroupNorm(32, channels)
|
||||||
|
self.activation = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
residual = x
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.activation(x + residual)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LatentUpsampler(nn.Module):
|
||||||
|
"""
|
||||||
|
Model to spatially upsample VAE latents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (`int`): Number of channels in the input latent
|
||||||
|
mid_channels (`int`): Number of channels in the middle layers
|
||||||
|
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||||
|
dims (`int`): Number of dimensions for convolutions (2 or 3)
|
||||||
|
spatial_upsample (`bool`): Whether to spatially upsample the latent
|
||||||
|
temporal_upsample (`bool`): Whether to temporally upsample the latent
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 128,
|
||||||
|
mid_channels: int = 512,
|
||||||
|
num_blocks_per_stage: int = 4,
|
||||||
|
dims: int = 3,
|
||||||
|
spatial_upsample: bool = True,
|
||||||
|
temporal_upsample: bool = False,
|
||||||
|
spatial_scale: float = 2.0,
|
||||||
|
rational_resampler: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.mid_channels = mid_channels
|
||||||
|
self.num_blocks_per_stage = num_blocks_per_stage
|
||||||
|
self.dims = dims
|
||||||
|
self.spatial_upsample = spatial_upsample
|
||||||
|
self.temporal_upsample = temporal_upsample
|
||||||
|
self.spatial_scale = float(spatial_scale)
|
||||||
|
self.rational_resampler = rational_resampler
|
||||||
|
|
||||||
|
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
|
||||||
|
|
||||||
|
self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.initial_norm = nn.GroupNorm(32, mid_channels)
|
||||||
|
self.initial_activation = nn.SiLU()
|
||||||
|
|
||||||
|
self.res_blocks = nn.ModuleList(
|
||||||
|
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||||
|
)
|
||||||
|
|
||||||
|
if spatial_upsample and temporal_upsample:
|
||||||
|
self.upsampler = nn.Sequential(
|
||||||
|
nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(3),
|
||||||
|
)
|
||||||
|
elif spatial_upsample:
|
||||||
|
if rational_resampler:
|
||||||
|
self.upsampler = SpatialRationalResampler(
|
||||||
|
mid_channels=mid_channels, scale=self.spatial_scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.upsampler = nn.Sequential(
|
||||||
|
nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(2),
|
||||||
|
)
|
||||||
|
elif temporal_upsample:
|
||||||
|
self.upsampler = nn.Sequential(
|
||||||
|
nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Either spatial_upsample or temporal_upsample must be True"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.post_upsample_res_blocks = nn.ModuleList(
|
||||||
|
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
||||||
|
b, c, f, h, w = latent.shape
|
||||||
|
|
||||||
|
if self.dims == 2:
|
||||||
|
x = rearrange(latent, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.initial_conv(x)
|
||||||
|
x = self.initial_norm(x)
|
||||||
|
x = self.initial_activation(x)
|
||||||
|
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.upsampler(x)
|
||||||
|
|
||||||
|
for block in self.post_upsample_res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
else:
|
||||||
|
x = self.initial_conv(latent)
|
||||||
|
x = self.initial_norm(x)
|
||||||
|
x = self.initial_activation(x)
|
||||||
|
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
if self.temporal_upsample:
|
||||||
|
x = self.upsampler(x)
|
||||||
|
x = x[:, :, 1:, :, :]
|
||||||
|
else:
|
||||||
|
if isinstance(self.upsampler, SpatialRationalResampler):
|
||||||
|
x = self.upsampler(x)
|
||||||
|
else:
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.upsampler(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
|
||||||
|
for block in self.post_upsample_res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
return cls(
|
||||||
|
in_channels=config.get("in_channels", 4),
|
||||||
|
mid_channels=config.get("mid_channels", 128),
|
||||||
|
num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
|
||||||
|
dims=config.get("dims", 2),
|
||||||
|
spatial_upsample=config.get("spatial_upsample", True),
|
||||||
|
temporal_upsample=config.get("temporal_upsample", False),
|
||||||
|
spatial_scale=config.get("spatial_scale", 2.0),
|
||||||
|
rational_resampler=config.get("rational_resampler", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def config(self):
|
||||||
|
return {
|
||||||
|
"_class_name": "LatentUpsampler",
|
||||||
|
"in_channels": self.in_channels,
|
||||||
|
"mid_channels": self.mid_channels,
|
||||||
|
"num_blocks_per_stage": self.num_blocks_per_stage,
|
||||||
|
"dims": self.dims,
|
||||||
|
"spatial_upsample": self.spatial_upsample,
|
||||||
|
"temporal_upsample": self.temporal_upsample,
|
||||||
|
"spatial_scale": self.spatial_scale,
|
||||||
|
"rational_resampler": self.rational_resampler,
|
||||||
|
}
|
||||||
@ -1,13 +1,47 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
import functools
|
||||||
|
import math
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.ldm.modules.attention
|
import comfy.ldm.modules.attention
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import math
|
|
||||||
from typing import Dict, Optional, Tuple
|
|
||||||
|
|
||||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||||
from comfy.ldm.flux.math import apply_rope1
|
|
||||||
|
def _log_base(x, base):
|
||||||
|
return np.log(x) / np.log(base)
|
||||||
|
|
||||||
|
class LTXRopeType(str, Enum):
|
||||||
|
INTERLEAVED = "interleaved"
|
||||||
|
SPLIT = "split"
|
||||||
|
|
||||||
|
KEY = "rope_type"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, kwargs, default=None):
|
||||||
|
if default is None:
|
||||||
|
default = cls.INTERLEAVED
|
||||||
|
return cls(kwargs.get(cls.KEY, default))
|
||||||
|
|
||||||
|
|
||||||
|
class LTXFrequenciesPrecision(str, Enum):
|
||||||
|
FLOAT32 = "float32"
|
||||||
|
FLOAT64 = "float64"
|
||||||
|
|
||||||
|
KEY = "frequencies_precision"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, kwargs, default=None):
|
||||||
|
if default is None:
|
||||||
|
default = cls.FLOAT32
|
||||||
|
return cls(kwargs.get(cls.KEY, default))
|
||||||
|
|
||||||
|
|
||||||
def get_timestep_embedding(
|
def get_timestep_embedding(
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
@ -39,9 +73,7 @@ def get_timestep_embedding(
|
|||||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||||
|
|
||||||
half_dim = embedding_dim // 2
|
half_dim = embedding_dim // 2
|
||||||
exponent = -math.log(max_period) * torch.arange(
|
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
||||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
|
||||||
)
|
|
||||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||||
|
|
||||||
emb = torch.exp(exponent)
|
emb = torch.exp(exponent)
|
||||||
@ -73,7 +105,9 @@ class TimestepEmbedding(nn.Module):
|
|||||||
post_act_fn: Optional[str] = None,
|
post_act_fn: Optional[str] = None,
|
||||||
cond_proj_dim=None,
|
cond_proj_dim=None,
|
||||||
sample_proj_bias=True,
|
sample_proj_bias=True,
|
||||||
dtype=None, device=None, operations=None,
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -90,7 +124,9 @@ class TimestepEmbedding(nn.Module):
|
|||||||
time_embed_dim_out = out_dim
|
time_embed_dim_out = out_dim
|
||||||
else:
|
else:
|
||||||
time_embed_dim_out = time_embed_dim
|
time_embed_dim_out = time_embed_dim
|
||||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
|
self.linear_2 = operations.Linear(
|
||||||
|
time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
if post_act_fn is None:
|
if post_act_fn is None:
|
||||||
self.post_act = None
|
self.post_act = None
|
||||||
@ -139,12 +175,22 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
|||||||
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim,
|
||||||
|
size_emb_dim,
|
||||||
|
use_additional_conditions: bool = False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.outdim = size_emb_dim
|
self.outdim = size_emb_dim
|
||||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
|
self.timestep_embedder = TimestepEmbedding(
|
||||||
|
in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
||||||
timesteps_proj = self.time_proj(timestep)
|
timesteps_proj = self.time_proj(timestep)
|
||||||
@ -163,15 +209,22 @@ class AdaLayerNormSingle(nn.Module):
|
|||||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
def __init__(
|
||||||
|
self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||||
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
|
embedding_dim,
|
||||||
|
size_emb_dim=embedding_dim // 3,
|
||||||
|
use_additional_conditions=use_additional_conditions,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.silu = nn.SiLU()
|
self.silu = nn.SiLU()
|
||||||
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
|
self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -185,6 +238,7 @@ class AdaLayerNormSingle(nn.Module):
|
|||||||
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
||||||
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
||||||
|
|
||||||
|
|
||||||
class PixArtAlphaTextProjection(nn.Module):
|
class PixArtAlphaTextProjection(nn.Module):
|
||||||
"""
|
"""
|
||||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||||
@ -192,18 +246,24 @@ class PixArtAlphaTextProjection(nn.Module):
|
|||||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
|
def __init__(
|
||||||
|
self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if out_features is None:
|
if out_features is None:
|
||||||
out_features = hidden_size
|
out_features = hidden_size
|
||||||
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
|
self.linear_1 = operations.Linear(
|
||||||
|
in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
if act_fn == "gelu_tanh":
|
if act_fn == "gelu_tanh":
|
||||||
self.act_1 = nn.GELU(approximate="tanh")
|
self.act_1 = nn.GELU(approximate="tanh")
|
||||||
elif act_fn == "silu":
|
elif act_fn == "silu":
|
||||||
self.act_1 = nn.SiLU()
|
self.act_1 = nn.SiLU()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||||
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
|
self.linear_2 = operations.Linear(
|
||||||
|
in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, caption):
|
def forward(self, caption):
|
||||||
hidden_states = self.linear_1(caption)
|
hidden_states = self.linear_1(caption)
|
||||||
@ -222,23 +282,68 @@ class GELU_approx(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
|
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = int(dim * mult)
|
||||||
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
project_in,
|
project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||||
nn.Dropout(dropout),
|
|
||||||
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
def apply_rotary_emb(input_tensor, freqs_cis):
|
||||||
|
cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1]
|
||||||
|
split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False
|
||||||
|
return (
|
||||||
|
apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
||||||
|
if split_pe else
|
||||||
|
apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one
|
||||||
|
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||||
|
t1, t2 = t_dup.unbind(dim=-1)
|
||||||
|
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||||
|
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||||
|
|
||||||
|
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def apply_split_rotary_emb(input_tensor, cos, sin):
|
||||||
|
needs_reshape = False
|
||||||
|
if input_tensor.ndim != 4 and cos.ndim == 4:
|
||||||
|
B, H, T, _ = cos.shape
|
||||||
|
input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2)
|
||||||
|
needs_reshape = True
|
||||||
|
split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2)
|
||||||
|
first_half_input = split_input[..., :1, :]
|
||||||
|
second_half_input = split_input[..., 1:, :]
|
||||||
|
output = split_input * cos.unsqueeze(-2)
|
||||||
|
first_half_output = output[..., :1, :]
|
||||||
|
second_half_output = output[..., 1:, :]
|
||||||
|
first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input)
|
||||||
|
second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input)
|
||||||
|
output = rearrange(output, "... d r -> ... (d r)")
|
||||||
|
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim,
|
||||||
|
context_dim=None,
|
||||||
|
heads=8,
|
||||||
|
dim_head=64,
|
||||||
|
dropout=0.0,
|
||||||
|
attn_precision=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = query_dim if context_dim is None else context_dim
|
context_dim = query_dim if context_dim is None else context_dim
|
||||||
@ -254,9 +359,11 @@ class CrossAttention(nn.Module):
|
|||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(
|
||||||
|
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
|
def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}):
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = x if context is None else context
|
context = x if context is None else context
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
@ -266,8 +373,8 @@ class CrossAttention(nn.Module):
|
|||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
|
|
||||||
if pe is not None:
|
if pe is not None:
|
||||||
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
|
q = apply_rotary_emb(q, pe)
|
||||||
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
|
k = apply_rotary_emb(k, pe if k_pe is None else k_pe)
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
@ -277,14 +384,34 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
|
def __init__(
|
||||||
|
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attn_precision = attn_precision
|
self.attn_precision = attn_precision
|
||||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
self.attn1 = CrossAttention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
context_dim=None,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
self.attn2 = CrossAttention(
|
||||||
|
query_dim=dim,
|
||||||
|
context_dim=context_dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
@ -306,116 +433,446 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def get_fractional_positions(indices_grid, max_pos):
|
def get_fractional_positions(indices_grid, max_pos):
|
||||||
|
n_pos_dims = indices_grid.shape[1]
|
||||||
|
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
||||||
fractional_positions = torch.stack(
|
fractional_positions = torch.stack(
|
||||||
[
|
[indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],
|
||||||
indices_grid[:, i] / max_pos[i]
|
axis=-1,
|
||||||
for i in range(3)
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
)
|
||||||
return fractional_positions
|
return fractional_positions
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
@functools.lru_cache(maxsize=5)
|
||||||
dtype = torch.float32
|
def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None):
|
||||||
device = indices_grid.device
|
theta = positional_embedding_theta
|
||||||
|
start = 1
|
||||||
|
end = theta
|
||||||
|
|
||||||
|
n_elem = 2 * positional_embedding_max_pos_count
|
||||||
|
pow_indices = np.power(
|
||||||
|
theta,
|
||||||
|
np.linspace(
|
||||||
|
_log_base(start, theta),
|
||||||
|
_log_base(end, theta),
|
||||||
|
inner_dim // n_elem,
|
||||||
|
dtype=np.float64,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)
|
||||||
|
|
||||||
|
def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device):
|
||||||
|
theta = positional_embedding_theta
|
||||||
|
start = 1
|
||||||
|
end = theta
|
||||||
|
n_elem = 2 * positional_embedding_max_pos_count
|
||||||
|
|
||||||
|
indices = theta ** (
|
||||||
|
torch.linspace(
|
||||||
|
math.log(start, theta),
|
||||||
|
math.log(end, theta),
|
||||||
|
inner_dim // n_elem,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
indices = indices.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
indices = indices * math.pi / 2
|
||||||
|
|
||||||
|
return indices
|
||||||
|
|
||||||
|
def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid):
|
||||||
|
if use_middle_indices_grid:
|
||||||
|
assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2)
|
||||||
|
indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]
|
||||||
|
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
|
||||||
|
elif len(indices_grid.shape) == 4:
|
||||||
|
indices_grid = indices_grid[..., 0]
|
||||||
|
|
||||||
# Get fractional positions and compute frequency indices
|
# Get fractional positions and compute frequency indices
|
||||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||||
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
|
indices = indices.to(device=fractional_positions.device)
|
||||||
|
|
||||||
# Compute frequencies and apply cos/sin
|
freqs = (
|
||||||
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
|
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||||
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
|
.transpose(-1, -2)
|
||||||
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
|
.flatten(2)
|
||||||
|
)
|
||||||
|
return freqs
|
||||||
|
|
||||||
# Pad if dim is not divisible by 6
|
def interleaved_freqs_cis(freqs, pad_size):
|
||||||
if dim % 6 != 0:
|
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||||
padding_size = dim % 6
|
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||||
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
|
if pad_size != 0:
|
||||||
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
|
cos_padding = torch.ones_like(cos_freq[:, :, : pad_size])
|
||||||
|
sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size])
|
||||||
|
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||||
|
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||||
|
return cos_freq, sin_freq
|
||||||
|
|
||||||
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
|
def split_freqs_cis(freqs, pad_size, num_attention_heads):
|
||||||
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
cos_freq = freqs.cos()
|
||||||
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
sin_freq = freqs.sin()
|
||||||
|
|
||||||
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
|
if pad_size != 0:
|
||||||
freqs_cis = torch.stack([
|
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
|
||||||
torch.stack([cos_vals, -sin_vals], dim=-1),
|
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
|
||||||
torch.stack([sin_vals, cos_vals], dim=-1)
|
|
||||||
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
|
|
||||||
|
|
||||||
return freqs_cis
|
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
|
||||||
|
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
|
||||||
|
|
||||||
|
# Reshape freqs to be compatible with multi-head attention
|
||||||
|
B , T, half_HD = cos_freq.shape
|
||||||
|
|
||||||
class LTXVModel(torch.nn.Module):
|
cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
||||||
def __init__(self,
|
sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
||||||
in_channels=128,
|
|
||||||
cross_attention_dim=2048,
|
|
||||||
attention_head_dim=64,
|
|
||||||
num_attention_heads=32,
|
|
||||||
|
|
||||||
caption_channels=4096,
|
cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
|
||||||
num_layers=28,
|
sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
|
||||||
|
return cos_freq, sin_freq
|
||||||
|
|
||||||
|
class LTXBaseModel(torch.nn.Module, ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for LTX models (Lightricks Transformer models).
|
||||||
|
|
||||||
positional_embedding_theta=10000.0,
|
This class defines the common interface and shared functionality for all LTX models,
|
||||||
positional_embedding_max_pos=[20, 2048, 2048],
|
including LTXV (video) and LTXAV (audio-video) variants.
|
||||||
causal_temporal_positioning=False,
|
"""
|
||||||
vae_scale_factors=(8, 32, 32),
|
|
||||||
dtype=None, device=None, operations=None, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
cross_attention_dim: int,
|
||||||
|
attention_head_dim: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
caption_channels: int,
|
||||||
|
num_layers: int,
|
||||||
|
positional_embedding_theta: float = 10000.0,
|
||||||
|
positional_embedding_max_pos: list = [20, 2048, 2048],
|
||||||
|
causal_temporal_positioning: bool = False,
|
||||||
|
vae_scale_factors: tuple = (8, 32, 32),
|
||||||
|
use_middle_indices_grid=False,
|
||||||
|
timestep_scale_multiplier = 1000.0,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.generator = None
|
self.generator = None
|
||||||
self.vae_scale_factors = vae_scale_factors
|
self.vae_scale_factors = vae_scale_factors
|
||||||
|
self.use_middle_indices_grid = use_middle_indices_grid
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.out_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
self.attention_head_dim = attention_head_dim
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.caption_channels = caption_channels
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.positional_embedding_theta = positional_embedding_theta
|
||||||
|
self.positional_embedding_max_pos = positional_embedding_max_pos
|
||||||
|
self.split_positional_embedding = LTXRopeType.from_dict(kwargs)
|
||||||
|
self.freq_grid_generator = (
|
||||||
|
generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64
|
||||||
|
else generate_freq_grid_pytorch
|
||||||
|
)
|
||||||
self.causal_temporal_positioning = causal_temporal_positioning
|
self.causal_temporal_positioning = causal_temporal_positioning
|
||||||
|
self.operations = operations
|
||||||
|
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||||
|
|
||||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
# Common dimensions
|
||||||
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.out_channels = in_channels
|
||||||
|
|
||||||
|
# Initialize common components
|
||||||
|
self._init_common_components(device, dtype)
|
||||||
|
|
||||||
|
# Initialize model-specific components
|
||||||
|
self._init_model_components(device, dtype, **kwargs)
|
||||||
|
|
||||||
|
# Initialize transformer blocks
|
||||||
|
self._init_transformer_blocks(device, dtype, **kwargs)
|
||||||
|
|
||||||
|
# Initialize output components
|
||||||
|
self._init_output_components(device, dtype)
|
||||||
|
|
||||||
|
def _init_common_components(self, device, dtype):
|
||||||
|
"""Initialize components common to all LTX models
|
||||||
|
- patchify_proj: Linear projection for patchifying input
|
||||||
|
- adaln_single: AdaLN layer for timestep embedding
|
||||||
|
- caption_projection: Linear projection for caption embedding
|
||||||
|
"""
|
||||||
|
self.patchify_proj = self.operations.Linear(
|
||||||
|
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
self.adaln_single = AdaLayerNormSingle(
|
self.adaln_single = AdaLayerNormSingle(
|
||||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
|
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||||
)
|
)
|
||||||
|
|
||||||
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.caption_projection = PixArtAlphaTextProjection(
|
self.caption_projection = PixArtAlphaTextProjection(
|
||||||
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
|
in_features=self.caption_channels,
|
||||||
|
hidden_size=self.inner_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _init_model_components(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize model-specific components. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize transformer blocks. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _init_output_components(self, device, dtype):
|
||||||
|
"""Initialize output components. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||||
|
"""Process input data. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
|
||||||
|
"""Process transformer blocks. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||||
|
"""Process output data. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||||
|
"""Prepare timestep embeddings."""
|
||||||
|
grid_mask = kwargs.get("grid_mask", None)
|
||||||
|
if grid_mask is not None:
|
||||||
|
timestep = timestep[:, grid_mask]
|
||||||
|
|
||||||
|
timestep = timestep * self.timestep_scale_multiplier
|
||||||
|
timestep, embedded_timestep = self.adaln_single(
|
||||||
|
timestep.flatten(),
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||||
|
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||||
|
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||||
|
|
||||||
|
return timestep, embedded_timestep
|
||||||
|
|
||||||
|
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||||
|
"""Prepare context for transformer blocks."""
|
||||||
|
if self.caption_projection is not None:
|
||||||
|
context = self.caption_projection(context)
|
||||||
|
context = context.view(batch_size, -1, x.shape[-1])
|
||||||
|
|
||||||
|
return context, attention_mask
|
||||||
|
|
||||||
|
def _precompute_freqs_cis(
|
||||||
|
self,
|
||||||
|
indices_grid,
|
||||||
|
dim,
|
||||||
|
out_dtype,
|
||||||
|
theta=10000.0,
|
||||||
|
max_pos=[20, 2048, 2048],
|
||||||
|
use_middle_indices_grid=False,
|
||||||
|
num_attention_heads=32,
|
||||||
|
):
|
||||||
|
split_mode = self.split_positional_embedding == LTXRopeType.SPLIT
|
||||||
|
indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device)
|
||||||
|
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
|
||||||
|
|
||||||
|
if split_mode:
|
||||||
|
expected_freqs = dim // 2
|
||||||
|
current_freqs = freqs.shape[-1]
|
||||||
|
pad_size = expected_freqs - current_freqs
|
||||||
|
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
|
||||||
|
else:
|
||||||
|
# 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only
|
||||||
|
n_elem = 2 * indices_grid.shape[1]
|
||||||
|
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||||
|
return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode
|
||||||
|
|
||||||
|
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
||||||
|
"""Prepare positional embeddings."""
|
||||||
|
fractional_coords = pixel_coords.to(torch.float32)
|
||||||
|
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
||||||
|
pe = self._precompute_freqs_cis(
|
||||||
|
fractional_coords,
|
||||||
|
dim=self.inner_dim,
|
||||||
|
out_dtype=x_dtype,
|
||||||
|
max_pos=self.positional_embedding_max_pos,
|
||||||
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
)
|
||||||
|
return pe
|
||||||
|
|
||||||
|
def _prepare_attention_mask(self, attention_mask, x_dtype):
|
||||||
|
"""Prepare attention mask."""
|
||||||
|
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||||
|
attention_mask = (attention_mask - 1).to(x_dtype).reshape(
|
||||||
|
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||||
|
) * torch.finfo(x_dtype).max
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass for LTX models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
timestep: Timestep tensor
|
||||||
|
context: Context tensor (e.g., text embeddings)
|
||||||
|
attention_mask: Attention mask tensor
|
||||||
|
frame_rate: Frame rate for temporal processing
|
||||||
|
transformer_options: Additional options for transformer blocks
|
||||||
|
keyframe_idxs: Keyframe indices for temporal processing
|
||||||
|
**kwargs: Additional keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed output tensor
|
||||||
|
"""
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(
|
||||||
|
comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options
|
||||||
|
),
|
||||||
|
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Internal forward pass for LTX models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
timestep: Timestep tensor
|
||||||
|
context: Context tensor (e.g., text embeddings)
|
||||||
|
attention_mask: Attention mask tensor
|
||||||
|
frame_rate: Frame rate for temporal processing
|
||||||
|
transformer_options: Additional options for transformer blocks
|
||||||
|
keyframe_idxs: Keyframe indices for temporal processing
|
||||||
|
**kwargs: Additional keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed output tensor
|
||||||
|
"""
|
||||||
|
if isinstance(x, list):
|
||||||
|
input_dtype = x[0].dtype
|
||||||
|
batch_size = x[0].shape[0]
|
||||||
|
else:
|
||||||
|
input_dtype = x.dtype
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
# Process input
|
||||||
|
merged_args = {**transformer_options, **kwargs}
|
||||||
|
x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args)
|
||||||
|
merged_args.update(additional_args)
|
||||||
|
|
||||||
|
# Prepare timestep and context
|
||||||
|
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||||
|
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
||||||
|
|
||||||
|
# Prepare attention mask and positional embeddings
|
||||||
|
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
||||||
|
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
||||||
|
|
||||||
|
# Process transformer blocks
|
||||||
|
x = self._process_transformer_blocks(
|
||||||
|
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process output
|
||||||
|
x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVModel(LTXBaseModel):
|
||||||
|
"""LTXV model for video generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=128,
|
||||||
|
cross_attention_dim=2048,
|
||||||
|
attention_head_dim=64,
|
||||||
|
num_attention_heads=32,
|
||||||
|
caption_channels=4096,
|
||||||
|
num_layers=28,
|
||||||
|
positional_embedding_theta=10000.0,
|
||||||
|
positional_embedding_max_pos=[20, 2048, 2048],
|
||||||
|
causal_temporal_positioning=False,
|
||||||
|
vae_scale_factors=(8, 32, 32),
|
||||||
|
use_middle_indices_grid=False,
|
||||||
|
timestep_scale_multiplier = 1000.0,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
in_channels=in_channels,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
caption_channels=caption_channels,
|
||||||
|
num_layers=num_layers,
|
||||||
|
positional_embedding_theta=positional_embedding_theta,
|
||||||
|
positional_embedding_max_pos=positional_embedding_max_pos,
|
||||||
|
causal_temporal_positioning=causal_temporal_positioning,
|
||||||
|
vae_scale_factors=vae_scale_factors,
|
||||||
|
use_middle_indices_grid=use_middle_indices_grid,
|
||||||
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_model_components(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize LTXV-specific components."""
|
||||||
|
# No additional components needed for LTXV beyond base class
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize transformer blocks for LTXV."""
|
||||||
self.transformer_blocks = nn.ModuleList(
|
self.transformer_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
BasicTransformerBlock(
|
BasicTransformerBlock(
|
||||||
self.inner_dim,
|
self.inner_dim,
|
||||||
num_attention_heads,
|
self.num_attention_heads,
|
||||||
attention_head_dim,
|
self.attention_head_dim,
|
||||||
context_dim=cross_attention_dim,
|
context_dim=self.cross_attention_dim,
|
||||||
# attn_precision=attn_precision,
|
dtype=dtype,
|
||||||
dtype=dtype, device=device, operations=operations
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
)
|
)
|
||||||
for d in range(num_layers)
|
for _ in range(self.num_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _init_output_components(self, device, dtype):
|
||||||
|
"""Initialize output components for LTXV."""
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
||||||
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.norm_out = self.operations.LayerNorm(
|
||||||
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||||
|
)
|
||||||
self.patchifier = SymmetricPatchifier(1)
|
self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||||
|
self.patchifier = SymmetricPatchifier(1, start_end=True)
|
||||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
|
||||||
self._forward,
|
|
||||||
self,
|
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
|
||||||
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
|
|
||||||
|
|
||||||
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
|
||||||
|
|
||||||
orig_shape = list(x.shape)
|
|
||||||
|
|
||||||
|
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||||
|
"""Process input for LTXV."""
|
||||||
|
additional_args = {"orig_shape": list(x.shape)}
|
||||||
x, latent_coords = self.patchifier.patchify(x)
|
x, latent_coords = self.patchifier.patchify(x)
|
||||||
pixel_coords = latent_to_pixel_coords(
|
pixel_coords = latent_to_pixel_coords(
|
||||||
latent_coords=latent_coords,
|
latent_coords=latent_coords,
|
||||||
@ -423,44 +880,30 @@ class LTXVModel(torch.nn.Module):
|
|||||||
causal_fix=self.causal_temporal_positioning,
|
causal_fix=self.causal_temporal_positioning,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
grid_mask = None
|
||||||
if keyframe_idxs is not None:
|
if keyframe_idxs is not None:
|
||||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
|
additional_args.update({ "orig_patchified_shape": list(x.shape)})
|
||||||
|
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
|
||||||
|
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
|
||||||
|
additional_args.update({"grid_mask": grid_mask})
|
||||||
|
x = x[:, grid_mask, :]
|
||||||
|
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
||||||
|
|
||||||
fractional_coords = pixel_coords.to(torch.float32)
|
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
||||||
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||||
|
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||||
|
|
||||||
x = self.patchify_proj(x)
|
x = self.patchify_proj(x)
|
||||||
timestep = timestep * 1000.0
|
return x, pixel_coords, additional_args
|
||||||
|
|
||||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
|
||||||
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
|
||||||
|
|
||||||
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
|
|
||||||
|
|
||||||
batch_size = x.shape[0]
|
|
||||||
timestep, embedded_timestep = self.adaln_single(
|
|
||||||
timestep.flatten(),
|
|
||||||
{"resolution": None, "aspect_ratio": None},
|
|
||||||
batch_size=batch_size,
|
|
||||||
hidden_dtype=x.dtype,
|
|
||||||
)
|
|
||||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
|
||||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
|
||||||
embedded_timestep = embedded_timestep.view(
|
|
||||||
batch_size, -1, embedded_timestep.shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Blocks
|
|
||||||
if self.caption_projection is not None:
|
|
||||||
batch_size = x.shape[0]
|
|
||||||
context = self.caption_projection(context)
|
|
||||||
context = context.view(
|
|
||||||
batch_size, -1, x.shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
|
||||||
|
"""Process transformer blocks for LTXV."""
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
|
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||||
@ -478,16 +921,28 @@ class LTXVModel(torch.nn.Module):
|
|||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Output
|
return x
|
||||||
|
|
||||||
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||||
|
"""Process output for LTXV."""
|
||||||
|
# Apply scale-shift modulation
|
||||||
scale_shift_values = (
|
scale_shift_values = (
|
||||||
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||||
)
|
)
|
||||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||||
|
|
||||||
x = self.norm_out(x)
|
x = self.norm_out(x)
|
||||||
# Modulation
|
x = x * (1 + scale) + shift
|
||||||
x = torch.addcmul(x, x, scale).add_(shift)
|
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
|
|
||||||
|
if keyframe_idxs is not None:
|
||||||
|
grid_mask = kwargs["grid_mask"]
|
||||||
|
orig_patchified_shape = kwargs["orig_patchified_shape"]
|
||||||
|
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
|
||||||
|
full_x[:, grid_mask, :] = x
|
||||||
|
x = full_x
|
||||||
|
# Unpatchify to restore original dimensions
|
||||||
|
orig_shape = kwargs["orig_shape"]
|
||||||
x = self.patchifier.unpatchify(
|
x = self.patchifier.unpatchify(
|
||||||
latents=x,
|
latents=x,
|
||||||
output_height=orig_shape[3],
|
output_height=orig_shape[3],
|
||||||
|
|||||||
@ -21,20 +21,23 @@ def latent_to_pixel_coords(
|
|||||||
Returns:
|
Returns:
|
||||||
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
||||||
"""
|
"""
|
||||||
|
shape = [1] * latent_coords.ndim
|
||||||
|
shape[1] = -1
|
||||||
pixel_coords = (
|
pixel_coords = (
|
||||||
latent_coords
|
latent_coords
|
||||||
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
|
* torch.tensor(scale_factors, device=latent_coords.device).view(*shape)
|
||||||
)
|
)
|
||||||
if causal_fix:
|
if causal_fix:
|
||||||
# Fix temporal scale for first frame to 1 due to causality
|
# Fix temporal scale for first frame to 1 due to causality
|
||||||
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
|
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
||||||
return pixel_coords
|
return pixel_coords
|
||||||
|
|
||||||
|
|
||||||
class Patchifier(ABC):
|
class Patchifier(ABC):
|
||||||
def __init__(self, patch_size: int):
|
def __init__(self, patch_size: int, start_end: bool=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._patch_size = (1, patch_size, patch_size)
|
self._patch_size = (1, patch_size, patch_size)
|
||||||
|
self.start_end = start_end
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def patchify(
|
def patchify(
|
||||||
@ -71,11 +74,23 @@ class Patchifier(ABC):
|
|||||||
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
||||||
indexing="ij",
|
indexing="ij",
|
||||||
)
|
)
|
||||||
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
|
latent_sample_coords_start = torch.stack(latent_sample_coords, dim=0)
|
||||||
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
delta = torch.tensor(self._patch_size, device=latent_sample_coords_start.device, dtype=latent_sample_coords_start.dtype)[:, None, None, None]
|
||||||
latent_coords = rearrange(
|
latent_sample_coords_end = latent_sample_coords_start + delta
|
||||||
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
|
|
||||||
|
latent_sample_coords_start = latent_sample_coords_start.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||||
|
latent_sample_coords_start = rearrange(
|
||||||
|
latent_sample_coords_start, "b c f h w -> b c (f h w)", b=batch_size
|
||||||
)
|
)
|
||||||
|
if self.start_end:
|
||||||
|
latent_sample_coords_end = latent_sample_coords_end.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||||
|
latent_sample_coords_end = rearrange(
|
||||||
|
latent_sample_coords_end, "b c f h w -> b c (f h w)", b=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_coords = torch.stack((latent_sample_coords_start, latent_sample_coords_end), dim=-1)
|
||||||
|
else:
|
||||||
|
latent_coords = latent_sample_coords_start
|
||||||
return latent_coords
|
return latent_coords
|
||||||
|
|
||||||
|
|
||||||
@ -115,3 +130,61 @@ class SymmetricPatchifier(Patchifier):
|
|||||||
q=self._patch_size[2],
|
q=self._patch_size[2],
|
||||||
)
|
)
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
class AudioPatchifier(Patchifier):
|
||||||
|
def __init__(self, patch_size: int,
|
||||||
|
sample_rate=16000,
|
||||||
|
hop_length=160,
|
||||||
|
audio_latent_downsample_factor=4,
|
||||||
|
is_causal=True,
|
||||||
|
start_end=False,
|
||||||
|
shift = 0
|
||||||
|
):
|
||||||
|
super().__init__(patch_size, start_end=start_end)
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
||||||
|
self.is_causal = is_causal
|
||||||
|
self.shift = shift
|
||||||
|
|
||||||
|
def copy_with_shift(self, shift):
|
||||||
|
return AudioPatchifier(
|
||||||
|
self.patch_size, self.sample_rate, self.hop_length, self.audio_latent_downsample_factor,
|
||||||
|
self.is_causal, self.start_end, shift
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_audio_latent_time_in_sec(self, start_latent, end_latent: int, dtype: torch.dtype, device=torch.device):
|
||||||
|
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
|
||||||
|
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
|
||||||
|
if self.is_causal:
|
||||||
|
audio_mel_frame = (audio_mel_frame + 1 - self.audio_latent_downsample_factor).clip(min=0)
|
||||||
|
return audio_mel_frame * self.hop_length / self.sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def patchify(self, audio_latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# audio_latents: (batch, channels, time, freq)
|
||||||
|
b, _, t, _ = audio_latents.shape
|
||||||
|
audio_latents = rearrange(
|
||||||
|
audio_latents,
|
||||||
|
"b c t f -> b t (c f)",
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_latents_start_timings = self._get_audio_latent_time_in_sec(self.shift, t + self.shift, torch.float32, audio_latents.device)
|
||||||
|
audio_latents_start_timings = audio_latents_start_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
|
||||||
|
|
||||||
|
if self.start_end:
|
||||||
|
audio_latents_end_timings = self._get_audio_latent_time_in_sec(self.shift + 1, t + self.shift + 1, torch.float32, audio_latents.device)
|
||||||
|
audio_latents_end_timings = audio_latents_end_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
|
||||||
|
|
||||||
|
audio_latents_timings = torch.stack([audio_latents_start_timings, audio_latents_end_timings], dim=-1)
|
||||||
|
else:
|
||||||
|
audio_latents_timings = audio_latents_start_timings
|
||||||
|
return audio_latents, audio_latents_timings
|
||||||
|
|
||||||
|
def unpatchify(self, audio_latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor:
|
||||||
|
# audio_latents: (batch, time, freq * channels)
|
||||||
|
audio_latents = rearrange(
|
||||||
|
audio_latents, "b t (c f) -> b c t f", c=channels, f=freq
|
||||||
|
)
|
||||||
|
return audio_latents
|
||||||
|
|||||||
286
comfy/ldm/lightricks/vae/audio_vae.py
Normal file
286
comfy/ldm/lightricks/vae/audio_vae.py
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.model_patcher
|
||||||
|
import comfy.utils as utils
|
||||||
|
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
|
||||||
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
|
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
||||||
|
CausalityAxis,
|
||||||
|
CausalAudioAutoencoder,
|
||||||
|
)
|
||||||
|
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
|
||||||
|
|
||||||
|
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AudioVAEComponentConfig:
|
||||||
|
"""Container for model component configuration extracted from metadata."""
|
||||||
|
|
||||||
|
autoencoder: dict
|
||||||
|
vocoder: dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig":
|
||||||
|
assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE"
|
||||||
|
|
||||||
|
raw_config = metadata["config"]
|
||||||
|
if isinstance(raw_config, str):
|
||||||
|
parsed_config = json.loads(raw_config)
|
||||||
|
else:
|
||||||
|
parsed_config = raw_config
|
||||||
|
|
||||||
|
audio_config = parsed_config.get("audio_vae")
|
||||||
|
vocoder_config = parsed_config.get("vocoder")
|
||||||
|
|
||||||
|
assert audio_config is not None, "Audio VAE config is required for audio VAE"
|
||||||
|
assert vocoder_config is not None, "Vocoder config is required for audio VAE"
|
||||||
|
|
||||||
|
return cls(autoencoder=audio_config, vocoder=vocoder_config)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDeviceManager:
|
||||||
|
"""Manages device placement and GPU residency for the composed model."""
|
||||||
|
|
||||||
|
def __init__(self, module: torch.nn.Module):
|
||||||
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
offload_device = comfy.model_management.vae_offload_device()
|
||||||
|
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
|
||||||
|
|
||||||
|
def ensure_model_loaded(self) -> None:
|
||||||
|
comfy.model_management.free_memory(
|
||||||
|
self.patcher.model_size(),
|
||||||
|
self.patcher.load_device,
|
||||||
|
)
|
||||||
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
|
|
||||||
|
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
return tensor.to(self.patcher.load_device)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def load_device(self):
|
||||||
|
return self.patcher.load_device
|
||||||
|
|
||||||
|
|
||||||
|
class AudioLatentNormalizer:
|
||||||
|
"""Applies per-channel statistics in patch space and restores original layout."""
|
||||||
|
|
||||||
|
def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module):
|
||||||
|
self.patchifier = patchfier
|
||||||
|
self.statistics = statistics_processor
|
||||||
|
|
||||||
|
def normalize(self, latents: torch.Tensor) -> torch.Tensor:
|
||||||
|
channels = latents.shape[1]
|
||||||
|
freq = latents.shape[3]
|
||||||
|
patched, _ = self.patchifier.patchify(latents)
|
||||||
|
normalized = self.statistics.normalize(patched)
|
||||||
|
return self.patchifier.unpatchify(normalized, channels=channels, freq=freq)
|
||||||
|
|
||||||
|
def denormalize(self, latents: torch.Tensor) -> torch.Tensor:
|
||||||
|
channels = latents.shape[1]
|
||||||
|
freq = latents.shape[3]
|
||||||
|
patched, _ = self.patchifier.patchify(latents)
|
||||||
|
denormalized = self.statistics.un_normalize(patched)
|
||||||
|
return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioPreprocessor:
|
||||||
|
"""Prepares raw waveforms for the autoencoder by matching training conditions."""
|
||||||
|
|
||||||
|
def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int):
|
||||||
|
self.target_sample_rate = target_sample_rate
|
||||||
|
self.mel_bins = mel_bins
|
||||||
|
self.mel_hop_length = mel_hop_length
|
||||||
|
self.n_fft = n_fft
|
||||||
|
|
||||||
|
def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor:
|
||||||
|
if source_rate == self.target_sample_rate:
|
||||||
|
return waveform
|
||||||
|
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_amplitude(
|
||||||
|
waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5
|
||||||
|
) -> torch.Tensor:
|
||||||
|
waveform = waveform - waveform.mean(dim=2, keepdim=True)
|
||||||
|
peak = torch.max(torch.abs(waveform)) + eps
|
||||||
|
scale = peak.clamp(max=max_amplitude) / peak
|
||||||
|
return waveform * scale
|
||||||
|
|
||||||
|
def waveform_to_mel(
|
||||||
|
self, waveform: torch.Tensor, waveform_sample_rate: int, device
|
||||||
|
) -> torch.Tensor:
|
||||||
|
waveform = self.resample(waveform, waveform_sample_rate)
|
||||||
|
waveform = self.normalize_amplitude(waveform)
|
||||||
|
|
||||||
|
mel_transform = torchaudio.transforms.MelSpectrogram(
|
||||||
|
sample_rate=self.target_sample_rate,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
win_length=self.n_fft,
|
||||||
|
hop_length=self.mel_hop_length,
|
||||||
|
f_min=0.0,
|
||||||
|
f_max=self.target_sample_rate / 2.0,
|
||||||
|
n_mels=self.mel_bins,
|
||||||
|
window_fn=torch.hann_window,
|
||||||
|
center=True,
|
||||||
|
pad_mode="reflect",
|
||||||
|
power=1.0,
|
||||||
|
mel_scale="slaney",
|
||||||
|
norm="slaney",
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
mel = mel_transform(waveform)
|
||||||
|
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||||
|
return mel.permute(0, 1, 3, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class AudioVAE(torch.nn.Module):
|
||||||
|
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
|
||||||
|
|
||||||
|
def __init__(self, state_dict: dict, metadata: dict):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
component_config = AudioVAEComponentConfig.from_metadata(metadata)
|
||||||
|
|
||||||
|
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
|
||||||
|
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
||||||
|
|
||||||
|
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||||
|
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||||
|
|
||||||
|
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||||
|
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||||
|
|
||||||
|
autoencoder_config = self.autoencoder.get_config()
|
||||||
|
self.normalizer = AudioLatentNormalizer(
|
||||||
|
AudioPatchifier(
|
||||||
|
patch_size=1,
|
||||||
|
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||||
|
sample_rate=autoencoder_config["sampling_rate"],
|
||||||
|
hop_length=autoencoder_config["mel_hop_length"],
|
||||||
|
is_causal=autoencoder_config["is_causal"],
|
||||||
|
),
|
||||||
|
self.autoencoder.per_channel_statistics,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.preprocessor = AudioPreprocessor(
|
||||||
|
target_sample_rate=autoencoder_config["sampling_rate"],
|
||||||
|
mel_bins=autoencoder_config["mel_bins"],
|
||||||
|
mel_hop_length=autoencoder_config["mel_hop_length"],
|
||||||
|
n_fft=autoencoder_config["n_fft"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.device_manager = ModelDeviceManager(self)
|
||||||
|
|
||||||
|
def encode(self, audio: dict) -> torch.Tensor:
|
||||||
|
"""Encode a waveform dictionary into normalized latent tensors."""
|
||||||
|
|
||||||
|
waveform = audio["waveform"]
|
||||||
|
waveform_sample_rate = audio["sample_rate"]
|
||||||
|
input_device = waveform.device
|
||||||
|
# Ensure that Audio VAE is loaded on the correct device.
|
||||||
|
self.device_manager.ensure_model_loaded()
|
||||||
|
|
||||||
|
waveform = self.device_manager.move_to_load_device(waveform)
|
||||||
|
expected_channels = self.autoencoder.encoder.in_channels
|
||||||
|
if waveform.shape[1] != expected_channels:
|
||||||
|
raise ValueError(
|
||||||
|
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
mel_spec = self.preprocessor.waveform_to_mel(
|
||||||
|
waveform, waveform_sample_rate, device=self.device_manager.load_device
|
||||||
|
)
|
||||||
|
|
||||||
|
latents = self.autoencoder.encode(mel_spec)
|
||||||
|
posterior = DiagonalGaussianDistribution(latents)
|
||||||
|
latent_mode = posterior.mode()
|
||||||
|
|
||||||
|
normalized = self.normalizer.normalize(latent_mode)
|
||||||
|
return normalized.to(input_device)
|
||||||
|
|
||||||
|
def decode(self, latents: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Decode normalized latent tensors into an audio waveform."""
|
||||||
|
original_shape = latents.shape
|
||||||
|
|
||||||
|
# Ensure that Audio VAE is loaded on the correct device.
|
||||||
|
self.device_manager.ensure_model_loaded()
|
||||||
|
|
||||||
|
latents = self.device_manager.move_to_load_device(latents)
|
||||||
|
latents = self.normalizer.denormalize(latents)
|
||||||
|
|
||||||
|
target_shape = self.target_shape_from_latents(original_shape)
|
||||||
|
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
|
||||||
|
|
||||||
|
waveform = self.run_vocoder(mel_spec)
|
||||||
|
return self.device_manager.move_to_load_device(waveform)
|
||||||
|
|
||||||
|
def target_shape_from_latents(self, latents_shape):
|
||||||
|
batch, _, time, _ = latents_shape
|
||||||
|
target_length = time * LATENT_DOWNSAMPLE_FACTOR
|
||||||
|
if self.autoencoder.causality_axis != CausalityAxis.NONE:
|
||||||
|
target_length -= LATENT_DOWNSAMPLE_FACTOR - 1
|
||||||
|
return (
|
||||||
|
batch,
|
||||||
|
self.autoencoder.decoder.out_ch,
|
||||||
|
target_length,
|
||||||
|
self.autoencoder.mel_bins,
|
||||||
|
)
|
||||||
|
|
||||||
|
def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int:
|
||||||
|
return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second)
|
||||||
|
|
||||||
|
def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
audio_channels = self.autoencoder.decoder.out_ch
|
||||||
|
vocoder_input = mel_spec.transpose(2, 3)
|
||||||
|
|
||||||
|
if audio_channels == 1:
|
||||||
|
vocoder_input = vocoder_input.squeeze(1)
|
||||||
|
elif audio_channels != 2:
|
||||||
|
raise ValueError(f"Unsupported audio_channels: {audio_channels}")
|
||||||
|
|
||||||
|
return self.vocoder(vocoder_input)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_rate(self) -> int:
|
||||||
|
return int(self.autoencoder.sampling_rate)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mel_hop_length(self) -> int:
|
||||||
|
return int(self.autoencoder.mel_hop_length)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mel_bins(self) -> int:
|
||||||
|
return int(self.autoencoder.mel_bins)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latent_channels(self) -> int:
|
||||||
|
return int(self.autoencoder.decoder.z_channels)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latent_frequency_bins(self) -> int:
|
||||||
|
return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latents_per_second(self) -> float:
|
||||||
|
return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_sample_rate(self) -> int:
|
||||||
|
output_rate = getattr(self.vocoder, "output_sample_rate", None)
|
||||||
|
if output_rate is not None:
|
||||||
|
return int(output_rate)
|
||||||
|
upsample_factor = getattr(self.vocoder, "upsample_factor", None)
|
||||||
|
if upsample_factor is None:
|
||||||
|
raise AttributeError(
|
||||||
|
"Vocoder is missing upsample_factor; cannot infer output sample rate"
|
||||||
|
)
|
||||||
|
return int(self.sample_rate * upsample_factor / self.mel_hop_length)
|
||||||
|
|
||||||
|
def memory_required(self, input_shape):
|
||||||
|
return self.device_manager.patcher.model_size()
|
||||||
909
comfy/ldm/lightricks/vae/causal_audio_autoencoder.py
Normal file
909
comfy/ldm/lightricks/vae/causal_audio_autoencoder.py
Normal file
@ -0,0 +1,909 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from typing import Optional
|
||||||
|
from enum import Enum
|
||||||
|
from .pixel_norm import PixelNorm
|
||||||
|
import comfy.ops
|
||||||
|
import logging
|
||||||
|
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
|
||||||
|
class StringConvertibleEnum(Enum):
|
||||||
|
"""
|
||||||
|
Base enum class that provides string-to-enum conversion functionality.
|
||||||
|
|
||||||
|
This mixin adds a str_to_enum() class method that handles conversion from
|
||||||
|
strings, None, or existing enum instances with case-insensitive matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def str_to_enum(cls, value):
|
||||||
|
"""
|
||||||
|
Convert a string, enum instance, or None to the appropriate enum member.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Can be an enum instance of this class, a string, or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Enum member of this class
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the value cannot be converted to a valid enum member
|
||||||
|
"""
|
||||||
|
# Already an enum instance of this class
|
||||||
|
if isinstance(value, cls):
|
||||||
|
return value
|
||||||
|
|
||||||
|
# None maps to NONE member if it exists
|
||||||
|
if value is None:
|
||||||
|
if hasattr(cls, "NONE"):
|
||||||
|
return cls.NONE
|
||||||
|
raise ValueError(f"{cls.__name__} does not have a NONE member to map None to")
|
||||||
|
|
||||||
|
# String conversion (case-insensitive)
|
||||||
|
if isinstance(value, str):
|
||||||
|
value_lower = value.lower()
|
||||||
|
|
||||||
|
# Try to match against enum values
|
||||||
|
for member in cls:
|
||||||
|
# Handle members with None values
|
||||||
|
if member.value is None:
|
||||||
|
if value_lower == "none":
|
||||||
|
return member
|
||||||
|
# Handle members with string values
|
||||||
|
elif isinstance(member.value, str) and member.value.lower() == value_lower:
|
||||||
|
return member
|
||||||
|
|
||||||
|
# Build helpful error message with valid values
|
||||||
|
valid_values = []
|
||||||
|
for member in cls:
|
||||||
|
if member.value is None:
|
||||||
|
valid_values.append("none")
|
||||||
|
elif isinstance(member.value, str):
|
||||||
|
valid_values.append(member.value)
|
||||||
|
|
||||||
|
raise ValueError(f"Invalid {cls.__name__} string: '{value}'. " f"Valid values are: {valid_values}")
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot convert type {type(value).__name__} to {cls.__name__} enum. "
|
||||||
|
f"Expected string, None, or {cls.__name__} instance."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionType(StringConvertibleEnum):
|
||||||
|
"""Enum for specifying the attention mechanism type."""
|
||||||
|
|
||||||
|
VANILLA = "vanilla"
|
||||||
|
LINEAR = "linear"
|
||||||
|
NONE = "none"
|
||||||
|
|
||||||
|
|
||||||
|
class CausalityAxis(StringConvertibleEnum):
|
||||||
|
"""Enum for specifying the causality axis in causal convolutions."""
|
||||||
|
|
||||||
|
NONE = None
|
||||||
|
WIDTH = "width"
|
||||||
|
HEIGHT = "height"
|
||||||
|
WIDTH_COMPATIBILITY = "width-compatibility"
|
||||||
|
|
||||||
|
|
||||||
|
def Normalize(in_channels, *, num_groups=32, normtype="group"):
|
||||||
|
if normtype == "group":
|
||||||
|
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
elif normtype == "pixel":
|
||||||
|
return PixelNorm(dim=1, eps=1e-6)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv2d(nn.Module):
|
||||||
|
"""
|
||||||
|
A causal 2D convolution.
|
||||||
|
|
||||||
|
This layer ensures that the output at time `t` only depends on inputs
|
||||||
|
at time `t` and earlier. It achieves this by applying asymmetric padding
|
||||||
|
to the time dimension (width) before the convolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
bias=True,
|
||||||
|
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
|
||||||
|
# Ensure kernel_size and dilation are tuples
|
||||||
|
kernel_size = nn.modules.utils._pair(kernel_size)
|
||||||
|
dilation = nn.modules.utils._pair(dilation)
|
||||||
|
|
||||||
|
# Calculate padding dimensions
|
||||||
|
pad_h = (kernel_size[0] - 1) * dilation[0]
|
||||||
|
pad_w = (kernel_size[1] - 1) * dilation[1]
|
||||||
|
|
||||||
|
# The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
|
||||||
|
match self.causality_axis:
|
||||||
|
case CausalityAxis.NONE:
|
||||||
|
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
||||||
|
case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
|
||||||
|
self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
||||||
|
case CausalityAxis.HEIGHT:
|
||||||
|
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
||||||
|
|
||||||
|
# The internal convolution layer uses no padding, as we handle it manually
|
||||||
|
self.conv = ops.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=0,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Apply causal padding before convolution
|
||||||
|
x = F.pad(x, self.padding)
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
def make_conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=None,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
bias=True,
|
||||||
|
causality_axis: Optional[CausalityAxis] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a 2D convolution layer that can be either causal or non-causal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels: Number of input channels
|
||||||
|
out_channels: Number of output channels
|
||||||
|
kernel_size: Size of the convolution kernel
|
||||||
|
stride: Convolution stride
|
||||||
|
padding: Padding (if None, will be calculated based on causal flag)
|
||||||
|
dilation: Dilation rate
|
||||||
|
groups: Number of groups for grouped convolution
|
||||||
|
bias: Whether to use bias
|
||||||
|
causality_axis: Dimension along which to apply causality.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either a regular Conv2d or CausalConv2d layer
|
||||||
|
"""
|
||||||
|
if causality_axis is not None:
|
||||||
|
# For causal convolution, padding is handled internally by CausalConv2d
|
||||||
|
return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
|
||||||
|
else:
|
||||||
|
# For non-causal convolution, use symmetric padding if not specified
|
||||||
|
if padding is None:
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
padding = kernel_size // 2
|
||||||
|
else:
|
||||||
|
padding = tuple(k // 2 for k in kernel_size)
|
||||||
|
return ops.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
groups,
|
||||||
|
bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.HEIGHT):
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = with_conv
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
if self.with_conv:
|
||||||
|
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
|
if self.with_conv:
|
||||||
|
x = self.conv(x)
|
||||||
|
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
|
||||||
|
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
|
||||||
|
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
|
||||||
|
# So the output elements rely on the following windows:
|
||||||
|
# 0: [-,-,0]
|
||||||
|
# 1: [-,0,0]
|
||||||
|
# 2: [0,0,1]
|
||||||
|
# 3: [0,1,1]
|
||||||
|
# 4: [1,1,2]
|
||||||
|
# 5: [1,2,2]
|
||||||
|
# Notice that the first and second elements in the output rely only on the first element in the input,
|
||||||
|
# while all other elements rely on two elements in the input.
|
||||||
|
# So we can drop the first element to undo the padding (rather than the last element).
|
||||||
|
# This is a no-op for non-causal convolutions.
|
||||||
|
match self.causality_axis:
|
||||||
|
case CausalityAxis.NONE:
|
||||||
|
pass # x remains unchanged
|
||||||
|
case CausalityAxis.HEIGHT:
|
||||||
|
x = x[:, :, 1:, :]
|
||||||
|
case CausalityAxis.WIDTH:
|
||||||
|
x = x[:, :, :, 1:]
|
||||||
|
case CausalityAxis.WIDTH_COMPATIBILITY:
|
||||||
|
pass # x remains unchanged
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
"""
|
||||||
|
A downsampling layer that can use either a strided convolution
|
||||||
|
or average pooling. Supports standard and causal padding for the
|
||||||
|
convolutional mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.WIDTH):
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = with_conv
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
|
||||||
|
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
|
||||||
|
raise ValueError("causality is only supported when `with_conv=True`.")
|
||||||
|
|
||||||
|
if self.with_conv:
|
||||||
|
# Do time downsampling here
|
||||||
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
|
self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.with_conv:
|
||||||
|
# (pad_left, pad_right, pad_top, pad_bottom)
|
||||||
|
match self.causality_axis:
|
||||||
|
case CausalityAxis.NONE:
|
||||||
|
pad = (0, 1, 0, 1)
|
||||||
|
case CausalityAxis.WIDTH:
|
||||||
|
pad = (2, 0, 0, 1)
|
||||||
|
case CausalityAxis.HEIGHT:
|
||||||
|
pad = (0, 1, 2, 0)
|
||||||
|
case CausalityAxis.WIDTH_COMPATIBILITY:
|
||||||
|
pad = (1, 0, 0, 1)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||||
|
|
||||||
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||||
|
x = self.conv(x)
|
||||||
|
else:
|
||||||
|
# This branch is only taken if with_conv=False, which implies causality_axis is NONE.
|
||||||
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
in_channels,
|
||||||
|
out_channels=None,
|
||||||
|
conv_shortcut=False,
|
||||||
|
dropout,
|
||||||
|
temb_channels=512,
|
||||||
|
norm_type="group",
|
||||||
|
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
|
||||||
|
if self.causality_axis != CausalityAxis.NONE and norm_type == "group":
|
||||||
|
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
||||||
|
self.in_channels = in_channels
|
||||||
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
|
||||||
|
self.norm1 = Normalize(in_channels, normtype=norm_type)
|
||||||
|
self.non_linearity = nn.SiLU()
|
||||||
|
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||||
|
if temb_channels > 0:
|
||||||
|
self.temb_proj = ops.Linear(temb_channels, out_channels)
|
||||||
|
self.norm2 = Normalize(out_channels, normtype=norm_type)
|
||||||
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
|
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
self.conv_shortcut = make_conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.nin_shortcut = make_conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, temb):
|
||||||
|
h = x
|
||||||
|
h = self.norm1(h)
|
||||||
|
h = self.non_linearity(h)
|
||||||
|
h = self.conv1(h)
|
||||||
|
|
||||||
|
if temb is not None:
|
||||||
|
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
||||||
|
|
||||||
|
h = self.norm2(h)
|
||||||
|
h = self.non_linearity(h)
|
||||||
|
h = self.dropout(h)
|
||||||
|
h = self.conv2(h)
|
||||||
|
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
x = self.conv_shortcut(x)
|
||||||
|
else:
|
||||||
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
|
class AttnBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, norm_type="group"):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.norm = Normalize(in_channels, normtype=norm_type)
|
||||||
|
self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q = self.q(h_)
|
||||||
|
k = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
b, c, h, w = q.shape
|
||||||
|
q = q.reshape(b, c, h * w).contiguous()
|
||||||
|
q = q.permute(0, 2, 1).contiguous() # b,hw,c
|
||||||
|
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
|
||||||
|
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||||
|
w_ = w_ * (int(c) ** (-0.5))
|
||||||
|
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||||
|
|
||||||
|
# attend to values
|
||||||
|
v = v.reshape(b, c, h * w).contiguous()
|
||||||
|
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
|
||||||
|
h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||||
|
h_ = h_.reshape(b, c, h, w).contiguous()
|
||||||
|
|
||||||
|
h_ = self.proj_out(h_)
|
||||||
|
|
||||||
|
return x + h_
|
||||||
|
|
||||||
|
|
||||||
|
def make_attn(in_channels, attn_type="vanilla", norm_type="group"):
|
||||||
|
# Convert string to enum if needed
|
||||||
|
attn_type = AttentionType.str_to_enum(attn_type)
|
||||||
|
|
||||||
|
if attn_type != AttentionType.NONE:
|
||||||
|
logging.info(f"making attention of type '{attn_type.value}' with {in_channels} in_channels")
|
||||||
|
else:
|
||||||
|
logging.info(f"making identity attention with {in_channels} in_channels")
|
||||||
|
|
||||||
|
match attn_type:
|
||||||
|
case AttentionType.VANILLA:
|
||||||
|
return AttnBlock(in_channels, norm_type=norm_type)
|
||||||
|
case AttentionType.NONE:
|
||||||
|
return nn.Identity(in_channels)
|
||||||
|
case AttentionType.LINEAR:
|
||||||
|
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ch,
|
||||||
|
out_ch,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
num_res_blocks,
|
||||||
|
attn_resolutions,
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
in_channels,
|
||||||
|
resolution,
|
||||||
|
z_channels,
|
||||||
|
double_z=True,
|
||||||
|
attn_type="vanilla",
|
||||||
|
mid_block_add_attention=True,
|
||||||
|
norm_type="group",
|
||||||
|
causality_axis=CausalityAxis.WIDTH.value,
|
||||||
|
**ignore_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = ch
|
||||||
|
self.temb_ch = 0
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.z_channels = z_channels
|
||||||
|
self.double_z = double_z
|
||||||
|
self.norm_type = norm_type
|
||||||
|
# Convert string to enum if needed (for config loading)
|
||||||
|
causality_axis = CausalityAxis.str_to_enum(causality_axis)
|
||||||
|
self.attn_type = AttentionType.str_to_enum(attn_type)
|
||||||
|
|
||||||
|
# downsampling
|
||||||
|
self.conv_in = make_conv2d(
|
||||||
|
in_channels,
|
||||||
|
self.ch,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.non_linearity = nn.SiLU()
|
||||||
|
|
||||||
|
curr_res = resolution
|
||||||
|
in_ch_mult = (1,) + tuple(ch_mult)
|
||||||
|
self.in_ch_mult = in_ch_mult
|
||||||
|
self.down = nn.ModuleList()
|
||||||
|
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_in = ch * in_ch_mult[i_level]
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
|
||||||
|
for _ in range(self.num_res_blocks):
|
||||||
|
block.append(
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
block_in = block_out
|
||||||
|
if curr_res in attn_resolutions:
|
||||||
|
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
|
||||||
|
|
||||||
|
down = nn.Module()
|
||||||
|
down.block = block
|
||||||
|
down.attn = attn
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||||
|
curr_res = curr_res // 2
|
||||||
|
self.down.append(down)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
if mid_block_add_attention:
|
||||||
|
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
|
||||||
|
else:
|
||||||
|
self.mid.attn_1 = nn.Identity()
|
||||||
|
self.mid.block_2 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(block_in, normtype=self.norm_type)
|
||||||
|
self.conv_out = make_conv2d(
|
||||||
|
block_in,
|
||||||
|
2 * z_channels if double_z else z_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass through the encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape [batch, channels, time, n_mels]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Encoded latent representation
|
||||||
|
"""
|
||||||
|
feature_maps = [self.conv_in(x)]
|
||||||
|
|
||||||
|
# Process each resolution level (from high to low resolution)
|
||||||
|
for resolution_level in range(self.num_resolutions):
|
||||||
|
# Apply residual blocks at current resolution level
|
||||||
|
for block_idx in range(self.num_res_blocks):
|
||||||
|
# Apply ResNet block with optional timestep embedding
|
||||||
|
current_features = self.down[resolution_level].block[block_idx](feature_maps[-1], temb=None)
|
||||||
|
|
||||||
|
# Apply attention if configured for this resolution level
|
||||||
|
if len(self.down[resolution_level].attn) > 0:
|
||||||
|
current_features = self.down[resolution_level].attn[block_idx](current_features)
|
||||||
|
|
||||||
|
# Store processed features
|
||||||
|
feature_maps.append(current_features)
|
||||||
|
|
||||||
|
# Downsample spatial dimensions (except at the final resolution level)
|
||||||
|
if resolution_level != self.num_resolutions - 1:
|
||||||
|
downsampled_features = self.down[resolution_level].downsample(feature_maps[-1])
|
||||||
|
feature_maps.append(downsampled_features)
|
||||||
|
|
||||||
|
# === MIDDLE PROCESSING PHASE ===
|
||||||
|
# Take the lowest resolution features for middle processing
|
||||||
|
bottleneck_features = feature_maps[-1]
|
||||||
|
|
||||||
|
# Apply first middle ResNet block
|
||||||
|
bottleneck_features = self.mid.block_1(bottleneck_features, temb=None)
|
||||||
|
|
||||||
|
# Apply middle attention block
|
||||||
|
bottleneck_features = self.mid.attn_1(bottleneck_features)
|
||||||
|
|
||||||
|
# Apply second middle ResNet block
|
||||||
|
bottleneck_features = self.mid.block_2(bottleneck_features, temb=None)
|
||||||
|
|
||||||
|
# === OUTPUT PHASE ===
|
||||||
|
# Normalize the bottleneck features
|
||||||
|
output_features = self.norm_out(bottleneck_features)
|
||||||
|
|
||||||
|
# Apply non-linearity (SiLU activation)
|
||||||
|
output_features = self.non_linearity(output_features)
|
||||||
|
|
||||||
|
# Final convolution to produce latent representation
|
||||||
|
# [batch, channels, time, n_mels] -> [batch, 2 * z_channels if double_z else z_channels, time, n_mels]
|
||||||
|
return self.conv_out(output_features)
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ch,
|
||||||
|
out_ch,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
num_res_blocks,
|
||||||
|
attn_resolutions,
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
in_channels,
|
||||||
|
resolution,
|
||||||
|
z_channels,
|
||||||
|
give_pre_end=False,
|
||||||
|
tanh_out=False,
|
||||||
|
attn_type="vanilla",
|
||||||
|
mid_block_add_attention=True,
|
||||||
|
norm_type="group",
|
||||||
|
causality_axis=CausalityAxis.WIDTH.value,
|
||||||
|
**ignorekwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = ch
|
||||||
|
self.temb_ch = 0
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_ch = out_ch
|
||||||
|
self.give_pre_end = give_pre_end
|
||||||
|
self.tanh_out = tanh_out
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.z_channels = z_channels
|
||||||
|
# Convert string to enum if needed (for config loading)
|
||||||
|
causality_axis = CausalityAxis.str_to_enum(causality_axis)
|
||||||
|
self.attn_type = AttentionType.str_to_enum(attn_type)
|
||||||
|
|
||||||
|
# compute block_in and curr_res at lowest res
|
||||||
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
self.conv_in = make_conv2d(z_channels, block_in, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||||
|
|
||||||
|
self.non_linearity = nn.SiLU()
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
if mid_block_add_attention:
|
||||||
|
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
|
||||||
|
else:
|
||||||
|
self.mid.attn_1 = nn.Identity()
|
||||||
|
self.mid.block_2 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for _ in range(self.num_res_blocks + 1):
|
||||||
|
block.append(
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
block_in = block_out
|
||||||
|
if curr_res in attn_resolutions:
|
||||||
|
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
|
||||||
|
up = nn.Module()
|
||||||
|
up.block = block
|
||||||
|
up.attn = attn
|
||||||
|
if i_level != 0:
|
||||||
|
up.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||||
|
curr_res = curr_res * 2
|
||||||
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(block_in, normtype=self.norm_type)
|
||||||
|
self.conv_out = make_conv2d(block_in, out_ch, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||||
|
|
||||||
|
def _adjust_output_shape(self, decoded_output, target_shape):
|
||||||
|
"""
|
||||||
|
Adjust output shape to match target dimensions for variable-length audio.
|
||||||
|
|
||||||
|
This function handles the common case where decoded audio spectrograms need to be
|
||||||
|
resized to match a specific target shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoded_output: Tensor of shape (batch, channels, time, frequency)
|
||||||
|
target_shape: Target shape tuple (batch, channels, time, frequency)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor adjusted to match target_shape exactly
|
||||||
|
"""
|
||||||
|
# Current output shape: (batch, channels, time, frequency)
|
||||||
|
_, _, current_time, current_freq = decoded_output.shape
|
||||||
|
_, target_channels, target_time, target_freq = target_shape
|
||||||
|
|
||||||
|
# Step 1: Crop first to avoid exceeding target dimensions
|
||||||
|
decoded_output = decoded_output[
|
||||||
|
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Step 2: Calculate padding needed for time and frequency dimensions
|
||||||
|
time_padding_needed = target_time - decoded_output.shape[2]
|
||||||
|
freq_padding_needed = target_freq - decoded_output.shape[3]
|
||||||
|
|
||||||
|
# Step 3: Apply padding if needed
|
||||||
|
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||||
|
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
|
||||||
|
# For audio: pad_left/right = frequency, pad_top/bottom = time
|
||||||
|
padding = (
|
||||||
|
0,
|
||||||
|
max(freq_padding_needed, 0), # frequency padding (left, right)
|
||||||
|
0,
|
||||||
|
max(time_padding_needed, 0), # time padding (top, bottom)
|
||||||
|
)
|
||||||
|
decoded_output = F.pad(decoded_output, padding)
|
||||||
|
|
||||||
|
# Step 4: Final safety crop to ensure exact target shape
|
||||||
|
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
||||||
|
|
||||||
|
return decoded_output
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {
|
||||||
|
"ch": self.ch,
|
||||||
|
"out_ch": self.out_ch,
|
||||||
|
"ch_mult": self.ch_mult,
|
||||||
|
"num_res_blocks": self.num_res_blocks,
|
||||||
|
"in_channels": self.in_channels,
|
||||||
|
"resolution": self.resolution,
|
||||||
|
"z_channels": self.z_channels,
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, latent_features, target_shape=None):
|
||||||
|
"""
|
||||||
|
Decode latent features back to audio spectrograms.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latent_features: Encoded latent representation of shape (batch, channels, height, width)
|
||||||
|
target_shape: Optional target output shape (batch, channels, time, frequency)
|
||||||
|
If provided, output will be cropped/padded to match this shape
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
|
||||||
|
"""
|
||||||
|
assert target_shape is not None, "Target shape is required for CausalAudioAutoencoder Decoder"
|
||||||
|
|
||||||
|
# Transform latent features to decoder's internal feature dimension
|
||||||
|
hidden_features = self.conv_in(latent_features)
|
||||||
|
|
||||||
|
# Middle processing
|
||||||
|
hidden_features = self.mid.block_1(hidden_features, temb=None)
|
||||||
|
hidden_features = self.mid.attn_1(hidden_features)
|
||||||
|
hidden_features = self.mid.block_2(hidden_features, temb=None)
|
||||||
|
|
||||||
|
# Upsampling
|
||||||
|
# Progressively increase spatial resolution from lowest to highest
|
||||||
|
for resolution_level in reversed(range(self.num_resolutions)):
|
||||||
|
# Apply residual blocks at current resolution level
|
||||||
|
for block_index in range(self.num_res_blocks + 1):
|
||||||
|
hidden_features = self.up[resolution_level].block[block_index](hidden_features, temb=None)
|
||||||
|
|
||||||
|
if len(self.up[resolution_level].attn) > 0:
|
||||||
|
hidden_features = self.up[resolution_level].attn[block_index](hidden_features)
|
||||||
|
|
||||||
|
if resolution_level != 0:
|
||||||
|
hidden_features = self.up[resolution_level].upsample(hidden_features)
|
||||||
|
|
||||||
|
# Output
|
||||||
|
if self.give_pre_end:
|
||||||
|
# Return intermediate features before final processing (for debugging/analysis)
|
||||||
|
decoded_output = hidden_features
|
||||||
|
else:
|
||||||
|
# Standard output path: normalize, activate, and convert to output channels
|
||||||
|
# Final normalization layer
|
||||||
|
hidden_features = self.norm_out(hidden_features)
|
||||||
|
|
||||||
|
# Apply SiLU (Swish) activation function
|
||||||
|
hidden_features = self.non_linearity(hidden_features)
|
||||||
|
|
||||||
|
# Final convolution to map to output channels (typically 2 for stereo audio)
|
||||||
|
decoded_output = self.conv_out(hidden_features)
|
||||||
|
|
||||||
|
# Optional tanh activation to bound output values to [-1, 1] range
|
||||||
|
if self.tanh_out:
|
||||||
|
decoded_output = torch.tanh(decoded_output)
|
||||||
|
|
||||||
|
# Adjust shape for audio data
|
||||||
|
if target_shape is not None:
|
||||||
|
decoded_output = self._adjust_output_shape(decoded_output, target_shape)
|
||||||
|
|
||||||
|
return decoded_output
|
||||||
|
|
||||||
|
|
||||||
|
class processor(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("std-of-means", torch.empty(128))
|
||||||
|
self.register_buffer("mean-of-means", torch.empty(128))
|
||||||
|
|
||||||
|
def un_normalize(self, x):
|
||||||
|
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
|
||||||
|
|
||||||
|
def normalize(self, x):
|
||||||
|
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
|
||||||
|
|
||||||
|
|
||||||
|
class CausalAudioAutoencoder(nn.Module):
|
||||||
|
def __init__(self, config=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = self._guess_config()
|
||||||
|
|
||||||
|
# Extract encoder and decoder configs from the new format
|
||||||
|
model_config = config.get("model", {}).get("params", {})
|
||||||
|
variables_config = config.get("variables", {})
|
||||||
|
|
||||||
|
self.sampling_rate = variables_config.get(
|
||||||
|
"sampling_rate",
|
||||||
|
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
|
||||||
|
)
|
||||||
|
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
||||||
|
decoder_config = model_config.get("decoder", encoder_config)
|
||||||
|
|
||||||
|
# Load mel spectrogram parameters
|
||||||
|
self.mel_bins = encoder_config.get("mel_bins", 64)
|
||||||
|
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||||
|
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||||
|
|
||||||
|
# Store causality configuration at VAE level (not just in encoder internals)
|
||||||
|
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
|
||||||
|
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
||||||
|
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
||||||
|
|
||||||
|
self.encoder = Encoder(**encoder_config)
|
||||||
|
self.decoder = Decoder(**decoder_config)
|
||||||
|
|
||||||
|
self.per_channel_statistics = processor()
|
||||||
|
|
||||||
|
def _guess_config(self):
|
||||||
|
encoder_config = {
|
||||||
|
# Required parameters - based on ltx-video-av-1679000 model metadata
|
||||||
|
"ch": 128,
|
||||||
|
"out_ch": 8,
|
||||||
|
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
|
||||||
|
"num_res_blocks": 2,
|
||||||
|
"attn_resolutions": [], # Based on metadata: empty list, no attention
|
||||||
|
"dropout": 0.0,
|
||||||
|
"resamp_with_conv": True,
|
||||||
|
"in_channels": 2, # stereo
|
||||||
|
"resolution": 256,
|
||||||
|
"z_channels": 8,
|
||||||
|
"double_z": True,
|
||||||
|
"attn_type": "vanilla",
|
||||||
|
"mid_block_add_attention": False, # Based on metadata: false
|
||||||
|
"norm_type": "pixel",
|
||||||
|
"causality_axis": "height", # Based on metadata
|
||||||
|
"mel_bins": 64, # Based on metadata: mel_bins = 64
|
||||||
|
}
|
||||||
|
|
||||||
|
decoder_config = {
|
||||||
|
# Inherits encoder config, can override specific params
|
||||||
|
**encoder_config,
|
||||||
|
"out_ch": 2, # Stereo audio output (2 channels)
|
||||||
|
"give_pre_end": False,
|
||||||
|
"tanh_out": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"_class_name": "CausalAudioAutoencoder",
|
||||||
|
"sampling_rate": 16000,
|
||||||
|
"model": {
|
||||||
|
"params": {
|
||||||
|
"encoder": encoder_config,
|
||||||
|
"decoder": decoder_config,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {
|
||||||
|
"sampling_rate": self.sampling_rate,
|
||||||
|
"mel_bins": self.mel_bins,
|
||||||
|
"mel_hop_length": self.mel_hop_length,
|
||||||
|
"n_fft": self.n_fft,
|
||||||
|
"causality_axis": self.causality_axis.value,
|
||||||
|
"is_causal": self.is_causal,
|
||||||
|
}
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self.encoder(x)
|
||||||
|
|
||||||
|
def decode(self, x, target_shape=None):
|
||||||
|
return self.decoder(x, target_shape=target_shape)
|
||||||
213
comfy/ldm/lightricks/vocoders/vocoder.py
Normal file
213
comfy/ldm/lightricks/vocoders/vocoder.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn as nn
|
||||||
|
import comfy.ops
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
def get_padding(kernel_size, dilation=1):
|
||||||
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock1(torch.nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||||
|
super(ResBlock1, self).__init__()
|
||||||
|
self.convs1 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[2],
|
||||||
|
padding=get_padding(kernel_size, dilation[2]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.convs2 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c1, c2 in zip(self.convs1, self.convs2):
|
||||||
|
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||||
|
xt = c2(xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock2(torch.nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||||
|
super(ResBlock2, self).__init__()
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c in self.convs:
|
||||||
|
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
xt = c(xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Vocoder(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config=None):
|
||||||
|
super(Vocoder, self).__init__()
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = self.get_default_config()
|
||||||
|
|
||||||
|
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
||||||
|
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
|
||||||
|
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
|
||||||
|
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||||
|
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
||||||
|
stereo = config.get("stereo", True)
|
||||||
|
resblock = config.get("resblock", "1")
|
||||||
|
|
||||||
|
self.output_sample_rate = config.get("output_sample_rate")
|
||||||
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(upsample_rates)
|
||||||
|
in_channels = 128 if stereo else 64
|
||||||
|
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||||
|
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||||
|
|
||||||
|
self.ups = nn.ModuleList()
|
||||||
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
|
self.ups.append(
|
||||||
|
ops.ConvTranspose1d(
|
||||||
|
upsample_initial_channel // (2**i),
|
||||||
|
upsample_initial_channel // (2 ** (i + 1)),
|
||||||
|
k,
|
||||||
|
u,
|
||||||
|
padding=(k - u) // 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||||
|
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||||
|
self.resblocks.append(resblock_class(ch, k, d))
|
||||||
|
|
||||||
|
out_channels = 2 if stereo else 1
|
||||||
|
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
|
||||||
|
|
||||||
|
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
||||||
|
|
||||||
|
def get_default_config(self):
|
||||||
|
"""Generate default configuration for the vocoder."""
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"resblock_kernel_sizes": [3, 7, 11],
|
||||||
|
"upsample_rates": [6, 5, 2, 2, 2],
|
||||||
|
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||||
|
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
"upsample_initial_channel": 1024,
|
||||||
|
"stereo": True,
|
||||||
|
"resblock": "1",
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass of the vocoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input spectrogram tensor. Can be:
|
||||||
|
- 3D: (batch_size, channels, time_steps) for mono
|
||||||
|
- 4D: (batch_size, 2, channels, time_steps) for stereo
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Audio tensor of shape (batch_size, out_channels, audio_length)
|
||||||
|
"""
|
||||||
|
if x.dim() == 4: # stereo
|
||||||
|
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
||||||
|
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
||||||
|
x = self.conv_pre(x)
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
x = self.ups[i](x)
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
x = F.leaky_relu(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
x = torch.tanh(x)
|
||||||
|
|
||||||
|
return x
|
||||||
@ -41,6 +41,11 @@ class ZImage_Control(torch.nn.Module):
|
|||||||
ffn_dim_multiplier: float = (8.0 / 3.0),
|
ffn_dim_multiplier: float = (8.0 / 3.0),
|
||||||
norm_eps: float = 1e-5,
|
norm_eps: float = 1e-5,
|
||||||
qk_norm: bool = True,
|
qk_norm: bool = True,
|
||||||
|
n_control_layers=6,
|
||||||
|
control_in_dim=16,
|
||||||
|
additional_in_dim=0,
|
||||||
|
broken=False,
|
||||||
|
refiner_control=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -49,10 +54,11 @@ class ZImage_Control(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
self.additional_in_dim = 0
|
self.broken = broken
|
||||||
self.control_in_dim = 16
|
self.additional_in_dim = additional_in_dim
|
||||||
|
self.control_in_dim = control_in_dim
|
||||||
n_refiner_layers = 2
|
n_refiner_layers = 2
|
||||||
self.n_control_layers = 6
|
self.n_control_layers = n_control_layers
|
||||||
self.control_layers = nn.ModuleList(
|
self.control_layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
ZImageControlTransformerBlock(
|
ZImageControlTransformerBlock(
|
||||||
@ -74,28 +80,49 @@ class ZImage_Control(torch.nn.Module):
|
|||||||
all_x_embedder = {}
|
all_x_embedder = {}
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
f_patch_size = 1
|
f_patch_size = 1
|
||||||
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
|
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
|
||||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||||
|
|
||||||
|
self.refiner_control = refiner_control
|
||||||
|
|
||||||
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||||
self.control_noise_refiner = nn.ModuleList(
|
if self.refiner_control:
|
||||||
[
|
self.control_noise_refiner = nn.ModuleList(
|
||||||
JointTransformerBlock(
|
[
|
||||||
layer_id,
|
ZImageControlTransformerBlock(
|
||||||
dim,
|
layer_id,
|
||||||
n_heads,
|
dim,
|
||||||
n_kv_heads,
|
n_heads,
|
||||||
multiple_of,
|
n_kv_heads,
|
||||||
ffn_dim_multiplier,
|
multiple_of,
|
||||||
norm_eps,
|
ffn_dim_multiplier,
|
||||||
qk_norm,
|
norm_eps,
|
||||||
modulation=True,
|
qk_norm,
|
||||||
z_image_modulation=True,
|
block_id=layer_id,
|
||||||
operation_settings=operation_settings,
|
operation_settings=operation_settings,
|
||||||
)
|
)
|
||||||
for layer_id in range(n_refiner_layers)
|
for layer_id in range(n_refiner_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.control_noise_refiner = nn.ModuleList(
|
||||||
|
[
|
||||||
|
JointTransformerBlock(
|
||||||
|
layer_id,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
multiple_of,
|
||||||
|
ffn_dim_multiplier,
|
||||||
|
norm_eps,
|
||||||
|
qk_norm,
|
||||||
|
modulation=True,
|
||||||
|
z_image_modulation=True,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
for layer_id in range(n_refiner_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
@ -105,9 +132,29 @@ class ZImage_Control(torch.nn.Module):
|
|||||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||||
|
|
||||||
x_attn_mask = None
|
x_attn_mask = None
|
||||||
for layer in self.control_noise_refiner:
|
if not self.refiner_control:
|
||||||
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
for layer in self.control_noise_refiner:
|
||||||
|
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||||
|
|
||||||
return control_context
|
return control_context
|
||||||
|
|
||||||
|
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||||
|
if self.refiner_control:
|
||||||
|
if self.broken:
|
||||||
|
if layer_id == 0:
|
||||||
|
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||||
|
if layer_id > 0:
|
||||||
|
out = None
|
||||||
|
for i in range(1, len(self.control_layers)):
|
||||||
|
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||||
|
if out is None:
|
||||||
|
out = o
|
||||||
|
|
||||||
|
return (out, control_context)
|
||||||
|
else:
|
||||||
|
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||||
|
else:
|
||||||
|
return (None, control_context)
|
||||||
|
|
||||||
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||||
|
|||||||
@ -377,6 +377,7 @@ class NextDiT(nn.Module):
|
|||||||
z_image_modulation=False,
|
z_image_modulation=False,
|
||||||
time_scale=1.0,
|
time_scale=1.0,
|
||||||
pad_tokens_multiple=None,
|
pad_tokens_multiple=None,
|
||||||
|
clip_text_dim=None,
|
||||||
image_model=None,
|
image_model=None,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -447,6 +448,31 @@ class NextDiT(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.clip_text_pooled_proj = None
|
||||||
|
|
||||||
|
if clip_text_dim is not None:
|
||||||
|
self.clip_text_dim = clip_text_dim
|
||||||
|
self.clip_text_pooled_proj = nn.Sequential(
|
||||||
|
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
clip_text_dim,
|
||||||
|
clip_text_dim,
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.time_text_embed = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
min(dim, 1024) + clip_text_dim,
|
||||||
|
min(dim, 1024),
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
JointTransformerBlock(
|
JointTransformerBlock(
|
||||||
@ -465,7 +491,8 @@ class NextDiT(nn.Module):
|
|||||||
for layer_id in range(n_layers)
|
for layer_id in range(n_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
# This norm final is in the lumina 2.0 code but isn't actually used for anything.
|
||||||
|
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
||||||
|
|
||||||
if self.pad_tokens_multiple is not None:
|
if self.pad_tokens_multiple is not None:
|
||||||
@ -510,6 +537,7 @@ class NextDiT(nn.Module):
|
|||||||
bsz = len(x)
|
bsz = len(x)
|
||||||
pH = pW = self.patch_size
|
pH = pW = self.patch_size
|
||||||
device = x[0].device
|
device = x[0].device
|
||||||
|
orig_x = x
|
||||||
|
|
||||||
if self.pad_tokens_multiple is not None:
|
if self.pad_tokens_multiple is not None:
|
||||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||||
@ -546,13 +574,21 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||||
|
|
||||||
|
patches = transformer_options.get("patches", {})
|
||||||
|
|
||||||
# refine context
|
# refine context
|
||||||
for layer in self.context_refiner:
|
for layer in self.context_refiner:
|
||||||
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
||||||
|
|
||||||
padded_img_mask = None
|
padded_img_mask = None
|
||||||
for layer in self.noise_refiner:
|
x_input = x
|
||||||
|
for i, layer in enumerate(self.noise_refiner):
|
||||||
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
||||||
|
if "noise_refiner" in patches:
|
||||||
|
for p in patches["noise_refiner"]:
|
||||||
|
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
||||||
|
if "img" in out:
|
||||||
|
x = out["img"]
|
||||||
|
|
||||||
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||||
mask = None
|
mask = None
|
||||||
@ -585,16 +621,29 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||||
|
|
||||||
|
if self.clip_text_pooled_proj is not None:
|
||||||
|
pooled = kwargs.get("clip_text_pooled", None)
|
||||||
|
if pooled is not None:
|
||||||
|
pooled = self.clip_text_pooled_proj(pooled)
|
||||||
|
else:
|
||||||
|
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||||
|
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
x_is_tensor = isinstance(x, torch.Tensor)
|
x_is_tensor = isinstance(x, torch.Tensor)
|
||||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
||||||
freqs_cis = freqs_cis.to(img.device)
|
freqs_cis = freqs_cis.to(img.device)
|
||||||
|
|
||||||
|
transformer_options["total_blocks"] = len(self.layers)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
|
img_input = img
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||||
if "double_block" in patches:
|
if "double_block" in patches:
|
||||||
for p in patches["double_block"]:
|
for p in patches["double_block"]:
|
||||||
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||||
if "img" in out:
|
if "img" in out:
|
||||||
img[:, cap_size[0]:] = out["img"]
|
img[:, cap_size[0]:] = out["img"]
|
||||||
if "txt" in out:
|
if "txt" in out:
|
||||||
|
|||||||
@ -30,6 +30,13 @@ except ImportError as e:
|
|||||||
raise e
|
raise e
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
|
SAGE_ATTENTION3_IS_AVAILABLE = False
|
||||||
|
try:
|
||||||
|
from sageattn3 import sageattn3_blackwell
|
||||||
|
SAGE_ATTENTION3_IS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_func
|
from flash_attn import flash_attn_func
|
||||||
@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
out = out.reshape(b, -1, heads * dim_head)
|
out = out.reshape(b, -1, heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@wrap_attn
|
||||||
|
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
exception_fallback = False
|
||||||
|
if (q.device.type != "cuda" or
|
||||||
|
q.dtype not in (torch.float16, torch.bfloat16) or
|
||||||
|
mask is not None):
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if skip_reshape:
|
||||||
|
B, H, L, D = q.shape
|
||||||
|
if H != heads:
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
q_s, k_s, v_s = q, k, v
|
||||||
|
N = q.shape[2]
|
||||||
|
dim_head = D
|
||||||
|
else:
|
||||||
|
B, N, inner_dim = q.shape
|
||||||
|
if inner_dim % heads != 0:
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=False,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
dim_head = inner_dim // heads
|
||||||
|
|
||||||
|
if dim_head >= 256 or N <= 1024:
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not skip_reshape:
|
||||||
|
q_s, k_s, v_s = map(
|
||||||
|
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
B, H, L, D = q_s.shape
|
||||||
|
|
||||||
|
try:
|
||||||
|
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
|
||||||
|
except Exception as e:
|
||||||
|
exception_fallback = True
|
||||||
|
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
|
||||||
|
|
||||||
|
if exception_fallback:
|
||||||
|
if not skip_reshape:
|
||||||
|
del q_s, k_s, v_s
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=False,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if skip_reshape:
|
||||||
|
if not skip_output_reshape:
|
||||||
|
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||||
|
else:
|
||||||
|
if skip_output_reshape:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||||
@ -650,6 +744,8 @@ optimized_attention_masked = optimized_attention
|
|||||||
# register core-supported attention functions
|
# register core-supported attention functions
|
||||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||||
register_attention_function("sage", attention_sage)
|
register_attention_function("sage", attention_sage)
|
||||||
|
if SAGE_ATTENTION3_IS_AVAILABLE:
|
||||||
|
register_attention_function("sage3", attention3_sage)
|
||||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||||
register_attention_function("flash", attention_flash)
|
register_attention_function("flash", attention_flash)
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
|
|||||||
@ -394,7 +394,8 @@ class Model(nn.Module):
|
|||||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_linear_attn: attn_type = "linear"
|
if use_linear_attn:
|
||||||
|
attn_type = "linear"
|
||||||
self.ch = ch
|
self.ch = ch
|
||||||
self.temb_ch = self.ch*4
|
self.temb_ch = self.ch*4
|
||||||
self.num_resolutions = len(ch_mult)
|
self.num_resolutions = len(ch_mult)
|
||||||
@ -548,7 +549,8 @@ class Encoder(nn.Module):
|
|||||||
conv3d=False, time_compress=None,
|
conv3d=False, time_compress=None,
|
||||||
**ignore_kwargs):
|
**ignore_kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_linear_attn: attn_type = "linear"
|
if use_linear_attn:
|
||||||
|
attn_type = "linear"
|
||||||
self.ch = ch
|
self.ch = ch
|
||||||
self.temb_ch = 0
|
self.temb_ch = 0
|
||||||
self.num_resolutions = len(ch_mult)
|
self.num_resolutions = len(ch_mult)
|
||||||
|
|||||||
@ -45,7 +45,7 @@ class LitEma(nn.Module):
|
|||||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||||
else:
|
else:
|
||||||
assert not key in self.m_name2s_name
|
assert key not in self.m_name2s_name
|
||||||
|
|
||||||
def copy_to(self, model):
|
def copy_to(self, model):
|
||||||
m_param = dict(model.named_parameters())
|
m_param = dict(model.named_parameters())
|
||||||
@ -54,7 +54,7 @@ class LitEma(nn.Module):
|
|||||||
if m_param[key].requires_grad:
|
if m_param[key].requires_grad:
|
||||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||||
else:
|
else:
|
||||||
assert not key in self.m_name2s_name
|
assert key not in self.m_name2s_name
|
||||||
|
|
||||||
def store(self, parameters):
|
def store(self, parameters):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis):
|
|||||||
|
|
||||||
|
|
||||||
class QwenTimestepProjEmbeddings(nn.Module):
|
class QwenTimestepProjEmbeddings(nn.Module):
|
||||||
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
|
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
||||||
self.timestep_embedder = TimestepEmbedding(
|
self.timestep_embedder = TimestepEmbedding(
|
||||||
@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module):
|
|||||||
operations=operations
|
operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, timestep, hidden_states):
|
self.use_additional_t_cond = use_additional_t_cond
|
||||||
|
if self.use_additional_t_cond:
|
||||||
|
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, timestep, hidden_states, addition_t_cond=None):
|
||||||
timesteps_proj = self.time_proj(timestep)
|
timesteps_proj = self.time_proj(timestep)
|
||||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
||||||
|
|
||||||
|
if self.use_additional_t_cond:
|
||||||
|
if addition_t_cond is None:
|
||||||
|
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
|
||||||
|
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
|
||||||
|
|
||||||
return timesteps_emb
|
return timesteps_emb
|
||||||
|
|
||||||
|
|
||||||
@ -218,9 +228,24 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _apply_gate(self, x, y, gate, timestep_zero_index=None):
|
||||||
|
if timestep_zero_index is not None:
|
||||||
|
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
|
||||||
|
else:
|
||||||
|
return torch.addcmul(y, gate, x)
|
||||||
|
|
||||||
|
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
|
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
|
||||||
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
if timestep_zero_index is not None:
|
||||||
|
actual_batch = shift.size(0) // 2
|
||||||
|
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
|
||||||
|
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
|
||||||
|
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
|
||||||
|
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
|
||||||
|
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
|
||||||
|
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
|
||||||
|
else:
|
||||||
|
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -229,14 +254,19 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
encoder_hidden_states_mask: torch.Tensor,
|
encoder_hidden_states_mask: torch.Tensor,
|
||||||
temb: torch.Tensor,
|
temb: torch.Tensor,
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
timestep_zero_index=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
img_mod_params = self.img_mod(temb)
|
img_mod_params = self.img_mod(temb)
|
||||||
|
|
||||||
|
if timestep_zero_index is not None:
|
||||||
|
temb = temb.chunk(2, dim=0)[0]
|
||||||
|
|
||||||
txt_mod_params = self.txt_mod(temb)
|
txt_mod_params = self.txt_mod(temb)
|
||||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
||||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
||||||
|
|
||||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
|
||||||
del img_mod1
|
del img_mod1
|
||||||
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||||
del txt_mod1
|
del txt_mod1
|
||||||
@ -251,15 +281,15 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
del img_modulated
|
del img_modulated
|
||||||
del txt_modulated
|
del txt_modulated
|
||||||
|
|
||||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
|
||||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||||
del img_attn_output
|
del img_attn_output
|
||||||
del txt_attn_output
|
del txt_attn_output
|
||||||
del img_gate1
|
del img_gate1
|
||||||
del txt_gate1
|
del txt_gate1
|
||||||
|
|
||||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
|
||||||
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
|
||||||
|
|
||||||
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
||||||
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
||||||
@ -300,10 +330,11 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
num_attention_heads: int = 24,
|
num_attention_heads: int = 24,
|
||||||
joint_attention_dim: int = 3584,
|
joint_attention_dim: int = 3584,
|
||||||
pooled_projection_dim: int = 768,
|
pooled_projection_dim: int = 768,
|
||||||
guidance_embeds: bool = False,
|
|
||||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||||
|
default_ref_method="index",
|
||||||
image_model=None,
|
image_model=None,
|
||||||
final_layer=True,
|
final_layer=True,
|
||||||
|
use_additional_t_cond=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -314,12 +345,14 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels or in_channels
|
self.out_channels = out_channels or in_channels
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.default_ref_method = default_ref_method
|
||||||
|
|
||||||
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||||
|
|
||||||
self.time_text_embed = QwenTimestepProjEmbeddings(
|
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||||
embedding_dim=self.inner_dim,
|
embedding_dim=self.inner_dim,
|
||||||
pooled_projection_dim=pooled_projection_dim,
|
pooled_projection_dim=pooled_projection_dim,
|
||||||
|
use_additional_t_cond=use_additional_t_cond,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations
|
operations=operations
|
||||||
@ -341,6 +374,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
if self.default_ref_method == "index_timestep_zero":
|
||||||
|
self.register_buffer("__index_timestep_zero__", torch.tensor([]))
|
||||||
|
|
||||||
if final_layer:
|
if final_layer:
|
||||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||||
@ -350,27 +386,33 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||||
|
t_len = t
|
||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
|
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
|
||||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
|
||||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
|
|
||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
|
||||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
|
||||||
|
|
||||||
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
if t_len > 1:
|
||||||
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
|
||||||
|
|
||||||
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
|
||||||
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
|
||||||
|
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
|
||||||
|
|
||||||
|
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
self._forward,
|
self._forward,
|
||||||
self,
|
self,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
|
).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self,
|
self,
|
||||||
@ -378,8 +420,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
timesteps,
|
timesteps,
|
||||||
context,
|
context,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
guidance: torch.Tensor = None,
|
|
||||||
ref_latents=None,
|
ref_latents=None,
|
||||||
|
additional_t_cond=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
control=None,
|
control=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@ -391,16 +433,24 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||||
num_embeds = hidden_states.shape[1]
|
num_embeds = hidden_states.shape[1]
|
||||||
|
|
||||||
|
timestep_zero_index = None
|
||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
h = 0
|
h = 0
|
||||||
w = 0
|
w = 0
|
||||||
index = 0
|
index = 0
|
||||||
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
|
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
|
||||||
|
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
||||||
|
negative_ref_method = ref_method == "negative_index"
|
||||||
|
timestep_zero = ref_method == "index_timestep_zero"
|
||||||
for ref in ref_latents:
|
for ref in ref_latents:
|
||||||
if index_ref_method:
|
if index_ref_method:
|
||||||
index += 1
|
index += 1
|
||||||
h_offset = 0
|
h_offset = 0
|
||||||
w_offset = 0
|
w_offset = 0
|
||||||
|
elif negative_ref_method:
|
||||||
|
index -= 1
|
||||||
|
h_offset = 0
|
||||||
|
w_offset = 0
|
||||||
else:
|
else:
|
||||||
index = 1
|
index = 1
|
||||||
h_offset = 0
|
h_offset = 0
|
||||||
@ -415,6 +465,10 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
|
if timestep_zero:
|
||||||
|
if index > 0:
|
||||||
|
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||||
|
timestep_zero_index = num_embeds
|
||||||
|
|
||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
@ -426,14 +480,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||||
|
|
||||||
if guidance is not None:
|
temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
|
||||||
guidance = guidance * 1000
|
|
||||||
|
|
||||||
temb = (
|
|
||||||
self.time_text_embed(timestep, hidden_states)
|
|
||||||
if guidance is None
|
|
||||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
|
||||||
)
|
|
||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
@ -446,7 +493,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
|
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
hidden_states = out["img"]
|
hidden_states = out["img"]
|
||||||
@ -458,6 +505,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
timestep_zero_index=timestep_zero_index,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -474,9 +522,12 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
if add is not None:
|
if add is not None:
|
||||||
hidden_states[:, :add.shape[1]] += add
|
hidden_states[:, :add.shape[1]] += add
|
||||||
|
|
||||||
|
if timestep_zero_index is not None:
|
||||||
|
temb = temb.chunk(2, dim=0)[0]
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||||
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
|
||||||
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
||||||
|
|||||||
@ -71,7 +71,7 @@ def count_params(model, verbose=False):
|
|||||||
|
|
||||||
|
|
||||||
def instantiate_from_config(config):
|
def instantiate_from_config(config):
|
||||||
if not "target" in config:
|
if "target" not in config:
|
||||||
if config == '__is_first_stage__':
|
if config == '__is_first_stage__':
|
||||||
return None
|
return None
|
||||||
elif config == "__is_unconditional__":
|
elif config == "__is_unconditional__":
|
||||||
|
|||||||
@ -568,7 +568,10 @@ class WanModel(torch.nn.Module):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -763,7 +766,10 @@ class VaceWanModel(WanModel):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -862,7 +868,10 @@ class CameraWanModel(WanModel):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -1326,16 +1335,19 @@ class WanModel_S2V(WanModel):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(x, e=e0, freqs=freqs, context=context)
|
x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
|
||||||
if audio_emb is not None:
|
if audio_emb is not None:
|
||||||
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
|
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
|
||||||
# head
|
# head
|
||||||
@ -1574,7 +1586,10 @@ class HumoWanModel(WanModel):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
@ -523,7 +523,10 @@ class AnimateWanModel(WanModel):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
@ -227,6 +227,7 @@ class Encoder3d(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
dim=128,
|
dim=128,
|
||||||
z_dim=4,
|
z_dim=4,
|
||||||
|
input_channels=3,
|
||||||
dim_mult=[1, 2, 4, 4],
|
dim_mult=[1, 2, 4, 4],
|
||||||
num_res_blocks=2,
|
num_res_blocks=2,
|
||||||
attn_scales=[],
|
attn_scales=[],
|
||||||
@ -245,7 +246,7 @@ class Encoder3d(nn.Module):
|
|||||||
scale = 1.0
|
scale = 1.0
|
||||||
|
|
||||||
# init block
|
# init block
|
||||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
|
||||||
|
|
||||||
# downsample blocks
|
# downsample blocks
|
||||||
downsamples = []
|
downsamples = []
|
||||||
@ -331,6 +332,7 @@ class Decoder3d(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
dim=128,
|
dim=128,
|
||||||
z_dim=4,
|
z_dim=4,
|
||||||
|
output_channels=3,
|
||||||
dim_mult=[1, 2, 4, 4],
|
dim_mult=[1, 2, 4, 4],
|
||||||
num_res_blocks=2,
|
num_res_blocks=2,
|
||||||
attn_scales=[],
|
attn_scales=[],
|
||||||
@ -378,7 +380,7 @@ class Decoder3d(nn.Module):
|
|||||||
# output blocks
|
# output blocks
|
||||||
self.head = nn.Sequential(
|
self.head = nn.Sequential(
|
||||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
## conv1
|
## conv1
|
||||||
@ -449,6 +451,7 @@ class WanVAE(nn.Module):
|
|||||||
num_res_blocks=2,
|
num_res_blocks=2,
|
||||||
attn_scales=[],
|
attn_scales=[],
|
||||||
temperal_downsample=[True, True, False],
|
temperal_downsample=[True, True, False],
|
||||||
|
image_channels=3,
|
||||||
dropout=0.0):
|
dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -460,11 +463,11 @@ class WanVAE(nn.Module):
|
|||||||
self.temperal_upsample = temperal_downsample[::-1]
|
self.temperal_upsample = temperal_downsample[::-1]
|
||||||
|
|
||||||
# modules
|
# modules
|
||||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
|
||||||
attn_scales, self.temperal_downsample, dropout)
|
attn_scales, self.temperal_downsample, dropout)
|
||||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
|
||||||
attn_scales, self.temperal_upsample, dropout)
|
attn_scales, self.temperal_upsample, dropout)
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
|
|||||||
@ -320,6 +320,7 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
to = diffusers_keys[k]
|
to = diffusers_keys[k]
|
||||||
key_lora = k[:-len(".weight")]
|
key_lora = k[:-len(".weight")]
|
||||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||||
|
key_map["transformer.{}".format(key_lora)] = to
|
||||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||||
|
|
||||||
if isinstance(model, comfy.model_base.Kandinsky5):
|
if isinstance(model, comfy.model_base.Kandinsky5):
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1
|
|||||||
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
|
import comfy.ldm.lightricks.av_model
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
from comfy.ldm.cascade.stage_c import StageC
|
from comfy.ldm.cascade.stage_c import StageC
|
||||||
from comfy.ldm.cascade.stage_b import StageB
|
from comfy.ldm.cascade.stage_b import StageB
|
||||||
@ -946,7 +947,7 @@ class GenmoMochi(BaseModel):
|
|||||||
|
|
||||||
class LTXV(BaseModel):
|
class LTXV(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel)
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
@ -977,6 +978,60 @@ class LTXV(BaseModel):
|
|||||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
return latent_image
|
return latent_image
|
||||||
|
|
||||||
|
class LTXAV(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||||
|
|
||||||
|
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
|
|
||||||
|
audio_denoise_mask = None
|
||||||
|
if denoise_mask is not None and "latent_shapes" in kwargs:
|
||||||
|
denoise_mask = utils.unpack_latents(denoise_mask, kwargs["latent_shapes"])
|
||||||
|
if len(denoise_mask) > 1:
|
||||||
|
audio_denoise_mask = denoise_mask[1]
|
||||||
|
denoise_mask = denoise_mask[0]
|
||||||
|
|
||||||
|
if denoise_mask is not None:
|
||||||
|
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||||
|
|
||||||
|
if audio_denoise_mask is not None:
|
||||||
|
out["audio_denoise_mask"] = comfy.conds.CONDRegular(audio_denoise_mask)
|
||||||
|
|
||||||
|
keyframe_idxs = kwargs.get("keyframe_idxs", None)
|
||||||
|
if keyframe_idxs is not None:
|
||||||
|
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||||
|
|
||||||
|
latent_shapes = kwargs.get("latent_shapes", None)
|
||||||
|
if latent_shapes is not None:
|
||||||
|
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||||
|
v_timestep = timestep
|
||||||
|
a_timestep = timestep
|
||||||
|
|
||||||
|
if denoise_mask is not None:
|
||||||
|
v_timestep = self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
|
||||||
|
if audio_denoise_mask is not None:
|
||||||
|
a_timestep = self.diffusion_model.a_patchifier.patchify(((audio_denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (audio_denoise_mask.ndim - 1)))[:, :1, :, :1])[0]
|
||||||
|
|
||||||
|
return v_timestep, a_timestep
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
|
return latent_image
|
||||||
|
|
||||||
class HunyuanVideo(BaseModel):
|
class HunyuanVideo(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||||
@ -1110,6 +1165,10 @@ class Lumina2(BaseModel):
|
|||||||
if 'num_tokens' not in out:
|
if 'num_tokens' not in out:
|
||||||
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
||||||
|
|
||||||
|
clip_text_pooled = kwargs.get("pooled_output", None) # NewBie
|
||||||
|
if clip_text_pooled is not None:
|
||||||
|
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class WAN21(BaseModel):
|
class WAN21(BaseModel):
|
||||||
|
|||||||
@ -180,8 +180,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["use_cond_type_embedding"] = False
|
dit_config["use_cond_type_embedding"] = False
|
||||||
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
|
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
|
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
|
||||||
|
dit_config["meanflow_sum"] = True
|
||||||
else:
|
else:
|
||||||
dit_config["vision_in_dim"] = None
|
dit_config["vision_in_dim"] = None
|
||||||
|
dit_config["meanflow_sum"] = False
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||||
@ -257,6 +259,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_tile_size"] = 512
|
dit_config["nerf_tile_size"] = 512
|
||||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||||
|
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
||||||
|
dit_config["use_x0"] = True
|
||||||
|
else:
|
||||||
|
dit_config["use_x0"] = False
|
||||||
else:
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||||
@ -299,7 +305,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "ltxv"
|
dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv"
|
||||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
|
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
|
||||||
dit_config["attention_head_dim"] = shape[0] // 32
|
dit_config["attention_head_dim"] = shape[0] // 32
|
||||||
@ -423,6 +429,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["axes_lens"] = [300, 512, 512]
|
dit_config["axes_lens"] = [300, 512, 512]
|
||||||
dit_config["rope_theta"] = 10000.0
|
dit_config["rope_theta"] = 10000.0
|
||||||
dit_config["ffn_dim_multiplier"] = 4.0
|
dit_config["ffn_dim_multiplier"] = 4.0
|
||||||
|
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
||||||
|
if ctd_weight is not None: # NewBie
|
||||||
|
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
||||||
|
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
|
||||||
elif dit_config["dim"] == 3840: # Z image
|
elif dit_config["dim"] == 3840: # Z image
|
||||||
dit_config["n_heads"] = 30
|
dit_config["n_heads"] = 30
|
||||||
dit_config["n_kv_heads"] = 30
|
dit_config["n_kv_heads"] = 30
|
||||||
@ -609,6 +619,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["image_model"] = "qwen_image"
|
dit_config["image_model"] = "qwen_image"
|
||||||
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
|
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
|
||||||
|
dit_config["default_ref_method"] = "index_timestep_zero"
|
||||||
|
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
|
||||||
|
dit_config["use_additional_t_cond"] = True
|
||||||
|
dit_config["default_ref_method"] = "negative_index"
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||||
|
|||||||
@ -26,6 +26,7 @@ import importlib
|
|||||||
import platform
|
import platform
|
||||||
import weakref
|
import weakref
|
||||||
import gc
|
import gc
|
||||||
|
import os
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 #No vram present: no need to move models to vram
|
DISABLED = 0 #No vram present: no need to move models to vram
|
||||||
@ -333,13 +334,15 @@ except:
|
|||||||
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
||||||
|
|
||||||
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
|
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
|
||||||
|
AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_amd():
|
if is_amd():
|
||||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
||||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||||
|
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||||
@ -453,7 +456,7 @@ def module_size(module):
|
|||||||
sd = module.state_dict()
|
sd = module.state_dict()
|
||||||
for k in sd:
|
for k in sd:
|
||||||
t = sd[k]
|
t = sd[k]
|
||||||
module_mem += t.nelement() * t.element_size()
|
module_mem += t.nbytes
|
||||||
return module_mem
|
return module_mem
|
||||||
|
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
@ -1016,8 +1019,8 @@ NUM_STREAMS = 0
|
|||||||
if args.async_offload is not None:
|
if args.async_offload is not None:
|
||||||
NUM_STREAMS = args.async_offload
|
NUM_STREAMS = args.async_offload
|
||||||
else:
|
else:
|
||||||
# Enable by default on Nvidia
|
# Enable by default on Nvidia and AMD
|
||||||
if is_nvidia():
|
if is_nvidia() or is_amd():
|
||||||
NUM_STREAMS = 2
|
NUM_STREAMS = 2
|
||||||
|
|
||||||
if args.disable_async_offload:
|
if args.disable_async_offload:
|
||||||
@ -1123,6 +1126,16 @@ if not args.disable_pinned_memory:
|
|||||||
|
|
||||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||||
|
|
||||||
|
def discard_cuda_async_error():
|
||||||
|
try:
|
||||||
|
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
|
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
|
_ = a + b
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
except torch.AcceleratorError:
|
||||||
|
#Dump it! We already know about it from the synchronous return
|
||||||
|
pass
|
||||||
|
|
||||||
def pin_memory(tensor):
|
def pin_memory(tensor):
|
||||||
global TOTAL_PINNED_MEMORY
|
global TOTAL_PINNED_MEMORY
|
||||||
if MAX_PINNED_MEMORY <= 0:
|
if MAX_PINNED_MEMORY <= 0:
|
||||||
@ -1143,7 +1156,7 @@ def pin_memory(tensor):
|
|||||||
if not tensor.is_contiguous():
|
if not tensor.is_contiguous():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
size = tensor.numel() * tensor.element_size()
|
size = tensor.nbytes
|
||||||
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -1155,6 +1168,9 @@ def pin_memory(tensor):
|
|||||||
PINNED_MEMORY[ptr] = size
|
PINNED_MEMORY[ptr] = size
|
||||||
TOTAL_PINNED_MEMORY += size
|
TOTAL_PINNED_MEMORY += size
|
||||||
return True
|
return True
|
||||||
|
else:
|
||||||
|
logging.warning("Pin error.")
|
||||||
|
discard_cuda_async_error()
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -1167,7 +1183,7 @@ def unpin_memory(tensor):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
ptr = tensor.data_ptr()
|
ptr = tensor.data_ptr()
|
||||||
size = tensor.numel() * tensor.element_size()
|
size = tensor.nbytes
|
||||||
|
|
||||||
size_stored = PINNED_MEMORY.get(ptr, None)
|
size_stored = PINNED_MEMORY.get(ptr, None)
|
||||||
if size_stored is None:
|
if size_stored is None:
|
||||||
@ -1183,6 +1199,9 @@ def unpin_memory(tensor):
|
|||||||
if len(PINNED_MEMORY) == 0:
|
if len(PINNED_MEMORY) == 0:
|
||||||
TOTAL_PINNED_MEMORY = 0
|
TOTAL_PINNED_MEMORY = 0
|
||||||
return True
|
return True
|
||||||
|
else:
|
||||||
|
logging.warning("Unpin error.")
|
||||||
|
discard_cuda_async_error()
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -1485,6 +1504,16 @@ def supports_fp8_compute(device=None):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def supports_nvfp4_compute(device=None):
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
if props.major < 10:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def extended_fp16_support():
|
def extended_fp16_support():
|
||||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||||
if torch_version_numeric < (2, 7):
|
if torch_version_numeric < (2, 7):
|
||||||
@ -1523,6 +1552,10 @@ def soft_empty_cache(force=False):
|
|||||||
def unload_all_models():
|
def unload_all_models():
|
||||||
free_memory(1e30, get_torch_device())
|
free_memory(1e30, get_torch_device())
|
||||||
|
|
||||||
|
def debug_memory_summary():
|
||||||
|
if is_amd() or is_nvidia():
|
||||||
|
return torch.cuda.memory.memory_summary()
|
||||||
|
return ""
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
#TODO: might be cleaner to put this somewhere else
|
||||||
import threading
|
import threading
|
||||||
|
|||||||
@ -35,6 +35,7 @@ import comfy.model_management
|
|||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from comfy.comfy_types import UnetWrapperFunction
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
|
from comfy.quant_ops import QuantizedTensor
|
||||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||||
|
|
||||||
|
|
||||||
@ -132,14 +133,17 @@ class LowVramPatch:
|
|||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||||
|
|
||||||
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
|
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
|
||||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
|
|
||||||
|
|
||||||
def low_vram_patch_estimate_vram(model, key):
|
def low_vram_patch_estimate_vram(model, key):
|
||||||
weight, set_func, convert_func = get_key_weight(model, key)
|
weight, set_func, convert_func = get_key_weight(model, key)
|
||||||
if weight is None:
|
if weight is None:
|
||||||
return 0
|
return 0
|
||||||
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
|
model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
|
||||||
|
if model_dtype is None:
|
||||||
|
model_dtype = weight.dtype
|
||||||
|
|
||||||
|
return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
|
||||||
|
|
||||||
def get_key_weight(model, key):
|
def get_key_weight(model, key):
|
||||||
set_func = None
|
set_func = None
|
||||||
@ -450,6 +454,9 @@ class ModelPatcher:
|
|||||||
def set_model_post_input_patch(self, patch):
|
def set_model_post_input_patch(self, patch):
|
||||||
self.set_model_patch(patch, "post_input")
|
self.set_model_patch(patch, "post_input")
|
||||||
|
|
||||||
|
def set_model_noise_refiner_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "noise_refiner")
|
||||||
|
|
||||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||||
rope_options["scale_x"] = scale_x
|
rope_options["scale_x"] = scale_x
|
||||||
@ -662,12 +669,18 @@ class ModelPatcher:
|
|||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
module_offload_mem = module_mem
|
module_offload_mem = module_mem
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
weight_key = "{}.weight".format(n)
|
def check_module_offload_mem(key):
|
||||||
bias_key = "{}.bias".format(n)
|
if key in self.patches:
|
||||||
if weight_key in self.patches:
|
return low_vram_patch_estimate_vram(self.model, key)
|
||||||
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
|
model_dtype = getattr(self.model, "manual_cast_dtype", None)
|
||||||
if bias_key in self.patches:
|
weight, _, _ = get_key_weight(self.model, key)
|
||||||
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
|
if model_dtype is None or weight is None:
|
||||||
|
return 0
|
||||||
|
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
|
||||||
|
return weight.numel() * model_dtype.itemsize
|
||||||
|
return 0
|
||||||
|
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||||
|
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||||
loading.append((module_offload_mem, module_mem, n, m, params))
|
loading.append((module_offload_mem, module_mem, n, m, params))
|
||||||
return loading
|
return loading
|
||||||
|
|
||||||
@ -920,7 +933,7 @@ class ModelPatcher:
|
|||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
cast_weight = True
|
cast_weight = True
|
||||||
|
|
||||||
if cast_weight:
|
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
m.comfy_patched_weights = False
|
m.comfy_patched_weights = False
|
||||||
|
|||||||
170
comfy/ops.py
170
comfy/ops.py
@ -22,7 +22,6 @@ import comfy.model_management
|
|||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.rmsnorm
|
import comfy.rmsnorm
|
||||||
import contextlib
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
def run_every_op():
|
def run_every_op():
|
||||||
@ -80,7 +79,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
if input is not None:
|
if input is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
if isinstance(input, QuantizedTensor):
|
if isinstance(input, QuantizedTensor):
|
||||||
dtype = input._layout_params["orig_dtype"]
|
dtype = input.params.orig_dtype
|
||||||
else:
|
else:
|
||||||
dtype = input.dtype
|
dtype = input.dtype
|
||||||
if bias_dtype is None:
|
if bias_dtype is None:
|
||||||
@ -94,13 +93,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
else:
|
else:
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
|
|
||||||
if offload_stream is not None:
|
|
||||||
wf_context = offload_stream
|
|
||||||
if hasattr(wf_context, "as_context"):
|
|
||||||
wf_context = wf_context.as_context(offload_stream)
|
|
||||||
else:
|
|
||||||
wf_context = contextlib.nullcontext()
|
|
||||||
|
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
|
|
||||||
weight_has_function = len(s.weight_function) > 0
|
weight_has_function = len(s.weight_function) > 0
|
||||||
@ -420,26 +412,34 @@ def fp8_linear(self, input):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
input_dtype = input.dtype
|
input_dtype = input.dtype
|
||||||
|
input_shape = input.shape
|
||||||
|
tensor_3d = input.ndim == 3
|
||||||
|
|
||||||
if input.ndim == 3 or input.ndim == 2:
|
if tensor_3d:
|
||||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
input = input.reshape(-1, input_shape[2])
|
||||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
|
||||||
|
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
if input.ndim != 2:
|
||||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
return None
|
||||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
|
|
||||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
input_fp8 = input.to(dtype).contiguous()
|
||||||
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
|
||||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input)
|
||||||
|
|
||||||
uncast_bias_weight(self, w, bias, offload_stream)
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
return o
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
|
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
|
||||||
|
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||||
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
|
|
||||||
return None
|
uncast_bias_weight(self, w, bias, offload_stream)
|
||||||
|
if tensor_3d:
|
||||||
|
o = o.reshape((input_shape[0], input_shape[1], w.shape[0]))
|
||||||
|
|
||||||
|
return o
|
||||||
|
|
||||||
class fp8_ops(manual_cast):
|
class fp8_ops(manual_cast):
|
||||||
class Linear(manual_cast.Linear):
|
class Linear(manual_cast.Linear):
|
||||||
@ -485,14 +485,20 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Mixed Precision Operations
|
# Mixed Precision Operations
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
from .quant_ops import (
|
||||||
|
QuantizedTensor,
|
||||||
|
QUANT_ALGOS,
|
||||||
|
TensorCoreFP8Layout,
|
||||||
|
get_layout_class,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||||
class MixedPrecisionOps(manual_cast):
|
class MixedPrecisionOps(manual_cast):
|
||||||
_quant_config = quant_config
|
_quant_config = quant_config
|
||||||
_compute_dtype = compute_dtype
|
_compute_dtype = compute_dtype
|
||||||
_full_precision_mm = full_precision_mm
|
_full_precision_mm = full_precision_mm
|
||||||
|
_disabled = disabled
|
||||||
|
|
||||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -517,10 +523,21 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
self.tensor_class = None
|
self.tensor_class = None
|
||||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||||
|
self._full_precision_mm_config = False
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
|
||||||
|
key = f"{prefix}{param_name}"
|
||||||
|
value = state_dict.pop(key, None)
|
||||||
|
if value is not None:
|
||||||
|
value = value.to(device=device)
|
||||||
|
if dtype is not None:
|
||||||
|
value = value.view(dtype=dtype)
|
||||||
|
manually_loaded_keys.append(key)
|
||||||
|
return value
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
strict, missing_keys, unexpected_keys, error_msgs):
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
||||||
@ -541,34 +558,58 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||||
else:
|
else:
|
||||||
self.quant_format = layer_conf.get("format", None)
|
self.quant_format = layer_conf.get("format", None)
|
||||||
|
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||||
if not self._full_precision_mm:
|
if not self._full_precision_mm:
|
||||||
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
|
self._full_precision_mm = self._full_precision_mm_config
|
||||||
|
|
||||||
|
if self.quant_format in MixedPrecisionOps._disabled:
|
||||||
|
self._full_precision_mm = True
|
||||||
|
|
||||||
if self.quant_format is None:
|
if self.quant_format is None:
|
||||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||||
|
|
||||||
qconfig = QUANT_ALGOS[self.quant_format]
|
qconfig = QUANT_ALGOS[self.quant_format]
|
||||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||||
|
layout_cls = get_layout_class(self.layout_type)
|
||||||
|
|
||||||
weight_scale_key = f"{prefix}weight_scale"
|
# Load format-specific parameters
|
||||||
scale = state_dict.pop(weight_scale_key, None)
|
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
|
||||||
if scale is not None:
|
# FP8: single tensor scale
|
||||||
scale = scale.to(device)
|
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||||
layout_params = {
|
|
||||||
'scale': scale,
|
|
||||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
|
||||||
'block_size': qconfig.get("group_size", None),
|
|
||||||
}
|
|
||||||
|
|
||||||
if scale is not None:
|
params = layout_cls.Params(
|
||||||
manually_loaded_keys.append(weight_scale_key)
|
scale=scale,
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=(self.out_features, self.in_features),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.quant_format == "nvfp4":
|
||||||
|
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||||
|
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||||
|
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||||
|
dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
if tensor_scale is None or block_scale is None:
|
||||||
|
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||||
|
|
||||||
|
params = layout_cls.Params(
|
||||||
|
scale=tensor_scale,
|
||||||
|
block_scale=block_scale,
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=(self.out_features, self.in_features),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||||
|
|
||||||
self.weight = torch.nn.Parameter(
|
self.weight = torch.nn.Parameter(
|
||||||
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
|
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||||
requires_grad=False
|
requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
for param_name in qconfig["parameters"]:
|
for param_name in qconfig["parameters"]:
|
||||||
|
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||||
|
continue # Already handled above
|
||||||
|
|
||||||
param_key = f"{prefix}{param_name}"
|
param_key = f"{prefix}{param_name}"
|
||||||
_v = state_dict.pop(param_key, None)
|
_v = state_dict.pop(param_key, None)
|
||||||
if _v is None:
|
if _v is None:
|
||||||
@ -585,11 +626,19 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||||
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
|
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
|
||||||
if isinstance(self.weight, QuantizedTensor):
|
if isinstance(self.weight, QuantizedTensor):
|
||||||
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
|
layout_cls = self.weight._layout_cls
|
||||||
|
|
||||||
|
# Check if it's any FP8 variant (E4M3 or E5M2)
|
||||||
|
if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"):
|
||||||
|
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
|
||||||
|
elif layout_cls == "TensorCoreNVFP4Layout":
|
||||||
|
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale
|
||||||
|
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
|
||||||
|
|
||||||
quant_conf = {"format": self.quant_format}
|
quant_conf = {"format": self.quant_format}
|
||||||
if self._full_precision_mm:
|
if self._full_precision_mm_config:
|
||||||
quant_conf["full_precision_matrix_mult"] = True
|
quant_conf["full_precision_matrix_mult"] = True
|
||||||
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
|
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def _forward(self, input, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
@ -604,12 +653,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
def forward(self, input, *args, **kwargs):
|
def forward(self, input, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
|
|
||||||
|
input_shape = input.shape
|
||||||
|
tensor_3d = input.ndim == 3
|
||||||
|
|
||||||
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||||
|
|
||||||
if (getattr(self, 'layout_type', None) is not None and
|
if (getattr(self, 'layout_type', None) is not None and
|
||||||
not isinstance(input, QuantizedTensor)):
|
not isinstance(input, QuantizedTensor)):
|
||||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
|
|
||||||
return self._forward(input, self.weight, self.bias)
|
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||||
|
if tensor_3d:
|
||||||
|
input = input.reshape(-1, input_shape[2])
|
||||||
|
|
||||||
|
if input.ndim != 2:
|
||||||
|
# Fall back to comfy_cast_weights for non-2D tensors
|
||||||
|
return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs)
|
||||||
|
|
||||||
|
# dtype is now implicit in the layout class
|
||||||
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None))
|
||||||
|
|
||||||
|
output = self._forward(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
# Reshape output back to 3D if input was 3D
|
||||||
|
if tensor_3d:
|
||||||
|
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||||
if isinstance(weight, QuantizedTensor):
|
if isinstance(weight, QuantizedTensor):
|
||||||
@ -619,7 +689,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||||
if getattr(self, 'layout_type', None) is not None:
|
if getattr(self, 'layout_type', None) is not None:
|
||||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
# dtype is now implicit in the layout class
|
||||||
|
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
|
||||||
else:
|
else:
|
||||||
weight = weight.to(self.weight.dtype)
|
weight = weight.to(self.weight.dtype)
|
||||||
if return_weight:
|
if return_weight:
|
||||||
@ -646,10 +717,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||||
|
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
||||||
|
|
||||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||||
logging.info("Using mixed precision operations")
|
logging.info("Using mixed precision operations")
|
||||||
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
disabled = set()
|
||||||
|
if not nvfp4_compute:
|
||||||
|
disabled.add("nvfp4")
|
||||||
|
if not fp8_compute:
|
||||||
|
disabled.add("float8_e4m3fn")
|
||||||
|
disabled.add("float8_e5m2")
|
||||||
|
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
fp8_compute and
|
fp8_compute and
|
||||||
|
|||||||
@ -1,577 +1,140 @@
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
from typing import Tuple, Dict
|
|
||||||
|
try:
|
||||||
|
import comfy_kitchen as ck
|
||||||
|
from comfy_kitchen.tensor import (
|
||||||
|
QuantizedTensor,
|
||||||
|
QuantizedLayout,
|
||||||
|
TensorCoreFP8Layout as _CKFp8Layout,
|
||||||
|
TensorCoreNVFP4Layout, # Direct import, no wrapper needed
|
||||||
|
register_layout_op,
|
||||||
|
register_layout_class,
|
||||||
|
get_layout_class,
|
||||||
|
)
|
||||||
|
_CK_AVAILABLE = True
|
||||||
|
if torch.version.cuda is None:
|
||||||
|
ck.registry.disable("cuda")
|
||||||
|
else:
|
||||||
|
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
|
||||||
|
if cuda_version < (13,):
|
||||||
|
ck.registry.disable("cuda")
|
||||||
|
|
||||||
|
ck.registry.disable("triton")
|
||||||
|
for k, v in ck.list_backends().items():
|
||||||
|
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
||||||
|
except ImportError as e:
|
||||||
|
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
|
||||||
|
_CK_AVAILABLE = False
|
||||||
|
|
||||||
|
class QuantizedTensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class _CKFp8Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TensorCoreNVFP4Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def register_layout_class(name, cls):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_layout_class(name):
|
||||||
|
return None
|
||||||
|
|
||||||
import comfy.float
|
import comfy.float
|
||||||
|
|
||||||
_LAYOUT_REGISTRY = {}
|
|
||||||
_GENERIC_UTILS = {}
|
|
||||||
|
|
||||||
|
|
||||||
def register_layout_op(torch_op, layout_type):
|
|
||||||
"""
|
|
||||||
Decorator to register a layout-specific operation handler.
|
|
||||||
Args:
|
|
||||||
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
|
|
||||||
layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
|
||||||
Example:
|
|
||||||
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
|
||||||
def fp8_linear(func, args, kwargs):
|
|
||||||
# FP8-specific linear implementation
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
def decorator(handler_func):
|
|
||||||
if torch_op not in _LAYOUT_REGISTRY:
|
|
||||||
_LAYOUT_REGISTRY[torch_op] = {}
|
|
||||||
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
|
|
||||||
return handler_func
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def register_generic_util(torch_op):
|
|
||||||
"""
|
|
||||||
Decorator to register a generic utility that works for all layouts.
|
|
||||||
Args:
|
|
||||||
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
|
||||||
|
|
||||||
Example:
|
|
||||||
@register_generic_util(torch.ops.aten.detach.default)
|
|
||||||
def generic_detach(func, args, kwargs):
|
|
||||||
# Works for any layout
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
def decorator(handler_func):
|
|
||||||
_GENERIC_UTILS[torch_op] = handler_func
|
|
||||||
return handler_func
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def _get_layout_from_args(args):
|
|
||||||
for arg in args:
|
|
||||||
if isinstance(arg, QuantizedTensor):
|
|
||||||
return arg._layout_type
|
|
||||||
elif isinstance(arg, (list, tuple)):
|
|
||||||
for item in arg:
|
|
||||||
if isinstance(item, QuantizedTensor):
|
|
||||||
return item._layout_type
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _move_layout_params_to_device(params, device):
|
|
||||||
new_params = {}
|
|
||||||
for k, v in params.items():
|
|
||||||
if isinstance(v, torch.Tensor):
|
|
||||||
new_params[k] = v.to(device=device)
|
|
||||||
else:
|
|
||||||
new_params[k] = v
|
|
||||||
return new_params
|
|
||||||
|
|
||||||
|
|
||||||
def _copy_layout_params(params):
|
|
||||||
new_params = {}
|
|
||||||
for k, v in params.items():
|
|
||||||
if isinstance(v, torch.Tensor):
|
|
||||||
new_params[k] = v.clone()
|
|
||||||
else:
|
|
||||||
new_params[k] = v
|
|
||||||
return new_params
|
|
||||||
|
|
||||||
def _copy_layout_params_inplace(src, dst, non_blocking=False):
|
|
||||||
for k, v in src.items():
|
|
||||||
if isinstance(v, torch.Tensor):
|
|
||||||
dst[k].copy_(v, non_blocking=non_blocking)
|
|
||||||
else:
|
|
||||||
dst[k] = v
|
|
||||||
|
|
||||||
class QuantizedLayout:
|
|
||||||
"""
|
|
||||||
Base class for quantization layouts.
|
|
||||||
|
|
||||||
A layout encapsulates the format-specific logic for quantization/dequantization
|
|
||||||
and provides a uniform interface for extracting raw tensors needed for computation.
|
|
||||||
|
|
||||||
New quantization formats should subclass this and implement the required methods.
|
|
||||||
"""
|
|
||||||
@classmethod
|
|
||||||
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
|
|
||||||
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
|
||||||
raise NotImplementedError("TensorLayout must implement dequantize()")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
|
||||||
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
|
||||||
|
|
||||||
|
|
||||||
class QuantizedTensor(torch.Tensor):
|
|
||||||
"""
|
|
||||||
Universal quantized tensor that works with any layout.
|
|
||||||
|
|
||||||
This tensor subclass uses a pluggable layout system to support multiple
|
|
||||||
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
|
||||||
|
|
||||||
The layout_type determines format-specific behavior, while common operations
|
|
||||||
(detach, clone, to) are handled generically.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
_qdata: The quantized tensor data
|
|
||||||
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
|
||||||
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __new__(cls, qdata, layout_type, layout_params):
|
|
||||||
"""
|
|
||||||
Create a quantized tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
qdata: The quantized data tensor
|
|
||||||
layout_type: Layout class (subclass of QuantizedLayout)
|
|
||||||
layout_params: Dict with layout-specific parameters
|
|
||||||
"""
|
|
||||||
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
|
||||||
|
|
||||||
def __init__(self, qdata, layout_type, layout_params):
|
|
||||||
self._qdata = qdata
|
|
||||||
self._layout_type = layout_type
|
|
||||||
self._layout_params = layout_params
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
layout_name = self._layout_type
|
|
||||||
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
|
||||||
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def layout_type(self):
|
|
||||||
return self._layout_type
|
|
||||||
|
|
||||||
def __tensor_flatten__(self):
|
|
||||||
"""
|
|
||||||
Tensor flattening protocol for proper device movement.
|
|
||||||
"""
|
|
||||||
inner_tensors = ["_qdata"]
|
|
||||||
ctx = {
|
|
||||||
"layout_type": self._layout_type,
|
|
||||||
}
|
|
||||||
|
|
||||||
tensor_params = {}
|
|
||||||
non_tensor_params = {}
|
|
||||||
for k, v in self._layout_params.items():
|
|
||||||
if isinstance(v, torch.Tensor):
|
|
||||||
tensor_params[k] = v
|
|
||||||
else:
|
|
||||||
non_tensor_params[k] = v
|
|
||||||
|
|
||||||
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
|
||||||
ctx["non_tensor_params"] = non_tensor_params
|
|
||||||
|
|
||||||
for k, v in tensor_params.items():
|
|
||||||
attr_name = f"_layout_param_{k}"
|
|
||||||
object.__setattr__(self, attr_name, v)
|
|
||||||
inner_tensors.append(attr_name)
|
|
||||||
|
|
||||||
return inner_tensors, ctx
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
|
||||||
"""
|
|
||||||
Tensor unflattening protocol for proper device movement.
|
|
||||||
Reconstructs the QuantizedTensor after device movement.
|
|
||||||
"""
|
|
||||||
layout_type = ctx["layout_type"]
|
|
||||||
layout_params = dict(ctx["non_tensor_params"])
|
|
||||||
|
|
||||||
for key in ctx["tensor_param_keys"]:
|
|
||||||
attr_name = f"_layout_param_{key}"
|
|
||||||
layout_params[key] = inner_tensors[attr_name]
|
|
||||||
|
|
||||||
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
|
||||||
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
|
|
||||||
return cls(qdata, layout_type, layout_params)
|
|
||||||
|
|
||||||
def dequantize(self) -> torch.Tensor:
|
|
||||||
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
||||||
kwargs = kwargs or {}
|
|
||||||
|
|
||||||
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
|
||||||
if func in _GENERIC_UTILS:
|
|
||||||
return _GENERIC_UTILS[func](func, args, kwargs)
|
|
||||||
|
|
||||||
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
|
||||||
layout_type = _get_layout_from_args(args)
|
|
||||||
if layout_type and func in _LAYOUT_REGISTRY:
|
|
||||||
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
|
||||||
if handler:
|
|
||||||
return handler(func, args, kwargs)
|
|
||||||
|
|
||||||
# Step 3: Fallback to dequantization
|
|
||||||
if isinstance(args[0] if args else None, QuantizedTensor):
|
|
||||||
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
|
||||||
return cls._dequant_and_fallback(func, args, kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _dequant_and_fallback(cls, func, args, kwargs):
|
|
||||||
def dequant_arg(arg):
|
|
||||||
if isinstance(arg, QuantizedTensor):
|
|
||||||
return arg.dequantize()
|
|
||||||
elif isinstance(arg, (list, tuple)):
|
|
||||||
return type(arg)(dequant_arg(a) for a in arg)
|
|
||||||
return arg
|
|
||||||
|
|
||||||
new_args = dequant_arg(args)
|
|
||||||
new_kwargs = dequant_arg(kwargs)
|
|
||||||
return func(*new_args, **new_kwargs)
|
|
||||||
|
|
||||||
def data_ptr(self):
|
|
||||||
return self._qdata.data_ptr()
|
|
||||||
|
|
||||||
def is_pinned(self):
|
|
||||||
return self._qdata.is_pinned()
|
|
||||||
|
|
||||||
def is_contiguous(self, *arg, **kwargs):
|
|
||||||
return self._qdata.is_contiguous(*arg, **kwargs)
|
|
||||||
|
|
||||||
def storage(self):
|
|
||||||
return self._qdata.storage()
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Generic Utilities (Layout-Agnostic Operations)
|
# FP8 Layouts with Comfy-Specific Extensions
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
def _create_transformed_qtensor(qt, transform_fn):
|
class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||||
new_data = transform_fn(qt._qdata)
|
FP8_DTYPE = None # Must be overridden in subclass
|
||||||
new_params = _copy_layout_params(qt._layout_params)
|
|
||||||
return QuantizedTensor(new_data, qt._layout_type, new_params)
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
|
||||||
if target_layout is not None and target_layout != torch.strided:
|
|
||||||
logging.warning(
|
|
||||||
f"QuantizedTensor: layout change requested to {target_layout}, "
|
|
||||||
f"but not supported. Ignoring layout."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle device transfer
|
|
||||||
current_device = qt._qdata.device
|
|
||||||
if target_device is not None:
|
|
||||||
# Normalize device for comparison
|
|
||||||
if isinstance(target_device, str):
|
|
||||||
target_device = torch.device(target_device)
|
|
||||||
if isinstance(current_device, str):
|
|
||||||
current_device = torch.device(current_device)
|
|
||||||
|
|
||||||
if target_device != current_device:
|
|
||||||
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
|
||||||
new_q_data = qt._qdata.to(device=target_device)
|
|
||||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
|
||||||
if target_dtype is not None:
|
|
||||||
new_params["orig_dtype"] = target_dtype
|
|
||||||
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
|
||||||
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
|
||||||
return new_qt
|
|
||||||
|
|
||||||
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
|
||||||
return qt
|
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten.detach.default)
|
|
||||||
def generic_detach(func, args, kwargs):
|
|
||||||
"""Detach operation - creates a detached copy of the quantized tensor."""
|
|
||||||
qt = args[0]
|
|
||||||
if isinstance(qt, QuantizedTensor):
|
|
||||||
return _create_transformed_qtensor(qt, lambda x: x.detach())
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten.clone.default)
|
|
||||||
def generic_clone(func, args, kwargs):
|
|
||||||
"""Clone operation - creates a deep copy of the quantized tensor."""
|
|
||||||
qt = args[0]
|
|
||||||
if isinstance(qt, QuantizedTensor):
|
|
||||||
return _create_transformed_qtensor(qt, lambda x: x.clone())
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten._to_copy.default)
|
|
||||||
def generic_to_copy(func, args, kwargs):
|
|
||||||
"""Device/dtype transfer operation - handles .to(device) calls."""
|
|
||||||
qt = args[0]
|
|
||||||
if isinstance(qt, QuantizedTensor):
|
|
||||||
return _handle_device_transfer(
|
|
||||||
qt,
|
|
||||||
target_device=kwargs.get('device', None),
|
|
||||||
target_dtype=kwargs.get('dtype', None),
|
|
||||||
op_name="_to_copy"
|
|
||||||
)
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten.to.dtype_layout)
|
|
||||||
def generic_to_dtype_layout(func, args, kwargs):
|
|
||||||
"""Handle .to(device) calls using the dtype_layout variant."""
|
|
||||||
qt = args[0]
|
|
||||||
if isinstance(qt, QuantizedTensor):
|
|
||||||
return _handle_device_transfer(
|
|
||||||
qt,
|
|
||||||
target_device=kwargs.get('device', None),
|
|
||||||
target_dtype=kwargs.get('dtype', None),
|
|
||||||
target_layout=kwargs.get('layout', None),
|
|
||||||
op_name="to"
|
|
||||||
)
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten.copy_.default)
|
|
||||||
def generic_copy_(func, args, kwargs):
|
|
||||||
qt_dest = args[0]
|
|
||||||
src = args[1]
|
|
||||||
non_blocking = args[2] if len(args) > 2 else False
|
|
||||||
if isinstance(qt_dest, QuantizedTensor):
|
|
||||||
if isinstance(src, QuantizedTensor):
|
|
||||||
# Copy from another quantized tensor
|
|
||||||
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
|
||||||
qt_dest._layout_type = src._layout_type
|
|
||||||
orig_dtype = qt_dest._layout_params["orig_dtype"]
|
|
||||||
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
|
||||||
qt_dest._layout_params["orig_dtype"] = orig_dtype
|
|
||||||
else:
|
|
||||||
# Copy from regular tensor - just copy raw data
|
|
||||||
qt_dest._qdata.copy_(src)
|
|
||||||
return qt_dest
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten.to.dtype)
|
|
||||||
def generic_to_dtype(func, args, kwargs):
|
|
||||||
"""Handle .to(dtype) calls - dtype conversion only."""
|
|
||||||
src = args[0]
|
|
||||||
if isinstance(src, QuantizedTensor):
|
|
||||||
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
|
|
||||||
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
|
|
||||||
src._layout_params["orig_dtype"] = target_dtype
|
|
||||||
return src
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
|
||||||
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten.empty_like.default)
|
|
||||||
def generic_empty_like(func, args, kwargs):
|
|
||||||
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
|
|
||||||
qt = args[0]
|
|
||||||
if isinstance(qt, QuantizedTensor):
|
|
||||||
# Create empty tensor with same shape and dtype as the quantized data
|
|
||||||
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
|
|
||||||
new_qdata = torch.empty_like(qt._qdata, **kwargs)
|
|
||||||
|
|
||||||
# Handle device transfer for layout params
|
|
||||||
target_device = kwargs.get('device', new_qdata.device)
|
|
||||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
|
||||||
|
|
||||||
# Update orig_dtype if dtype is specified
|
|
||||||
new_params['orig_dtype'] = hp_dtype
|
|
||||||
|
|
||||||
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# FP8 Layout + Operation Handlers
|
|
||||||
# ==============================================================================
|
|
||||||
class TensorCoreFP8Layout(QuantizedLayout):
|
|
||||||
"""
|
|
||||||
Storage format:
|
|
||||||
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
|
||||||
- scale: Scalar tensor (float32) for dequantization
|
|
||||||
- orig_dtype: Original dtype before quantization (for casting back)
|
|
||||||
"""
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||||
|
if cls.FP8_DTYPE is None:
|
||||||
|
raise NotImplementedError(f"{cls.__name__} must define FP8_DTYPE")
|
||||||
|
|
||||||
orig_dtype = tensor.dtype
|
orig_dtype = tensor.dtype
|
||||||
|
orig_shape = tuple(tensor.shape)
|
||||||
|
|
||||||
if isinstance(scale, str) and scale == "recalculate":
|
if isinstance(scale, str) and scale == "recalculate":
|
||||||
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max
|
||||||
|
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
|
||||||
|
tensor_info = torch.finfo(tensor.dtype)
|
||||||
|
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
|
||||||
|
|
||||||
if scale is not None:
|
if scale is None:
|
||||||
if not isinstance(scale, torch.Tensor):
|
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
||||||
scale = torch.tensor(scale)
|
if not isinstance(scale, torch.Tensor):
|
||||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
if stochastic_rounding > 0:
|
||||||
if inplace_ops:
|
if inplace_ops:
|
||||||
tensor *= (1.0 / scale).to(tensor.dtype)
|
tensor *= (1.0 / scale).to(tensor.dtype)
|
||||||
else:
|
else:
|
||||||
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
||||||
|
qdata = comfy.float.stochastic_rounding(tensor, dtype=cls.FP8_DTYPE, seed=stochastic_rounding)
|
||||||
else:
|
else:
|
||||||
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE)
|
||||||
|
|
||||||
if stochastic_rounding > 0:
|
params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape)
|
||||||
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
return qdata, params
|
||||||
else:
|
|
||||||
lp_amax = torch.finfo(dtype).max
|
|
||||||
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
|
|
||||||
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
|
|
||||||
|
|
||||||
layout_params = {
|
|
||||||
'scale': scale,
|
|
||||||
'orig_dtype': orig_dtype
|
|
||||||
}
|
|
||||||
return tensor, layout_params
|
|
||||||
|
|
||||||
@staticmethod
|
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
|
||||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
FP8_DTYPE = torch.float8_e4m3fn
|
||||||
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
|
||||||
plain_tensor.mul_(scale)
|
|
||||||
return plain_tensor
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_plain_tensors(cls, qtensor):
|
class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
|
||||||
return qtensor._qdata, qtensor._layout_params['scale']
|
FP8_DTYPE = torch.float8_e5m2
|
||||||
|
|
||||||
|
|
||||||
|
# Backward compatibility alias - default to E4M3
|
||||||
|
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Registry
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
||||||
|
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||||
|
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||||
|
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||||
|
|
||||||
QUANT_ALGOS = {
|
QUANT_ALGOS = {
|
||||||
"float8_e4m3fn": {
|
"float8_e4m3fn": {
|
||||||
"storage_t": torch.float8_e4m3fn,
|
"storage_t": torch.float8_e4m3fn,
|
||||||
"parameters": {"weight_scale", "input_scale"},
|
"parameters": {"weight_scale", "input_scale"},
|
||||||
"comfy_tensor_layout": "TensorCoreFP8Layout",
|
"comfy_tensor_layout": "TensorCoreFP8E4M3Layout",
|
||||||
|
},
|
||||||
|
"float8_e5m2": {
|
||||||
|
"storage_t": torch.float8_e5m2,
|
||||||
|
"parameters": {"weight_scale", "input_scale"},
|
||||||
|
"comfy_tensor_layout": "TensorCoreFP8E5M2Layout",
|
||||||
|
},
|
||||||
|
"nvfp4": {
|
||||||
|
"storage_t": torch.uint8,
|
||||||
|
"parameters": {"weight_scale", "weight_scale_2", "input_scale"},
|
||||||
|
"comfy_tensor_layout": "TensorCoreNVFP4Layout",
|
||||||
|
"group_size": 16,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
LAYOUTS = {
|
|
||||||
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Re-exports for backward compatibility
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
|
__all__ = [
|
||||||
def fp8_linear(func, args, kwargs):
|
"QuantizedTensor",
|
||||||
input_tensor = args[0]
|
"QuantizedLayout",
|
||||||
weight = args[1]
|
"TensorCoreFP8Layout",
|
||||||
bias = args[2] if len(args) > 2 else None
|
"TensorCoreFP8E4M3Layout",
|
||||||
|
"TensorCoreFP8E5M2Layout",
|
||||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
"TensorCoreNVFP4Layout",
|
||||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
"QUANT_ALGOS",
|
||||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
"register_layout_op",
|
||||||
|
]
|
||||||
out_dtype = kwargs.get("out_dtype")
|
|
||||||
if out_dtype is None:
|
|
||||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
|
||||||
|
|
||||||
weight_t = plain_weight.t()
|
|
||||||
|
|
||||||
tensor_2d = False
|
|
||||||
if len(plain_input.shape) == 2:
|
|
||||||
tensor_2d = True
|
|
||||||
plain_input = plain_input.unsqueeze(1)
|
|
||||||
|
|
||||||
input_shape = plain_input.shape
|
|
||||||
if len(input_shape) != 3:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
output = torch._scaled_mm(
|
|
||||||
plain_input.reshape(-1, input_shape[2]).contiguous(),
|
|
||||||
weight_t,
|
|
||||||
bias=bias,
|
|
||||||
scale_a=scale_a,
|
|
||||||
scale_b=scale_b,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
|
||||||
output = output[0]
|
|
||||||
|
|
||||||
if not tensor_2d:
|
|
||||||
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
|
||||||
|
|
||||||
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
||||||
output_scale = scale_a * scale_b
|
|
||||||
output_params = {
|
|
||||||
'scale': output_scale,
|
|
||||||
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
|
||||||
}
|
|
||||||
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
|
|
||||||
else:
|
|
||||||
return output
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
|
||||||
|
|
||||||
# Case 2: DQ Fallback
|
|
||||||
if isinstance(weight, QuantizedTensor):
|
|
||||||
weight = weight.dequantize()
|
|
||||||
if isinstance(input_tensor, QuantizedTensor):
|
|
||||||
input_tensor = input_tensor.dequantize()
|
|
||||||
|
|
||||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
|
||||||
|
|
||||||
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
|
|
||||||
if out_dtype is None:
|
|
||||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
|
||||||
|
|
||||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
|
||||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
|
||||||
|
|
||||||
output = torch._scaled_mm(
|
|
||||||
plain_input.contiguous(),
|
|
||||||
plain_weight,
|
|
||||||
bias=bias,
|
|
||||||
scale_a=scale_a,
|
|
||||||
scale_b=scale_b,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
|
||||||
output = output[0]
|
|
||||||
return output
|
|
||||||
|
|
||||||
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
|
|
||||||
def fp8_addmm(func, args, kwargs):
|
|
||||||
input_tensor = args[1]
|
|
||||||
weight = args[2]
|
|
||||||
bias = args[0]
|
|
||||||
|
|
||||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
|
||||||
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
|
|
||||||
|
|
||||||
a = list(args)
|
|
||||||
if isinstance(args[0], QuantizedTensor):
|
|
||||||
a[0] = args[0].dequantize()
|
|
||||||
if isinstance(args[1], QuantizedTensor):
|
|
||||||
a[1] = args[1].dequantize()
|
|
||||||
if isinstance(args[2], QuantizedTensor):
|
|
||||||
a[2] = args[2].dequantize()
|
|
||||||
|
|
||||||
return func(*a, **kwargs)
|
|
||||||
|
|
||||||
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
|
|
||||||
def fp8_mm(func, args, kwargs):
|
|
||||||
input_tensor = args[0]
|
|
||||||
weight = args[1]
|
|
||||||
|
|
||||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
|
||||||
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
|
|
||||||
|
|
||||||
a = list(args)
|
|
||||||
if isinstance(args[0], QuantizedTensor):
|
|
||||||
a[0] = args[0].dequantize()
|
|
||||||
if isinstance(args[1], QuantizedTensor):
|
|
||||||
a[1] = args[1].dequantize()
|
|
||||||
return func(*a, **kwargs)
|
|
||||||
|
|
||||||
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
|
||||||
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
|
||||||
def fp8_func(func, args, kwargs):
|
|
||||||
input_tensor = args[0]
|
|
||||||
if isinstance(input_tensor, QuantizedTensor):
|
|
||||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
|
||||||
ar = list(args)
|
|
||||||
ar[0] = plain_input
|
|
||||||
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|||||||
@ -122,20 +122,20 @@ def estimate_memory(model, noise_shape, conds):
|
|||||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||||
return memory_required, minimum_memory_required
|
return memory_required, minimum_memory_required
|
||||||
|
|
||||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||||
_prepare_sampling,
|
_prepare_sampling,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||||
)
|
)
|
||||||
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
|
||||||
|
|
||||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||||
real_model: BaseModel = None
|
real_model: BaseModel = None
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
models += get_additional_models_from_model_options(model_options)
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
|
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
|
||||||
real_model = model.model
|
real_model = model.model
|
||||||
|
|
||||||
return real_model, conds, models
|
return real_model, conds, models
|
||||||
|
|||||||
@ -720,7 +720,7 @@ class Sampler:
|
|||||||
sigma = float(sigmas[0])
|
sigma = float(sigmas[0])
|
||||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
@ -984,9 +984,6 @@ class CFGGuider:
|
|||||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
if denoise_mask is not None:
|
|
||||||
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
|
||||||
|
|
||||||
noise = noise.to(device)
|
noise = noise.to(device)
|
||||||
latent_image = latent_image.to(device)
|
latent_image = latent_image.to(device)
|
||||||
sigmas = sigmas.to(device)
|
sigmas = sigmas.to(device)
|
||||||
@ -1013,6 +1010,24 @@ class CFGGuider:
|
|||||||
else:
|
else:
|
||||||
latent_shapes = [latent_image.shape]
|
latent_shapes = [latent_image.shape]
|
||||||
|
|
||||||
|
if denoise_mask is not None:
|
||||||
|
if denoise_mask.is_nested:
|
||||||
|
denoise_masks = denoise_mask.unbind()
|
||||||
|
denoise_masks = denoise_masks[:len(latent_shapes)]
|
||||||
|
else:
|
||||||
|
denoise_masks = [denoise_mask]
|
||||||
|
|
||||||
|
for i in range(len(denoise_masks), len(latent_shapes)):
|
||||||
|
denoise_masks.append(torch.ones(latent_shapes[i]))
|
||||||
|
|
||||||
|
for i in range(len(denoise_masks)):
|
||||||
|
denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
|
||||||
|
|
||||||
|
if len(denoise_masks) > 1:
|
||||||
|
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
|
||||||
|
else:
|
||||||
|
denoise_mask = denoise_masks[0]
|
||||||
|
|
||||||
self.conds = {}
|
self.conds = {}
|
||||||
for k in self.original_conds:
|
for k in self.original_conds:
|
||||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||||
|
|||||||
71
comfy/sd.py
71
comfy/sd.py
@ -55,6 +55,8 @@ import comfy.text_encoders.hunyuan_image
|
|||||||
import comfy.text_encoders.z_image
|
import comfy.text_encoders.z_image
|
||||||
import comfy.text_encoders.ovis
|
import comfy.text_encoders.ovis
|
||||||
import comfy.text_encoders.kandinsky5
|
import comfy.text_encoders.kandinsky5
|
||||||
|
import comfy.text_encoders.jina_clip_2
|
||||||
|
import comfy.text_encoders.newbie
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -127,6 +129,8 @@ class CLIP:
|
|||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
|
#Match torch.float32 hardcode upcast in TE implemention
|
||||||
|
self.patcher.set_model_compute_dtype(torch.float32)
|
||||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||||
self.patcher.is_clip = True
|
self.patcher.is_clip = True
|
||||||
self.apply_hooks_to_conds = None
|
self.apply_hooks_to_conds = None
|
||||||
@ -319,6 +323,7 @@ class VAE:
|
|||||||
self.latent_channels = 4
|
self.latent_channels = 4
|
||||||
self.latent_dim = 2
|
self.latent_dim = 2
|
||||||
self.output_channels = 3
|
self.output_channels = 3
|
||||||
|
self.pad_channel_value = None
|
||||||
self.process_input = lambda image: image * 2.0 - 1.0
|
self.process_input = lambda image: image * 2.0 - 1.0
|
||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
@ -433,6 +438,7 @@ class VAE:
|
|||||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
|
||||||
self.latent_channels = 64
|
self.latent_channels = 64
|
||||||
self.output_channels = 2
|
self.output_channels = 2
|
||||||
|
self.pad_channel_value = "replicate"
|
||||||
self.upscale_ratio = 2048
|
self.upscale_ratio = 2048
|
||||||
self.downscale_ratio = 2048
|
self.downscale_ratio = 2048
|
||||||
self.latent_dim = 1
|
self.latent_dim = 1
|
||||||
@ -544,11 +550,15 @@ class VAE:
|
|||||||
self.downscale_index_formula = (4, 8, 8)
|
self.downscale_index_formula = (4, 8, 8)
|
||||||
self.latent_dim = 3
|
self.latent_dim = 3
|
||||||
self.latent_channels = 16
|
self.latent_channels = 16
|
||||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
self.output_channels = sd["encoder.conv1.weight"].shape[1]
|
||||||
|
self.pad_channel_value = 1.0
|
||||||
|
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
|
||||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype)
|
||||||
|
|
||||||
|
|
||||||
# Hunyuan 3d v2 2.0 & 2.1
|
# Hunyuan 3d v2 2.0 & 2.1
|
||||||
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||||
|
|
||||||
@ -578,6 +588,7 @@ class VAE:
|
|||||||
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
||||||
self.latent_channels = 8
|
self.latent_channels = 8
|
||||||
self.output_channels = 2
|
self.output_channels = 2
|
||||||
|
self.pad_channel_value = "replicate"
|
||||||
self.upscale_ratio = 4096
|
self.upscale_ratio = 4096
|
||||||
self.downscale_ratio = 4096
|
self.downscale_ratio = 4096
|
||||||
self.latent_dim = 2
|
self.latent_dim = 2
|
||||||
@ -686,17 +697,28 @@ class VAE:
|
|||||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||||
|
|
||||||
def vae_encode_crop_pixels(self, pixels):
|
def vae_encode_crop_pixels(self, pixels):
|
||||||
if not self.crop_input:
|
if self.crop_input:
|
||||||
return pixels
|
downscale_ratio = self.spacial_compression_encode()
|
||||||
|
|
||||||
downscale_ratio = self.spacial_compression_encode()
|
dims = pixels.shape[1:-1]
|
||||||
|
for d in range(len(dims)):
|
||||||
|
x = (dims[d] // downscale_ratio) * downscale_ratio
|
||||||
|
x_offset = (dims[d] % downscale_ratio) // 2
|
||||||
|
if x != dims[d]:
|
||||||
|
pixels = pixels.narrow(d + 1, x_offset, x)
|
||||||
|
|
||||||
dims = pixels.shape[1:-1]
|
if pixels.shape[-1] > self.output_channels:
|
||||||
for d in range(len(dims)):
|
pixels = pixels[..., :self.output_channels]
|
||||||
x = (dims[d] // downscale_ratio) * downscale_ratio
|
elif pixels.shape[-1] < self.output_channels:
|
||||||
x_offset = (dims[d] % downscale_ratio) // 2
|
if self.pad_channel_value is not None:
|
||||||
if x != dims[d]:
|
if isinstance(self.pad_channel_value, str):
|
||||||
pixels = pixels.narrow(d + 1, x_offset, x)
|
mode = self.pad_channel_value
|
||||||
|
value = None
|
||||||
|
else:
|
||||||
|
mode = "constant"
|
||||||
|
value = self.pad_channel_value
|
||||||
|
|
||||||
|
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
|
||||||
return pixels
|
return pixels
|
||||||
|
|
||||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||||
@ -988,6 +1010,7 @@ class CLIPType(Enum):
|
|||||||
OVIS = 21
|
OVIS = 21
|
||||||
KANDINSKY5 = 22
|
KANDINSKY5 = 22
|
||||||
KANDINSKY5_IMAGE = 23
|
KANDINSKY5_IMAGE = 23
|
||||||
|
NEWBIE = 24
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
@ -1018,6 +1041,8 @@ class TEModel(Enum):
|
|||||||
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
||||||
QWEN3_4B = 16
|
QWEN3_4B = 16
|
||||||
QWEN3_2B = 17
|
QWEN3_2B = 17
|
||||||
|
GEMMA_3_12B = 18
|
||||||
|
JINA_CLIP_2 = 19
|
||||||
|
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
@ -1027,6 +1052,8 @@ def detect_te_model(sd):
|
|||||||
return TEModel.CLIP_H
|
return TEModel.CLIP_H
|
||||||
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
|
||||||
return TEModel.CLIP_L
|
return TEModel.CLIP_L
|
||||||
|
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
|
||||||
|
return TEModel.JINA_CLIP_2
|
||||||
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||||
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||||
if weight.shape[-1] == 4096:
|
if weight.shape[-1] == 4096:
|
||||||
@ -1041,6 +1068,8 @@ def detect_te_model(sd):
|
|||||||
return TEModel.BYT5_SMALL_GLYPH
|
return TEModel.BYT5_SMALL_GLYPH
|
||||||
return TEModel.T5_BASE
|
return TEModel.T5_BASE
|
||||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||||
|
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||||
|
return TEModel.GEMMA_3_12B
|
||||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||||
return TEModel.GEMMA_3_4B
|
return TEModel.GEMMA_3_4B
|
||||||
return TEModel.GEMMA_2_2B
|
return TEModel.GEMMA_2_2B
|
||||||
@ -1187,6 +1216,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif te_model == TEModel.QWEN3_2B:
|
elif te_model == TEModel.QWEN3_2B:
|
||||||
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
||||||
|
elif te_model == TEModel.JINA_CLIP_2:
|
||||||
|
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||||
else:
|
else:
|
||||||
# clip_l
|
# clip_l
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
@ -1242,6 +1274,21 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
|
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
|
||||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||||
|
elif clip_type == CLIPType.LTXV:
|
||||||
|
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
|
||||||
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
|
elif clip_type == CLIPType.NEWBIE:
|
||||||
|
clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer
|
||||||
|
if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]:
|
||||||
|
clip_data_gemma = clip_data[0]
|
||||||
|
clip_data_jina = clip_data[1]
|
||||||
|
else:
|
||||||
|
clip_data_gemma = clip_data[1]
|
||||||
|
clip_data_jina = clip_data[0]
|
||||||
|
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
|
||||||
|
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
|
|||||||
@ -466,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
return embed_out
|
return embed_out
|
||||||
|
|
||||||
class SDTokenizer:
|
class SDTokenizer:
|
||||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}):
|
||||||
if tokenizer_path is None:
|
if tokenizer_path is None:
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||||
@ -513,6 +513,8 @@ class SDTokenizer:
|
|||||||
self.embedding_size = embedding_size
|
self.embedding_size = embedding_size
|
||||||
self.embedding_key = embedding_key
|
self.embedding_key = embedding_key
|
||||||
|
|
||||||
|
self.disable_weights = disable_weights
|
||||||
|
|
||||||
def _try_get_embedding(self, embedding_name:str):
|
def _try_get_embedding(self, embedding_name:str):
|
||||||
'''
|
'''
|
||||||
Takes a potential embedding name and tries to retrieve it.
|
Takes a potential embedding name and tries to retrieve it.
|
||||||
@ -547,7 +549,7 @@ class SDTokenizer:
|
|||||||
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
||||||
|
|
||||||
text = escape_important(text)
|
text = escape_important(text)
|
||||||
if kwargs.get("disable_weights", False):
|
if kwargs.get("disable_weights", self.disable_weights):
|
||||||
parsed_weights = [(text, 1.0)]
|
parsed_weights = [(text, 1.0)]
|
||||||
else:
|
else:
|
||||||
parsed_weights = token_weights(text, 1.0)
|
parsed_weights = token_weights(text, 1.0)
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from . import supported_models_base
|
|||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
|
|
||||||
from . import diffusers_convert
|
from . import diffusers_convert
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
class SD15(supported_models_base.BASE):
|
class SD15(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
@ -541,7 +542,7 @@ class SD3(supported_models_base.BASE):
|
|||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.SD3
|
latent_format = latent_formats.SD3
|
||||||
|
|
||||||
memory_usage_factor = 1.2
|
memory_usage_factor = 1.6
|
||||||
|
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
@ -835,6 +836,21 @@ class LTXV(supported_models_base.BASE):
|
|||||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
||||||
|
|
||||||
|
class LTXAV(LTXV):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "ltxav",
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.LTXAV
|
||||||
|
|
||||||
|
def __init__(self, unet_config):
|
||||||
|
super().__init__(unet_config)
|
||||||
|
self.memory_usage_factor = 0.055 # TODO
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.LTXAV(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class HunyuanVideo(supported_models_base.BASE):
|
class HunyuanVideo(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan_video",
|
"image_model": "hunyuan_video",
|
||||||
@ -965,7 +981,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
|
|||||||
|
|
||||||
def __init__(self, unet_config):
|
def __init__(self, unet_config):
|
||||||
super().__init__(unet_config)
|
super().__init__(unet_config)
|
||||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
|
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.CosmosPredict2(self, device=device)
|
out = model_base.CosmosPredict2(self, device=device)
|
||||||
@ -1026,9 +1042,15 @@ class ZImage(Lumina2):
|
|||||||
"shift": 3.0,
|
"shift": 3.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
memory_usage_factor = 1.7
|
memory_usage_factor = 2.0
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
def __init__(self, unet_config):
|
||||||
|
super().__init__(unet_config)
|
||||||
|
if comfy.model_management.extended_fp16_support():
|
||||||
|
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
|
||||||
|
self.supported_inference_dtypes.insert(1, torch.float16)
|
||||||
|
|
||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
@ -1289,7 +1311,7 @@ class ChromaRadiance(Chroma):
|
|||||||
latent_format = comfy.latent_formats.ChromaRadiance
|
latent_format = comfy.latent_formats.ChromaRadiance
|
||||||
|
|
||||||
# Pixel-space model, no spatial compression for model input.
|
# Pixel-space model, no spatial compression for model input.
|
||||||
memory_usage_factor = 0.038
|
memory_usage_factor = 0.044
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
return model_base.ChromaRadiance(self, device=device)
|
return model_base.ChromaRadiance(self, device=device)
|
||||||
@ -1332,7 +1354,7 @@ class Omnigen2(supported_models_base.BASE):
|
|||||||
"shift": 2.6,
|
"shift": 2.6,
|
||||||
}
|
}
|
||||||
|
|
||||||
memory_usage_factor = 1.65 #TODO
|
memory_usage_factor = 1.95 #TODO
|
||||||
|
|
||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.Flux
|
latent_format = latent_formats.Flux
|
||||||
@ -1397,7 +1419,7 @@ class HunyuanImage21(HunyuanVideo):
|
|||||||
|
|
||||||
latent_format = latent_formats.HunyuanImage21
|
latent_format = latent_formats.HunyuanImage21
|
||||||
|
|
||||||
memory_usage_factor = 7.7
|
memory_usage_factor = 8.7
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
@ -1488,7 +1510,7 @@ class Kandinsky5(supported_models_base.BASE):
|
|||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.HunyuanVideo
|
latent_format = latent_formats.HunyuanVideo
|
||||||
|
|
||||||
memory_usage_factor = 1.1 #TODO
|
memory_usage_factor = 1.25 #TODO
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
@ -1517,7 +1539,7 @@ class Kandinsky5Image(Kandinsky5):
|
|||||||
}
|
}
|
||||||
|
|
||||||
latent_format = latent_formats.Flux
|
latent_format = latent_formats.Flux
|
||||||
memory_usage_factor = 1.1 #TODO
|
memory_usage_factor = 1.25 #TODO
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.Kandinsky5Image(self, device=device)
|
out = model_base.Kandinsky5Image(self, device=device)
|
||||||
@ -1529,6 +1551,6 @@ class Kandinsky5Image(Kandinsky5):
|
|||||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -154,7 +154,8 @@ class TAEHV(nn.Module):
|
|||||||
self._show_progress_bar = value
|
self._show_progress_bar = value
|
||||||
|
|
||||||
def encode(self, x, **kwargs):
|
def encode(self, x, **kwargs):
|
||||||
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
|
if self.patch_size > 1:
|
||||||
|
x = F.pixel_unshuffle(x, self.patch_size)
|
||||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||||
if x.shape[1] % 4 != 0:
|
if x.shape[1] % 4 != 0:
|
||||||
# pad at end to multiple of 4
|
# pad at end to multiple of 4
|
||||||
@ -167,5 +168,6 @@ class TAEHV(nn.Module):
|
|||||||
def decode(self, x, **kwargs):
|
def decode(self, x, **kwargs):
|
||||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||||
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
|
if self.patch_size > 1:
|
||||||
|
x = F.pixel_shuffle(x, self.patch_size)
|
||||||
return x[:, self.frames_to_trim:].movedim(2, 1)
|
return x[:, self.frames_to_trim:].movedim(2, 1)
|
||||||
|
|||||||
219
comfy/text_encoders/jina_clip_2.py
Normal file
219
comfy/text_encoders/jina_clip_2.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation:
|
||||||
|
# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py
|
||||||
|
# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.ops
|
||||||
|
from comfy import sd1_clip
|
||||||
|
from .spiece_tokenizer import SPieceTokenizer
|
||||||
|
|
||||||
|
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
|
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
|
||||||
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
|
||||||
|
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
|
||||||
|
|
||||||
|
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
|
||||||
|
@dataclass
|
||||||
|
class XLMRobertaConfig:
|
||||||
|
vocab_size: int = 250002
|
||||||
|
type_vocab_size: int = 1
|
||||||
|
hidden_size: int = 1024
|
||||||
|
num_hidden_layers: int = 24
|
||||||
|
num_attention_heads: int = 16
|
||||||
|
rotary_emb_base: float = 20000.0
|
||||||
|
intermediate_size: int = 4096
|
||||||
|
hidden_act: str = "gelu"
|
||||||
|
hidden_dropout_prob: float = 0.1
|
||||||
|
attention_probs_dropout_prob: float = 0.1
|
||||||
|
layer_norm_eps: float = 1e-05
|
||||||
|
bos_token_id: int = 0
|
||||||
|
eos_token_id: int = 2
|
||||||
|
pad_token_id: int = 1
|
||||||
|
|
||||||
|
class XLMRobertaEmbeddings(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
embed_dim = config.hidden_size
|
||||||
|
self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype)
|
||||||
|
self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, input_ids=None, embeddings=None):
|
||||||
|
if input_ids is not None and embeddings is None:
|
||||||
|
embeddings = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
if embeddings is not None:
|
||||||
|
token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32)
|
||||||
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||||
|
embeddings = embeddings + token_type_embeddings
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, base, device=None):
|
||||||
|
super().__init__()
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
self._seq_len_cached = 0
|
||||||
|
self._cos_cached = None
|
||||||
|
self._sin_cached = None
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
||||||
|
if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
||||||
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
self._cos_cached = emb.cos().to(dtype)
|
||||||
|
self._sin_cached = emb.sin().to(dtype)
|
||||||
|
|
||||||
|
def forward(self, q, k):
|
||||||
|
batch, seqlen, heads, head_dim = q.shape
|
||||||
|
self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
|
cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim)
|
||||||
|
sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim)
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
size = x.shape[-1] // 2
|
||||||
|
x1, x2 = x[..., :size], x[..., size:]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
class MHA(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = embed_dim // config.num_attention_heads
|
||||||
|
|
||||||
|
self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device)
|
||||||
|
self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
|
||||||
|
self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, optimized_attention=None):
|
||||||
|
qkv = self.Wqkv(x)
|
||||||
|
batch_size, seq_len, _ = qkv.shape
|
||||||
|
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||||
|
q, k, v = qkv.unbind(2)
|
||||||
|
|
||||||
|
q, k = self.rotary_emb(q, k)
|
||||||
|
|
||||||
|
# NHD -> HND
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
|
||||||
|
return self.out_proj(out)
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
|
||||||
|
self.activation = F.gelu
|
||||||
|
self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.mixer = MHA(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||||
|
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, mask=None, optimized_attention=None):
|
||||||
|
mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention)
|
||||||
|
hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states)
|
||||||
|
mlp_out = self.mlp(hidden_states)
|
||||||
|
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class XLMRobertaEncoder(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attention_mask=None):
|
||||||
|
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class XLMRobertaModel_(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||||
|
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||||
|
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
|
||||||
|
x = self.emb_ln(x)
|
||||||
|
x = self.emb_drop(x)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if attention_mask is not None:
|
||||||
|
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1]))
|
||||||
|
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
||||||
|
|
||||||
|
sequence_output = self.encoder(x, attention_mask=mask)
|
||||||
|
|
||||||
|
# Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py
|
||||||
|
pooled_output = None
|
||||||
|
if attention_mask is None:
|
||||||
|
pooled_output = sequence_output.mean(dim=1)
|
||||||
|
else:
|
||||||
|
attention_mask = attention_mask.to(sequence_output.dtype)
|
||||||
|
pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Intermediate output is not yet implemented, use None for placeholder
|
||||||
|
return sequence_output, None, pooled_output
|
||||||
|
|
||||||
|
class XLMRobertaModel(nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.config = XLMRobertaConfig(**config_dict)
|
||||||
|
self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.num_layers = self.config.num_hidden_layers
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embeddings.word_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, embeddings):
|
||||||
|
self.model.embeddings.word_embeddings = embeddings
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.model(*args, **kwargs)
|
||||||
|
|
||||||
|
class JinaClip2TextModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||||
|
|
||||||
|
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)
|
||||||
@ -3,13 +3,12 @@ import torch.nn as nn
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
import math
|
import math
|
||||||
import logging
|
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.clip_model
|
||||||
|
|
||||||
import comfy.model_management
|
|
||||||
from . import qwen_vl
|
from . import qwen_vl
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -177,7 +176,7 @@ class Gemma3_4B_Config:
|
|||||||
num_key_value_heads: int = 4
|
num_key_value_heads: int = 4
|
||||||
max_position_embeddings: int = 131072
|
max_position_embeddings: int = 131072
|
||||||
rms_norm_eps: float = 1e-6
|
rms_norm_eps: float = 1e-6
|
||||||
rope_theta = [10000.0, 1000000.0]
|
rope_theta = [1000000.0, 10000.0]
|
||||||
transformer_type: str = "gemma3"
|
transformer_type: str = "gemma3"
|
||||||
head_dim = 256
|
head_dim = 256
|
||||||
rms_norm_add = True
|
rms_norm_add = True
|
||||||
@ -186,10 +185,35 @@ class Gemma3_4B_Config:
|
|||||||
rope_dims = None
|
rope_dims = None
|
||||||
q_norm = "gemma3"
|
q_norm = "gemma3"
|
||||||
k_norm = "gemma3"
|
k_norm = "gemma3"
|
||||||
sliding_attention = [False, False, False, False, False, 1024]
|
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||||
rope_scale = [1.0, 8.0]
|
rope_scale = [8.0, 1.0]
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Gemma3_12B_Config:
|
||||||
|
vocab_size: int = 262208
|
||||||
|
hidden_size: int = 3840
|
||||||
|
intermediate_size: int = 15360
|
||||||
|
num_hidden_layers: int = 48
|
||||||
|
num_attention_heads: int = 16
|
||||||
|
num_key_value_heads: int = 8
|
||||||
|
max_position_embeddings: int = 131072
|
||||||
|
rms_norm_eps: float = 1e-6
|
||||||
|
rope_theta = [1000000.0, 10000.0]
|
||||||
|
transformer_type: str = "gemma3"
|
||||||
|
head_dim = 256
|
||||||
|
rms_norm_add = True
|
||||||
|
mlp_activation = "gelu_pytorch_tanh"
|
||||||
|
qkv_bias = False
|
||||||
|
rope_dims = None
|
||||||
|
q_norm = "gemma3"
|
||||||
|
k_norm = "gemma3"
|
||||||
|
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||||
|
rope_scale = [8.0, 1.0]
|
||||||
|
final_norm: bool = True
|
||||||
|
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
||||||
|
mm_tokens_per_image = 256
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -370,7 +394,7 @@ class TransformerBlockGemma2(nn.Module):
|
|||||||
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
|
||||||
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
|
if config.sliding_attention is not None:
|
||||||
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
|
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
|
||||||
else:
|
else:
|
||||||
self.sliding_attention = False
|
self.sliding_attention = False
|
||||||
@ -387,7 +411,12 @@ class TransformerBlockGemma2(nn.Module):
|
|||||||
if self.transformer_type == 'gemma3':
|
if self.transformer_type == 'gemma3':
|
||||||
if self.sliding_attention:
|
if self.sliding_attention:
|
||||||
if x.shape[1] > self.sliding_attention:
|
if x.shape[1] > self.sliding_attention:
|
||||||
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
|
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
|
||||||
|
sliding_mask.tril_(diagonal=-self.sliding_attention)
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask + sliding_mask
|
||||||
|
else:
|
||||||
|
attention_mask = sliding_mask
|
||||||
freqs_cis = freqs_cis[1]
|
freqs_cis = freqs_cis[1]
|
||||||
else:
|
else:
|
||||||
freqs_cis = freqs_cis[0]
|
freqs_cis = freqs_cis[0]
|
||||||
@ -517,6 +546,41 @@ class Llama2_(nn.Module):
|
|||||||
|
|
||||||
return x, intermediate
|
return x, intermediate
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3MultiModalProjector(torch.nn.Module):
|
||||||
|
def __init__(self, config, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.mm_input_projection_weight = nn.Parameter(
|
||||||
|
torch.empty(config.vision_config["hidden_size"], config.hidden_size, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mm_soft_emb_norm = RMSNorm(config.vision_config["hidden_size"], eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.patches_per_image = int(config.vision_config["image_size"] // config.vision_config["patch_size"])
|
||||||
|
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
|
||||||
|
self.kernel_size = self.patches_per_image // self.tokens_per_side
|
||||||
|
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
|
||||||
|
|
||||||
|
def forward(self, vision_outputs: torch.Tensor):
|
||||||
|
batch_size, _, seq_length = vision_outputs.shape
|
||||||
|
|
||||||
|
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
|
||||||
|
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
|
||||||
|
batch_size, seq_length, self.patches_per_image, self.patches_per_image
|
||||||
|
)
|
||||||
|
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
|
||||||
|
|
||||||
|
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
|
||||||
|
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
|
||||||
|
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
|
||||||
|
|
||||||
|
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
|
||||||
|
|
||||||
|
projected_vision_outputs = torch.matmul(normed_vision_outputs, comfy.model_management.cast_to_device(self.mm_input_projection_weight, device=normed_vision_outputs.device, dtype=normed_vision_outputs.dtype))
|
||||||
|
return projected_vision_outputs.type_as(vision_outputs)
|
||||||
|
|
||||||
|
|
||||||
class BaseLlama:
|
class BaseLlama:
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.model.embed_tokens
|
return self.model.embed_tokens
|
||||||
@ -633,3 +697,21 @@ class Gemma3_4B(BaseLlama, torch.nn.Module):
|
|||||||
|
|
||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
class Gemma3_12B(BaseLlama, torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
config = Gemma3_12B_Config(**config_dict)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
|
||||||
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
|
||||||
|
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.image_size = config.vision_config["image_size"]
|
||||||
|
|
||||||
|
def preprocess_embed(self, embed, device):
|
||||||
|
if embed["type"] == "image":
|
||||||
|
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
|
||||||
|
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
|
||||||
|
return None, None
|
||||||
|
|||||||
@ -1,7 +1,11 @@
|
|||||||
from comfy import sd1_clip
|
from comfy import sd1_clip
|
||||||
import os
|
import os
|
||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
|
from .spiece_tokenizer import SPieceTokenizer
|
||||||
import comfy.text_encoders.genmo
|
import comfy.text_encoders.genmo
|
||||||
|
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||||
|
import torch
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -16,3 +20,112 @@ class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
|
|
||||||
def ltxv_te(*args, **kwargs):
|
def ltxv_te(*args, **kwargs):
|
||||||
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
|
||||||
|
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
|
||||||
|
|
||||||
|
class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs):
|
||||||
|
text = llama_template.format(text)
|
||||||
|
text_tokens = super().tokenize_with_weights(text, return_word_ids)
|
||||||
|
embed_count = 0
|
||||||
|
for k in text_tokens:
|
||||||
|
tt = text_tokens[k]
|
||||||
|
for r in tt:
|
||||||
|
for i in range(len(r)):
|
||||||
|
if r[i][0] == 262144:
|
||||||
|
if image_embeds is not None and embed_count < image_embeds.shape[0]:
|
||||||
|
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
|
||||||
|
embed_count += 1
|
||||||
|
return text_tokens
|
||||||
|
|
||||||
|
class LTXAVTEModel(torch.nn.Module):
|
||||||
|
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__()
|
||||||
|
self.dtypes = set()
|
||||||
|
self.dtypes.add(dtype)
|
||||||
|
|
||||||
|
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
||||||
|
self.dtypes.add(dtype_llama)
|
||||||
|
|
||||||
|
operations = self.gemma3_12b.operations # TODO
|
||||||
|
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||||
|
split_rope=True,
|
||||||
|
double_precision_rope=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.video_embeddings_connector = Embeddings1DConnector(
|
||||||
|
split_rope=True,
|
||||||
|
double_precision_rope=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_clip_options(self, options):
|
||||||
|
self.execution_device = options.get("execution_device", self.execution_device)
|
||||||
|
self.gemma3_12b.set_clip_options(options)
|
||||||
|
|
||||||
|
def reset_clip_options(self):
|
||||||
|
self.gemma3_12b.reset_clip_options()
|
||||||
|
self.execution_device = None
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
||||||
|
|
||||||
|
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
||||||
|
out_device = out.device
|
||||||
|
out = out.movedim(1, -1).to(self.execution_device)
|
||||||
|
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
||||||
|
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||||
|
out = self.text_embedding_projection(out)
|
||||||
|
out_vid = self.video_embeddings_connector(out)[0]
|
||||||
|
out_audio = self.audio_embeddings_connector(out)[0]
|
||||||
|
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||||
|
|
||||||
|
return out.to(out_device), pooled
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||||
|
return self.gemma3_12b.load_sd(sd)
|
||||||
|
else:
|
||||||
|
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
|
||||||
|
if len(sdo) == 0:
|
||||||
|
sdo = sd
|
||||||
|
|
||||||
|
return self.load_state_dict(sdo, strict=False)
|
||||||
|
|
||||||
|
|
||||||
|
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
class LTXAVTEModel_(LTXAVTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return LTXAVTEModel_
|
||||||
|
|||||||
@ -14,7 +14,7 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
|||||||
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
|
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
@ -33,6 +33,11 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
|
|||||||
|
|
||||||
class Gemma3_4BModel(sd1_clip.SDClipModel):
|
class Gemma3_4BModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||||
|
|||||||
62
comfy/text_encoders/newbie.py
Normal file
62
comfy/text_encoders/newbie.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.text_encoders.jina_clip_2
|
||||||
|
import comfy.text_encoders.lumina2
|
||||||
|
|
||||||
|
class NewBieTokenizer:
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
|
||||||
|
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
|
out = {}
|
||||||
|
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
|
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def untokenize(self, token_weight_pair):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
class NewBieTEModel(torch.nn.Module):
|
||||||
|
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__()
|
||||||
|
dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device)
|
||||||
|
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
|
||||||
|
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
self.dtypes = {dtype, dtype_gemma}
|
||||||
|
|
||||||
|
def set_clip_options(self, options):
|
||||||
|
self.gemma.set_clip_options(options)
|
||||||
|
self.jina.set_clip_options(options)
|
||||||
|
|
||||||
|
def reset_clip_options(self):
|
||||||
|
self.gemma.reset_clip_options()
|
||||||
|
self.jina.reset_clip_options()
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
token_weight_pairs_gemma = token_weight_pairs["gemma"]
|
||||||
|
token_weight_pairs_jina = token_weight_pairs["jina"]
|
||||||
|
|
||||||
|
gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma)
|
||||||
|
jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina)
|
||||||
|
|
||||||
|
return gemma_out, jina_pooled, gemma_extra
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "model.layers.0.self_attn.q_norm.weight" in sd:
|
||||||
|
return self.gemma.load_sd(sd)
|
||||||
|
else:
|
||||||
|
return self.jina.load_sd(sd)
|
||||||
|
|
||||||
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
class NewBieTEModel_(NewBieTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
|
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return NewBieTEModel_
|
||||||
@ -53,7 +53,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
|||||||
ALWAYS_SAFE_LOAD = True
|
ALWAYS_SAFE_LOAD = True
|
||||||
logging.info("Checkpoint files will always be loaded safely.")
|
logging.info("Checkpoint files will always be loaded safely.")
|
||||||
else:
|
else:
|
||||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||||
# Try GDS loading first if available and device is GPU
|
# Try GDS loading first if available and device is GPU
|
||||||
@ -815,12 +815,17 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
|||||||
return None
|
return None
|
||||||
return f.read(length_of_header)
|
return f.read(length_of_header)
|
||||||
|
|
||||||
|
ATTR_UNSET={}
|
||||||
|
|
||||||
def set_attr(obj, attr, value):
|
def set_attr(obj, attr, value):
|
||||||
attrs = attr.split(".")
|
attrs = attr.split(".")
|
||||||
for name in attrs[:-1]:
|
for name in attrs[:-1]:
|
||||||
obj = getattr(obj, name)
|
obj = getattr(obj, name)
|
||||||
prev = getattr(obj, attrs[-1])
|
prev = getattr(obj, attrs[-1], ATTR_UNSET)
|
||||||
setattr(obj, attrs[-1], value)
|
if value is ATTR_UNSET:
|
||||||
|
delattr(obj, attrs[-1])
|
||||||
|
else:
|
||||||
|
setattr(obj, attrs[-1], value)
|
||||||
return prev
|
return prev
|
||||||
|
|
||||||
def set_attr_param(obj, attr, value):
|
def set_attr_param(obj, attr, value):
|
||||||
@ -1205,7 +1210,7 @@ def unpack_latents(combined_latent, latent_shapes):
|
|||||||
combined_latent = combined_latent[:, :, cut:]
|
combined_latent = combined_latent[:, :, cut:]
|
||||||
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
|
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
|
||||||
else:
|
else:
|
||||||
output_tensors = combined_latent
|
output_tensors = [combined_latent]
|
||||||
return output_tensors
|
return output_tensors
|
||||||
|
|
||||||
def detect_layer_quantization(state_dict, prefix):
|
def detect_layer_quantization(state_dict, prefix):
|
||||||
@ -1237,6 +1242,8 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
|||||||
out_sd = {}
|
out_sd = {}
|
||||||
layers = {}
|
layers = {}
|
||||||
for k in list(state_dict.keys()):
|
for k in list(state_dict.keys()):
|
||||||
|
if k == scaled_fp8_key:
|
||||||
|
continue
|
||||||
if not k.startswith(model_prefix):
|
if not k.startswith(model_prefix):
|
||||||
out_sd[k] = state_dict[k]
|
out_sd[k] = state_dict[k]
|
||||||
continue
|
continue
|
||||||
@ -1269,6 +1276,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
|||||||
if quant_metadata is not None:
|
if quant_metadata is not None:
|
||||||
layers = quant_metadata["layers"]
|
layers = quant_metadata["layers"]
|
||||||
for k, v in layers.items():
|
for k, v in layers.items():
|
||||||
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
|
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
|
||||||
|
|
||||||
return state_dict, metadata
|
return state_dict, metadata
|
||||||
|
|||||||
@ -5,12 +5,12 @@ This module handles capability negotiation between frontend and backend,
|
|||||||
allowing graceful protocol evolution while maintaining backward compatibility.
|
allowing graceful protocol evolution while maintaining backward compatibility.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
# Default server capabilities
|
# Default server capabilities
|
||||||
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
||||||
"supports_preview_metadata": True,
|
"supports_preview_metadata": True,
|
||||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||||
"extension": {"manager": {"supports_v4": True}},
|
"extension": {"manager": {"supports_v4": True}},
|
||||||
@ -18,7 +18,7 @@ SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
|||||||
|
|
||||||
|
|
||||||
def get_connection_feature(
|
def get_connection_feature(
|
||||||
sockets_metadata: Dict[str, Dict[str, Any]],
|
sockets_metadata: dict[str, dict[str, Any]],
|
||||||
sid: str,
|
sid: str,
|
||||||
feature_name: str,
|
feature_name: str,
|
||||||
default: Any = False
|
default: Any = False
|
||||||
@ -42,7 +42,7 @@ def get_connection_feature(
|
|||||||
|
|
||||||
|
|
||||||
def supports_feature(
|
def supports_feature(
|
||||||
sockets_metadata: Dict[str, Dict[str, Any]],
|
sockets_metadata: dict[str, dict[str, Any]],
|
||||||
sid: str,
|
sid: str,
|
||||||
feature_name: str
|
feature_name: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@ -60,7 +60,7 @@ def supports_feature(
|
|||||||
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
|
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
|
||||||
|
|
||||||
|
|
||||||
def get_server_features() -> Dict[str, Any]:
|
def get_server_features() -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Get the server's feature flags.
|
Get the server's feature flags.
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Type, List, NamedTuple
|
from typing import NamedTuple
|
||||||
from comfy_api.internal.singleton import ProxiedSingleton
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
from packaging import version as packaging_version
|
from packaging import version as packaging_version
|
||||||
|
|
||||||
@ -10,7 +10,7 @@ class ComfyAPIBase(ProxiedSingleton):
|
|||||||
|
|
||||||
class ComfyAPIWithVersion(NamedTuple):
|
class ComfyAPIWithVersion(NamedTuple):
|
||||||
version: str
|
version: str
|
||||||
api_class: Type[ComfyAPIBase]
|
api_class: type[ComfyAPIBase]
|
||||||
|
|
||||||
|
|
||||||
def parse_version(version_str: str) -> packaging_version.Version:
|
def parse_version(version_str: str) -> packaging_version.Version:
|
||||||
@ -23,16 +23,16 @@ def parse_version(version_str: str) -> packaging_version.Version:
|
|||||||
return packaging_version.parse(version_str)
|
return packaging_version.parse(version_str)
|
||||||
|
|
||||||
|
|
||||||
registered_versions: List[ComfyAPIWithVersion] = []
|
registered_versions: list[ComfyAPIWithVersion] = []
|
||||||
|
|
||||||
|
|
||||||
def register_versions(versions: List[ComfyAPIWithVersion]):
|
def register_versions(versions: list[ComfyAPIWithVersion]):
|
||||||
versions.sort(key=lambda x: parse_version(x.version))
|
versions.sort(key=lambda x: parse_version(x.version))
|
||||||
global registered_versions
|
global registered_versions
|
||||||
registered_versions = versions
|
registered_versions = versions
|
||||||
|
|
||||||
|
|
||||||
def get_all_versions() -> List[ComfyAPIWithVersion]:
|
def get_all_versions() -> list[ComfyAPIWithVersion]:
|
||||||
"""
|
"""
|
||||||
Returns a list of all registered ComfyAPI versions.
|
Returns a list of all registered ComfyAPI versions.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import os
|
|||||||
import textwrap
|
import textwrap
|
||||||
import threading
|
import threading
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Type, get_origin, get_args, get_type_hints
|
from typing import Optional, get_origin, get_args, get_type_hints
|
||||||
|
|
||||||
|
|
||||||
class TypeTracker:
|
class TypeTracker:
|
||||||
@ -193,7 +193,7 @@ class AsyncToSyncConverter:
|
|||||||
return result_container["result"]
|
return result_container["result"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
|
def create_sync_class(cls, async_class: type, thread_pool_size=10) -> type:
|
||||||
"""
|
"""
|
||||||
Creates a new class with synchronous versions of all async methods.
|
Creates a new class with synchronous versions of all async methods.
|
||||||
|
|
||||||
@ -563,7 +563,7 @@ class AsyncToSyncConverter:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _generate_imports(
|
def _generate_imports(
|
||||||
cls, async_class: Type, type_tracker: TypeTracker
|
cls, async_class: type, type_tracker: TypeTracker
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Generate import statements for the stub file."""
|
"""Generate import statements for the stub file."""
|
||||||
imports = []
|
imports = []
|
||||||
@ -628,7 +628,7 @@ class AsyncToSyncConverter:
|
|||||||
return imports
|
return imports
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
|
def _get_class_attributes(cls, async_class: type) -> list[tuple[str, type]]:
|
||||||
"""Extract class attributes that are classes themselves."""
|
"""Extract class attributes that are classes themselves."""
|
||||||
class_attributes = []
|
class_attributes = []
|
||||||
|
|
||||||
@ -654,7 +654,7 @@ class AsyncToSyncConverter:
|
|||||||
def _generate_inner_class_stub(
|
def _generate_inner_class_stub(
|
||||||
cls,
|
cls,
|
||||||
name: str,
|
name: str,
|
||||||
attr: Type,
|
attr: type,
|
||||||
indent: str = " ",
|
indent: str = " ",
|
||||||
type_tracker: Optional[TypeTracker] = None,
|
type_tracker: Optional[TypeTracker] = None,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
@ -782,7 +782,7 @@ class AsyncToSyncConverter:
|
|||||||
return processed
|
return processed
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
|
def generate_stub_file(cls, async_class: type, sync_class: type) -> None:
|
||||||
"""
|
"""
|
||||||
Generate a .pyi stub file for the sync class to help IDEs with type checking.
|
Generate a .pyi stub file for the sync class to help IDEs with type checking.
|
||||||
"""
|
"""
|
||||||
@ -988,7 +988,7 @@ class AsyncToSyncConverter:
|
|||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
|
def create_sync_class(async_class: type, thread_pool_size=10) -> type:
|
||||||
"""
|
"""
|
||||||
Creates a sync version of an async class
|
Creates a sync version of an async class
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Type, TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
class SingletonMetaclass(type):
|
class SingletonMetaclass(type):
|
||||||
T = TypeVar("T", bound="SingletonMetaclass")
|
T = TypeVar("T", bound="SingletonMetaclass")
|
||||||
@ -11,13 +11,13 @@ class SingletonMetaclass(type):
|
|||||||
)
|
)
|
||||||
return cls._instances[cls]
|
return cls._instances[cls]
|
||||||
|
|
||||||
def inject_instance(cls: Type[T], instance: T) -> None:
|
def inject_instance(cls: type[T], instance: T) -> None:
|
||||||
assert cls not in SingletonMetaclass._instances, (
|
assert cls not in SingletonMetaclass._instances, (
|
||||||
"Cannot inject instance after first instantiation"
|
"Cannot inject instance after first instantiation"
|
||||||
)
|
)
|
||||||
SingletonMetaclass._instances[cls] = instance
|
SingletonMetaclass._instances[cls] = instance
|
||||||
|
|
||||||
def get_instance(cls: Type[T], *args, **kwargs) -> T:
|
def get_instance(cls: type[T], *args, **kwargs) -> T:
|
||||||
"""
|
"""
|
||||||
Gets the singleton instance of the class, creating it if it doesn't exist.
|
Gets the singleton instance of the class, creating it if it doesn't exist.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,16 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Type, TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from comfy_api.internal import ComfyAPIBase
|
from comfy_api.internal import ComfyAPIBase
|
||||||
from comfy_api.internal.singleton import ProxiedSingleton
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
from comfy_api.internal.async_to_sync import create_sync_class
|
from comfy_api.internal.async_to_sync import create_sync_class
|
||||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
from ._input_impl import VideoFromFile, VideoFromComponents
|
||||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||||
from . import _io_public as io
|
from . import _io_public as io
|
||||||
from . import _ui_public as ui
|
from . import _ui_public as ui
|
||||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
|
||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -80,7 +79,7 @@ class ComfyExtension(ABC):
|
|||||||
async def on_load(self) -> None:
|
async def on_load(self) -> None:
|
||||||
"""
|
"""
|
||||||
Called when an extension is loaded.
|
Called when an extension is loaded.
|
||||||
This should be used to initialize any global resources neeeded by the extension.
|
This should be used to initialize any global resources needed by the extension.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -113,7 +112,7 @@ ComfyAPI = ComfyAPI_latest
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
|
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
|
||||||
|
|
||||||
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
ComfyAPISync: type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||||
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
||||||
|
|
||||||
# create new aliases for io and ui
|
# create new aliases for io and ui
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import TypedDict, List, Optional
|
from typing import TypedDict, Optional
|
||||||
|
|
||||||
ImageInput = torch.Tensor
|
ImageInput = torch.Tensor
|
||||||
"""
|
"""
|
||||||
@ -39,4 +39,4 @@ class LatentInput(TypedDict):
|
|||||||
Optional noise mask tensor in the same format as samples.
|
Optional noise mask tensor in the same format as samples.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
batch_index: Optional[List[int]]
|
batch_index: Optional[list[int]]
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from fractions import Fraction
|
|||||||
from typing import Optional, Union, IO
|
from typing import Optional, Union, IO
|
||||||
import io
|
import io
|
||||||
import av
|
import av
|
||||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
|
||||||
class VideoInput(ABC):
|
class VideoInput(ABC):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -3,14 +3,14 @@ from av.container import InputContainer
|
|||||||
from av.subtitles.stream import SubtitleStream
|
from av.subtitles.stream import SubtitleStream
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from comfy_api.latest._input import AudioInput, VideoInput
|
from .._input import AudioInput, VideoInput
|
||||||
import av
|
import av
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
|
||||||
|
|
||||||
def container_to_output_format(container_format: str | None) -> str | None:
|
def container_to_output_format(container_format: str | None) -> str | None:
|
||||||
|
|||||||
@ -26,11 +26,9 @@ if TYPE_CHECKING:
|
|||||||
from comfy_api.input import VideoInput
|
from comfy_api.input import VideoInput
|
||||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||||
prune_dict, shallow_clone_class)
|
prune_dict, shallow_clone_class)
|
||||||
from comfy_api.latest._resources import Resources, ResourcesLocal
|
|
||||||
from comfy_execution.graph_utils import ExecutionBlocker
|
from comfy_execution.graph_utils import ExecutionBlocker
|
||||||
from ._util import MESH, VOXEL
|
from ._util import MESH, VOXEL, SVG as _SVG
|
||||||
|
|
||||||
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
|
|
||||||
|
|
||||||
class FolderType(str, Enum):
|
class FolderType(str, Enum):
|
||||||
input = "input"
|
input = "input"
|
||||||
@ -77,16 +75,6 @@ class NumberDisplay(str, Enum):
|
|||||||
slider = "slider"
|
slider = "slider"
|
||||||
|
|
||||||
|
|
||||||
class _StringIOType(str):
|
|
||||||
def __ne__(self, value: object) -> bool:
|
|
||||||
if self == "*" or value == "*":
|
|
||||||
return False
|
|
||||||
if not isinstance(value, str):
|
|
||||||
return True
|
|
||||||
a = frozenset(self.split(","))
|
|
||||||
b = frozenset(value.split(","))
|
|
||||||
return not (b.issubset(a) or a.issubset(b))
|
|
||||||
|
|
||||||
class _ComfyType(ABC):
|
class _ComfyType(ABC):
|
||||||
Type = Any
|
Type = Any
|
||||||
io_type: str = None
|
io_type: str = None
|
||||||
@ -126,8 +114,7 @@ def comfytype(io_type: str, **kwargs):
|
|||||||
new_cls.__module__ = cls.__module__
|
new_cls.__module__ = cls.__module__
|
||||||
new_cls.__doc__ = cls.__doc__
|
new_cls.__doc__ = cls.__doc__
|
||||||
# assign ComfyType attributes, if needed
|
# assign ComfyType attributes, if needed
|
||||||
# NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details)
|
new_cls.io_type = io_type
|
||||||
new_cls.io_type = _StringIOType(io_type)
|
|
||||||
if hasattr(new_cls, "Input") and new_cls.Input is not None:
|
if hasattr(new_cls, "Input") and new_cls.Input is not None:
|
||||||
new_cls.Input.Parent = new_cls
|
new_cls.Input.Parent = new_cls
|
||||||
if hasattr(new_cls, "Output") and new_cls.Output is not None:
|
if hasattr(new_cls, "Output") and new_cls.Output is not None:
|
||||||
@ -166,7 +153,7 @@ class Input(_IO_V3):
|
|||||||
'''
|
'''
|
||||||
Base class for a V3 Input.
|
Base class for a V3 Input.
|
||||||
'''
|
'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.id = id
|
self.id = id
|
||||||
self.display_name = display_name
|
self.display_name = display_name
|
||||||
@ -174,6 +161,7 @@ class Input(_IO_V3):
|
|||||||
self.tooltip = tooltip
|
self.tooltip = tooltip
|
||||||
self.lazy = lazy
|
self.lazy = lazy
|
||||||
self.extra_dict = extra_dict if extra_dict is not None else {}
|
self.extra_dict = extra_dict if extra_dict is not None else {}
|
||||||
|
self.rawLink = raw_link
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return prune_dict({
|
return prune_dict({
|
||||||
@ -181,10 +169,11 @@ class Input(_IO_V3):
|
|||||||
"optional": self.optional,
|
"optional": self.optional,
|
||||||
"tooltip": self.tooltip,
|
"tooltip": self.tooltip,
|
||||||
"lazy": self.lazy,
|
"lazy": self.lazy,
|
||||||
|
"rawLink": self.rawLink,
|
||||||
}) | prune_dict(self.extra_dict)
|
}) | prune_dict(self.extra_dict)
|
||||||
|
|
||||||
def get_io_type(self):
|
def get_io_type(self):
|
||||||
return _StringIOType(self.io_type)
|
return self.io_type
|
||||||
|
|
||||||
def get_all(self) -> list[Input]:
|
def get_all(self) -> list[Input]:
|
||||||
return [self]
|
return [self]
|
||||||
@ -195,8 +184,8 @@ class WidgetInput(Input):
|
|||||||
'''
|
'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: Any=None,
|
default: Any=None,
|
||||||
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None):
|
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||||
self.default = default
|
self.default = default
|
||||||
self.socketless = socketless
|
self.socketless = socketless
|
||||||
self.widget_type = widget_type
|
self.widget_type = widget_type
|
||||||
@ -218,13 +207,14 @@ class Output(_IO_V3):
|
|||||||
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
|
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
|
||||||
is_output_list=False):
|
is_output_list=False):
|
||||||
self.id = id
|
self.id = id
|
||||||
self.display_name = display_name
|
self.display_name = display_name if display_name else id
|
||||||
self.tooltip = tooltip
|
self.tooltip = tooltip
|
||||||
self.is_output_list = is_output_list
|
self.is_output_list = is_output_list
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
|
display_name = self.display_name if self.display_name else self.id
|
||||||
return prune_dict({
|
return prune_dict({
|
||||||
"display_name": self.display_name,
|
"display_name": display_name,
|
||||||
"tooltip": self.tooltip,
|
"tooltip": self.tooltip,
|
||||||
"is_output_list": self.is_output_list,
|
"is_output_list": self.is_output_list,
|
||||||
})
|
})
|
||||||
@ -252,8 +242,8 @@ class Boolean(ComfyTypeIO):
|
|||||||
'''Boolean input.'''
|
'''Boolean input.'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: bool=None, label_on: str=None, label_off: str=None,
|
default: bool=None, label_on: str=None, label_off: str=None,
|
||||||
socketless: bool=None, force_input: bool=None):
|
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||||
self.label_on = label_on
|
self.label_on = label_on
|
||||||
self.label_off = label_off
|
self.label_off = label_off
|
||||||
self.default: bool
|
self.default: bool
|
||||||
@ -272,8 +262,8 @@ class Int(ComfyTypeIO):
|
|||||||
'''Integer input.'''
|
'''Integer input.'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
||||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None):
|
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||||
self.min = min
|
self.min = min
|
||||||
self.max = max
|
self.max = max
|
||||||
self.step = step
|
self.step = step
|
||||||
@ -298,8 +288,8 @@ class Float(ComfyTypeIO):
|
|||||||
'''Float input.'''
|
'''Float input.'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
||||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None):
|
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||||
self.min = min
|
self.min = min
|
||||||
self.max = max
|
self.max = max
|
||||||
self.step = step
|
self.step = step
|
||||||
@ -324,8 +314,8 @@ class String(ComfyTypeIO):
|
|||||||
'''String input.'''
|
'''String input.'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
|
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
|
||||||
socketless: bool=None, force_input: bool=None):
|
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||||
self.multiline = multiline
|
self.multiline = multiline
|
||||||
self.placeholder = placeholder
|
self.placeholder = placeholder
|
||||||
self.dynamic_prompts = dynamic_prompts
|
self.dynamic_prompts = dynamic_prompts
|
||||||
@ -358,12 +348,14 @@ class Combo(ComfyTypeIO):
|
|||||||
image_folder: FolderType=None,
|
image_folder: FolderType=None,
|
||||||
remote: RemoteOptions=None,
|
remote: RemoteOptions=None,
|
||||||
socketless: bool=None,
|
socketless: bool=None,
|
||||||
|
extra_dict=None,
|
||||||
|
raw_link: bool=None,
|
||||||
):
|
):
|
||||||
if isinstance(options, type) and issubclass(options, Enum):
|
if isinstance(options, type) and issubclass(options, Enum):
|
||||||
options = [v.value for v in options]
|
options = [v.value for v in options]
|
||||||
if isinstance(default, Enum):
|
if isinstance(default, Enum):
|
||||||
default = default.value
|
default = default.value
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
|
||||||
self.multiselect = False
|
self.multiselect = False
|
||||||
self.options = options
|
self.options = options
|
||||||
self.control_after_generate = control_after_generate
|
self.control_after_generate = control_after_generate
|
||||||
@ -387,10 +379,6 @@ class Combo(ComfyTypeIO):
|
|||||||
super().__init__(id, display_name, tooltip, is_output_list)
|
super().__init__(id, display_name, tooltip, is_output_list)
|
||||||
self.options = options if options is not None else []
|
self.options = options if options is not None else []
|
||||||
|
|
||||||
@property
|
|
||||||
def io_type(self):
|
|
||||||
return self.options
|
|
||||||
|
|
||||||
@comfytype(io_type="COMBO")
|
@comfytype(io_type="COMBO")
|
||||||
class MultiCombo(ComfyTypeI):
|
class MultiCombo(ComfyTypeI):
|
||||||
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
|
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
|
||||||
@ -399,8 +387,8 @@ class MultiCombo(ComfyTypeI):
|
|||||||
class Input(Combo.Input):
|
class Input(Combo.Input):
|
||||||
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
||||||
socketless: bool=None):
|
socketless: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless)
|
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link)
|
||||||
self.multiselect = True
|
self.multiselect = True
|
||||||
self.placeholder = placeholder
|
self.placeholder = placeholder
|
||||||
self.chip = chip
|
self.chip = chip
|
||||||
@ -433,9 +421,9 @@ class Webcam(ComfyTypeIO):
|
|||||||
Type = str
|
Type = str
|
||||||
def __init__(
|
def __init__(
|
||||||
self, id: str, display_name: str=None, optional=False,
|
self, id: str, display_name: str=None, optional=False,
|
||||||
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None
|
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None
|
||||||
):
|
):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
|
||||||
|
|
||||||
|
|
||||||
@comfytype(io_type="MASK")
|
@comfytype(io_type="MASK")
|
||||||
@ -656,7 +644,7 @@ class Video(ComfyTypeIO):
|
|||||||
|
|
||||||
@comfytype(io_type="SVG")
|
@comfytype(io_type="SVG")
|
||||||
class SVG(ComfyTypeIO):
|
class SVG(ComfyTypeIO):
|
||||||
Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3
|
Type = _SVG
|
||||||
|
|
||||||
@comfytype(io_type="LORA_MODEL")
|
@comfytype(io_type="LORA_MODEL")
|
||||||
class LoraModel(ComfyTypeIO):
|
class LoraModel(ComfyTypeIO):
|
||||||
@ -774,6 +762,13 @@ class AudioEncoder(ComfyTypeIO):
|
|||||||
class AudioEncoderOutput(ComfyTypeIO):
|
class AudioEncoderOutput(ComfyTypeIO):
|
||||||
Type = Any
|
Type = Any
|
||||||
|
|
||||||
|
@comfytype(io_type="TRACKS")
|
||||||
|
class Tracks(ComfyTypeIO):
|
||||||
|
class TrackDict(TypedDict):
|
||||||
|
track_path: torch.Tensor
|
||||||
|
track_visibility: torch.Tensor
|
||||||
|
Type = TrackDict
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
||||||
class MultiType:
|
class MultiType:
|
||||||
Type = Any
|
Type = Any
|
||||||
@ -781,7 +776,7 @@ class MultiType:
|
|||||||
'''
|
'''
|
||||||
Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values.
|
Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values.
|
||||||
'''
|
'''
|
||||||
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
# if id is an Input, then use that Input with overridden values
|
# if id is an Input, then use that Input with overridden values
|
||||||
self.input_override = None
|
self.input_override = None
|
||||||
if isinstance(id, Input):
|
if isinstance(id, Input):
|
||||||
@ -794,7 +789,7 @@ class MultiType:
|
|||||||
# if is a widget input, make sure widget_type is set appropriately
|
# if is a widget input, make sure widget_type is set appropriately
|
||||||
if isinstance(self.input_override, WidgetInput):
|
if isinstance(self.input_override, WidgetInput):
|
||||||
self.input_override.widget_type = self.input_override.get_io_type()
|
self.input_override.widget_type = self.input_override.get_io_type()
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||||
self._io_types = types
|
self._io_types = types
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -848,8 +843,8 @@ class MatchType(ComfyTypeIO):
|
|||||||
|
|
||||||
class Input(Input):
|
class Input(Input):
|
||||||
def __init__(self, id: str, template: MatchType.Template,
|
def __init__(self, id: str, template: MatchType.Template,
|
||||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||||
self.template = template
|
self.template = template
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
@ -860,6 +855,8 @@ class MatchType(ComfyTypeIO):
|
|||||||
class Output(Output):
|
class Output(Output):
|
||||||
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
|
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
|
||||||
is_output_list=False):
|
is_output_list=False):
|
||||||
|
if not id and not display_name:
|
||||||
|
display_name = "MATCHTYPE"
|
||||||
super().__init__(id, display_name, tooltip, is_output_list)
|
super().__init__(id, display_name, tooltip, is_output_list)
|
||||||
self.template = template
|
self.template = template
|
||||||
|
|
||||||
@ -872,24 +869,30 @@ class DynamicInput(Input, ABC):
|
|||||||
'''
|
'''
|
||||||
Abstract class for dynamic input registration.
|
Abstract class for dynamic input registration.
|
||||||
'''
|
'''
|
||||||
def get_dynamic(self) -> list[Input]:
|
pass
|
||||||
return []
|
|
||||||
|
|
||||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicOutput(Output, ABC):
|
class DynamicOutput(Output, ABC):
|
||||||
'''
|
'''
|
||||||
Abstract class for dynamic output registration.
|
Abstract class for dynamic output registration.
|
||||||
'''
|
'''
|
||||||
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
|
pass
|
||||||
is_output_list=False):
|
|
||||||
super().__init__(id, display_name, tooltip, is_output_list)
|
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Output]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
def handle_prefix(prefix_list: list[str] | None, id: str | None = None) -> list[str]:
|
||||||
|
if prefix_list is None:
|
||||||
|
prefix_list = []
|
||||||
|
if id is not None:
|
||||||
|
prefix_list = prefix_list + [id]
|
||||||
|
return prefix_list
|
||||||
|
|
||||||
|
def finalize_prefix(prefix_list: list[str] | None, id: str | None = None) -> str:
|
||||||
|
assert not (prefix_list is None and id is None)
|
||||||
|
if prefix_list is None:
|
||||||
|
return id
|
||||||
|
elif id is not None:
|
||||||
|
prefix_list = prefix_list + [id]
|
||||||
|
return ".".join(prefix_list)
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_AUTOGROW_V3")
|
@comfytype(io_type="COMFY_AUTOGROW_V3")
|
||||||
class Autogrow(ComfyTypeI):
|
class Autogrow(ComfyTypeI):
|
||||||
@ -926,14 +929,6 @@ class Autogrow(ComfyTypeI):
|
|||||||
def validate(self):
|
def validate(self):
|
||||||
self.input.validate()
|
self.input.validate()
|
||||||
|
|
||||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
|
||||||
real_inputs = []
|
|
||||||
for name, input in self.cached_inputs.items():
|
|
||||||
if name in live_inputs:
|
|
||||||
real_inputs.append(input)
|
|
||||||
add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
|
|
||||||
add_dynamic_id_mapping(d, real_inputs, curr_prefix)
|
|
||||||
|
|
||||||
class TemplatePrefix(_AutogrowTemplate):
|
class TemplatePrefix(_AutogrowTemplate):
|
||||||
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
|
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
|
||||||
super().__init__(input)
|
super().__init__(input)
|
||||||
@ -978,22 +973,45 @@ class Autogrow(ComfyTypeI):
|
|||||||
"template": self.template.as_dict(),
|
"template": self.template.as_dict(),
|
||||||
})
|
})
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Input]:
|
|
||||||
return self.template.get_all()
|
|
||||||
|
|
||||||
def get_all(self) -> list[Input]:
|
def get_all(self) -> list[Input]:
|
||||||
return [self] + self.template.get_all()
|
return [self] + self.template.get_all()
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
self.template.validate()
|
self.template.validate()
|
||||||
|
|
||||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
@staticmethod
|
||||||
curr_prefix = f"{curr_prefix}{self.id}."
|
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||||
# need to remove self from expected inputs dictionary; replaced by template inputs in frontend
|
# NOTE: purposely do not include self in out_dict; instead use only the template inputs
|
||||||
for inner_dict in d.values():
|
# need to figure out names based on template type
|
||||||
if self.id in inner_dict:
|
is_names = ("names" in value[1]["template"])
|
||||||
del inner_dict[self.id]
|
is_prefix = ("prefix" in value[1]["template"])
|
||||||
self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
|
input = value[1]["template"]["input"]
|
||||||
|
if is_names:
|
||||||
|
min = value[1]["template"]["min"]
|
||||||
|
names = value[1]["template"]["names"]
|
||||||
|
max = len(names)
|
||||||
|
elif is_prefix:
|
||||||
|
prefix = value[1]["template"]["prefix"]
|
||||||
|
min = value[1]["template"]["min"]
|
||||||
|
max = value[1]["template"]["max"]
|
||||||
|
names = [f"{prefix}{i}" for i in range(max)]
|
||||||
|
# need to create a new input based on the contents of input
|
||||||
|
template_input = None
|
||||||
|
for _, dict_input in input.items():
|
||||||
|
# for now, get just the first value from dict_input
|
||||||
|
template_input = list(dict_input.values())[0]
|
||||||
|
new_dict = {}
|
||||||
|
for i, name in enumerate(names):
|
||||||
|
expected_id = finalize_prefix(curr_prefix, name)
|
||||||
|
if expected_id in live_inputs:
|
||||||
|
# required
|
||||||
|
if i < min:
|
||||||
|
type_dict = new_dict.setdefault("required", {})
|
||||||
|
# optional
|
||||||
|
else:
|
||||||
|
type_dict = new_dict.setdefault("optional", {})
|
||||||
|
type_dict[name] = template_input
|
||||||
|
parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix)
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
|
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
|
||||||
class DynamicCombo(ComfyTypeI):
|
class DynamicCombo(ComfyTypeI):
|
||||||
@ -1016,23 +1034,6 @@ class DynamicCombo(ComfyTypeI):
|
|||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||||
self.options = options
|
self.options = options
|
||||||
|
|
||||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
|
||||||
# check if dynamic input's id is in live_inputs
|
|
||||||
if self.id in live_inputs:
|
|
||||||
curr_prefix = f"{curr_prefix}{self.id}."
|
|
||||||
key = live_inputs[self.id]
|
|
||||||
selected_option = None
|
|
||||||
for option in self.options:
|
|
||||||
if option.key == key:
|
|
||||||
selected_option = option
|
|
||||||
break
|
|
||||||
if selected_option is not None:
|
|
||||||
add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
|
|
||||||
add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
|
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Input]:
|
|
||||||
return [input for option in self.options for input in option.inputs]
|
|
||||||
|
|
||||||
def get_all(self) -> list[Input]:
|
def get_all(self) -> list[Input]:
|
||||||
return [self] + [input for option in self.options for input in option.inputs]
|
return [self] + [input for option in self.options for input in option.inputs]
|
||||||
|
|
||||||
@ -1047,6 +1048,24 @@ class DynamicCombo(ComfyTypeI):
|
|||||||
for input in option.inputs:
|
for input in option.inputs:
|
||||||
input.validate()
|
input.validate()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||||
|
finalized_id = finalize_prefix(curr_prefix)
|
||||||
|
if finalized_id in live_inputs:
|
||||||
|
key = live_inputs[finalized_id]
|
||||||
|
selected_option = None
|
||||||
|
# get options from dict
|
||||||
|
options: list[dict[str, str | dict[str, Any]]] = value[1]["options"]
|
||||||
|
for option in options:
|
||||||
|
if option["key"] == key:
|
||||||
|
selected_option = option
|
||||||
|
break
|
||||||
|
if selected_option is not None:
|
||||||
|
parse_class_inputs(out_dict, live_inputs, selected_option["inputs"], curr_prefix)
|
||||||
|
# add self to inputs
|
||||||
|
out_dict[input_type][finalized_id] = value
|
||||||
|
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
|
@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
|
||||||
class DynamicSlot(ComfyTypeI):
|
class DynamicSlot(ComfyTypeI):
|
||||||
Type = dict[str, Any]
|
Type = dict[str, Any]
|
||||||
@ -1069,17 +1088,8 @@ class DynamicSlot(ComfyTypeI):
|
|||||||
self.force_input = True
|
self.force_input = True
|
||||||
self.slot.force_input = True
|
self.slot.force_input = True
|
||||||
|
|
||||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
|
||||||
if self.id in live_inputs:
|
|
||||||
curr_prefix = f"{curr_prefix}{self.id}."
|
|
||||||
add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
|
|
||||||
add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
|
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Input]:
|
|
||||||
return [self.slot] + self.inputs
|
|
||||||
|
|
||||||
def get_all(self) -> list[Input]:
|
def get_all(self) -> list[Input]:
|
||||||
return [self] + [self.slot] + self.inputs
|
return [self.slot] + self.inputs
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return super().as_dict() | prune_dict({
|
return super().as_dict() | prune_dict({
|
||||||
@ -1093,17 +1103,41 @@ class DynamicSlot(ComfyTypeI):
|
|||||||
for input in self.inputs:
|
for input in self.inputs:
|
||||||
input.validate()
|
input.validate()
|
||||||
|
|
||||||
def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None):
|
@staticmethod
|
||||||
dynamic = d.setdefault("dynamic_paths", {})
|
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||||
if self is not None:
|
finalized_id = finalize_prefix(curr_prefix)
|
||||||
dynamic[self.id] = f"{curr_prefix}{self.id}"
|
if finalized_id in live_inputs:
|
||||||
for i in inputs:
|
inputs = value[1]["inputs"]
|
||||||
if not isinstance(i, DynamicInput):
|
parse_class_inputs(out_dict, live_inputs, inputs, curr_prefix)
|
||||||
dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}"
|
# add self to inputs
|
||||||
|
out_dict[input_type][finalized_id] = value
|
||||||
|
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||||
|
|
||||||
|
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||||
|
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||||
|
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||||
|
|
||||||
|
def get_dynamic_input_func(io_type: str) -> Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]:
|
||||||
|
return DYNAMIC_INPUT_LOOKUP[io_type]
|
||||||
|
|
||||||
|
def setup_dynamic_input_funcs():
|
||||||
|
# Autogrow.Input
|
||||||
|
register_dynamic_input_func(Autogrow.io_type, Autogrow._expand_schema_for_dynamic)
|
||||||
|
# DynamicCombo.Input
|
||||||
|
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
|
||||||
|
# DynamicSlot.Input
|
||||||
|
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
|
||||||
|
|
||||||
|
if len(DYNAMIC_INPUT_LOOKUP) == 0:
|
||||||
|
setup_dynamic_input_funcs()
|
||||||
|
|
||||||
class V3Data(TypedDict):
|
class V3Data(TypedDict):
|
||||||
hidden_inputs: dict[str, Any]
|
hidden_inputs: dict[str, Any]
|
||||||
|
'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.'
|
||||||
dynamic_paths: dict[str, Any]
|
dynamic_paths: dict[str, Any]
|
||||||
|
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
|
||||||
|
create_dynamic_tuple: bool
|
||||||
|
'When True, the value of the dynamic input will be in the format (value, path_key).'
|
||||||
|
|
||||||
class HiddenHolder:
|
class HiddenHolder:
|
||||||
def __init__(self, unique_id: str, prompt: Any,
|
def __init__(self, unique_id: str, prompt: Any,
|
||||||
@ -1139,6 +1173,10 @@ class HiddenHolder:
|
|||||||
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
|
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_v3_data(cls, v3_data: V3Data | None) -> HiddenHolder:
|
||||||
|
return cls.from_dict(v3_data["hidden_inputs"] if v3_data else None)
|
||||||
|
|
||||||
class Hidden(str, Enum):
|
class Hidden(str, Enum):
|
||||||
'''
|
'''
|
||||||
Enumerator for requesting hidden variables in nodes.
|
Enumerator for requesting hidden variables in nodes.
|
||||||
@ -1244,61 +1282,56 @@ class Schema:
|
|||||||
- verify ids on inputs and outputs are unique - both internally and in relation to each other
|
- verify ids on inputs and outputs are unique - both internally and in relation to each other
|
||||||
'''
|
'''
|
||||||
nested_inputs: list[Input] = []
|
nested_inputs: list[Input] = []
|
||||||
if self.inputs is not None:
|
for input in self.inputs:
|
||||||
for input in self.inputs:
|
if not isinstance(input, DynamicInput):
|
||||||
nested_inputs.extend(input.get_all())
|
nested_inputs.extend(input.get_all())
|
||||||
input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else []
|
input_ids = [i.id for i in nested_inputs]
|
||||||
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
|
output_ids = [o.id for o in self.outputs]
|
||||||
input_set = set(input_ids)
|
input_set = set(input_ids)
|
||||||
output_set = set(output_ids)
|
output_set = set(output_ids)
|
||||||
issues = []
|
issues: list[str] = []
|
||||||
# verify ids are unique per list
|
# verify ids are unique per list
|
||||||
if len(input_set) != len(input_ids):
|
if len(input_set) != len(input_ids):
|
||||||
issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.")
|
issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.")
|
||||||
if len(output_set) != len(output_ids):
|
if len(output_set) != len(output_ids):
|
||||||
issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.")
|
issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.")
|
||||||
# verify ids are unique between lists
|
|
||||||
intersection = input_set & output_set
|
|
||||||
if len(intersection) > 0:
|
|
||||||
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
|
|
||||||
if len(issues) > 0:
|
if len(issues) > 0:
|
||||||
raise ValueError("\n".join(issues))
|
raise ValueError("\n".join(issues))
|
||||||
# validate inputs and outputs
|
# validate inputs and outputs
|
||||||
if self.inputs is not None:
|
for input in self.inputs:
|
||||||
for input in self.inputs:
|
input.validate()
|
||||||
input.validate()
|
for output in self.outputs:
|
||||||
if self.outputs is not None:
|
output.validate()
|
||||||
for output in self.outputs:
|
|
||||||
output.validate()
|
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
|
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
|
||||||
|
# ensure inputs, outputs, and hidden are lists
|
||||||
|
if self.inputs is None:
|
||||||
|
self.inputs = []
|
||||||
|
if self.outputs is None:
|
||||||
|
self.outputs = []
|
||||||
|
if self.hidden is None:
|
||||||
|
self.hidden = []
|
||||||
# if is an api_node, will need key-related hidden
|
# if is an api_node, will need key-related hidden
|
||||||
if self.is_api_node:
|
if self.is_api_node:
|
||||||
if self.hidden is None:
|
|
||||||
self.hidden = []
|
|
||||||
if Hidden.auth_token_comfy_org not in self.hidden:
|
if Hidden.auth_token_comfy_org not in self.hidden:
|
||||||
self.hidden.append(Hidden.auth_token_comfy_org)
|
self.hidden.append(Hidden.auth_token_comfy_org)
|
||||||
if Hidden.api_key_comfy_org not in self.hidden:
|
if Hidden.api_key_comfy_org not in self.hidden:
|
||||||
self.hidden.append(Hidden.api_key_comfy_org)
|
self.hidden.append(Hidden.api_key_comfy_org)
|
||||||
# if is an output_node, will need prompt and extra_pnginfo
|
# if is an output_node, will need prompt and extra_pnginfo
|
||||||
if self.is_output_node:
|
if self.is_output_node:
|
||||||
if self.hidden is None:
|
|
||||||
self.hidden = []
|
|
||||||
if Hidden.prompt not in self.hidden:
|
if Hidden.prompt not in self.hidden:
|
||||||
self.hidden.append(Hidden.prompt)
|
self.hidden.append(Hidden.prompt)
|
||||||
if Hidden.extra_pnginfo not in self.hidden:
|
if Hidden.extra_pnginfo not in self.hidden:
|
||||||
self.hidden.append(Hidden.extra_pnginfo)
|
self.hidden.append(Hidden.extra_pnginfo)
|
||||||
# give outputs without ids default ids
|
# give outputs without ids default ids
|
||||||
if self.outputs is not None:
|
for i, output in enumerate(self.outputs):
|
||||||
for i, output in enumerate(self.outputs):
|
if output.id is None:
|
||||||
if output.id is None:
|
output.id = f"_{i}_{output.io_type}_"
|
||||||
output.id = f"_{i}_{output.io_type}_"
|
|
||||||
|
|
||||||
def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1:
|
def get_v1_info(self, cls) -> NodeInfoV1:
|
||||||
# NOTE: live_inputs will not be used anymore very soon and this will be done another way
|
|
||||||
# get V1 inputs
|
# get V1 inputs
|
||||||
input = create_input_dict_v1(self.inputs, live_inputs)
|
input = create_input_dict_v1(self.inputs)
|
||||||
if self.hidden:
|
if self.hidden:
|
||||||
for hidden in self.hidden:
|
for hidden in self.hidden:
|
||||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||||
@ -1378,33 +1411,54 @@ class Schema:
|
|||||||
)
|
)
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]:
|
||||||
|
out_dict = {
|
||||||
|
"required": {},
|
||||||
|
"optional": {},
|
||||||
|
"dynamic_paths": {},
|
||||||
|
}
|
||||||
|
d = d.copy()
|
||||||
|
# ignore hidden for parsing
|
||||||
|
hidden = d.pop("hidden", None)
|
||||||
|
parse_class_inputs(out_dict, live_inputs, d)
|
||||||
|
if hidden is not None and include_hidden:
|
||||||
|
out_dict["hidden"] = hidden
|
||||||
|
v3_data = {}
|
||||||
|
dynamic_paths = out_dict.pop("dynamic_paths", None)
|
||||||
|
if dynamic_paths is not None:
|
||||||
|
v3_data["dynamic_paths"] = dynamic_paths
|
||||||
|
return out_dict, hidden, v3_data
|
||||||
|
|
||||||
def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict:
|
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
|
||||||
|
for input_type, inner_d in curr_dict.items():
|
||||||
|
for id, value in inner_d.items():
|
||||||
|
io_type = value[0]
|
||||||
|
if io_type in DYNAMIC_INPUT_LOOKUP:
|
||||||
|
# dynamic inputs need to be handled with lookup functions
|
||||||
|
dynamic_input_func = get_dynamic_input_func(io_type)
|
||||||
|
new_prefix = handle_prefix(curr_prefix, id)
|
||||||
|
dynamic_input_func(out_dict, live_inputs, value, input_type, new_prefix)
|
||||||
|
else:
|
||||||
|
# non-dynamic inputs get directly transferred
|
||||||
|
finalized_id = finalize_prefix(curr_prefix, id)
|
||||||
|
out_dict[input_type][finalized_id] = value
|
||||||
|
if curr_prefix:
|
||||||
|
out_dict["dynamic_paths"][finalized_id] = finalized_id
|
||||||
|
|
||||||
|
def create_input_dict_v1(inputs: list[Input]) -> dict:
|
||||||
input = {
|
input = {
|
||||||
"required": {}
|
"required": {}
|
||||||
}
|
}
|
||||||
add_to_input_dict_v1(input, inputs, live_inputs)
|
for i in inputs:
|
||||||
|
add_to_dict_v1(i, input)
|
||||||
return input
|
return input
|
||||||
|
|
||||||
def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''):
|
def add_to_dict_v1(i: Input, d: dict):
|
||||||
for i in inputs:
|
|
||||||
if isinstance(i, DynamicInput):
|
|
||||||
add_to_dict_v1(i, d)
|
|
||||||
if live_inputs is not None:
|
|
||||||
i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
|
|
||||||
else:
|
|
||||||
add_to_dict_v1(i, d)
|
|
||||||
|
|
||||||
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
|
|
||||||
key = "optional" if i.optional else "required"
|
key = "optional" if i.optional else "required"
|
||||||
as_dict = i.as_dict()
|
as_dict = i.as_dict()
|
||||||
# for v1, we don't want to include the optional key
|
# for v1, we don't want to include the optional key
|
||||||
as_dict.pop("optional", None)
|
as_dict.pop("optional", None)
|
||||||
if dynamic_dict is None:
|
d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
|
||||||
value = (i.get_io_type(), as_dict)
|
|
||||||
else:
|
|
||||||
value = (i.get_io_type(), as_dict, dynamic_dict)
|
|
||||||
d.setdefault(key, {})[i.id] = value
|
|
||||||
|
|
||||||
def add_to_dict_v3(io: Input | Output, d: dict):
|
def add_to_dict_v3(io: Input | Output, d: dict):
|
||||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||||
@ -1416,6 +1470,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
|||||||
values = values.copy()
|
values = values.copy()
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
|
create_tuple = v3_data.get("create_dynamic_tuple", False)
|
||||||
|
|
||||||
for key, path in paths.items():
|
for key, path in paths.items():
|
||||||
parts = path.split(".")
|
parts = path.split(".")
|
||||||
current = result
|
current = result
|
||||||
@ -1424,7 +1480,10 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
|||||||
is_last = (i == len(parts) - 1)
|
is_last = (i == len(parts) - 1)
|
||||||
|
|
||||||
if is_last:
|
if is_last:
|
||||||
current[p] = values.pop(key, None)
|
value = values.pop(key, None)
|
||||||
|
if create_tuple:
|
||||||
|
value = (value, key)
|
||||||
|
current[p] = value
|
||||||
else:
|
else:
|
||||||
current = current.setdefault(p, {})
|
current = current.setdefault(p, {})
|
||||||
|
|
||||||
@ -1439,7 +1498,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
SCHEMA = None
|
SCHEMA = None
|
||||||
|
|
||||||
# filled in during execution
|
# filled in during execution
|
||||||
resources: Resources = None
|
|
||||||
hidden: HiddenHolder = None
|
hidden: HiddenHolder = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1486,7 +1544,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
return [name for name in kwargs if kwargs[name] is None]
|
return [name for name in kwargs if kwargs[name] is None]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.local_resources: ResourcesLocal = None
|
|
||||||
self.__class__.VALIDATE_CLASS()
|
self.__class__.VALIDATE_CLASS()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1549,12 +1606,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
|
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data | None) -> type[ComfyNode]:
|
||||||
"""Creates clone of real node class to prevent monkey-patching."""
|
"""Creates clone of real node class to prevent monkey-patching."""
|
||||||
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
||||||
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
||||||
# set hidden
|
# set hidden
|
||||||
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
|
type_clone.hidden = HiddenHolder.from_v3_data(v3_data)
|
||||||
return type_clone
|
return type_clone
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@ -1671,19 +1728,10 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
|
def INPUT_TYPES(cls) -> dict[str, dict]:
|
||||||
schema = cls.FINALIZE_SCHEMA()
|
schema = cls.FINALIZE_SCHEMA()
|
||||||
info = schema.get_v1_info(cls, live_inputs)
|
info = schema.get_v1_info(cls)
|
||||||
input = info.input
|
return info.input
|
||||||
if not include_hidden:
|
|
||||||
input.pop("hidden", None)
|
|
||||||
if return_schema:
|
|
||||||
v3_data: V3Data = {}
|
|
||||||
dynamic = input.pop("dynamic_paths", None)
|
|
||||||
if dynamic is not None:
|
|
||||||
v3_data["dynamic_paths"] = dynamic
|
|
||||||
return input, schema, v3_data
|
|
||||||
return input
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1802,7 +1850,7 @@ class NodeOutput(_NodeOutputInternal):
|
|||||||
return self.args if len(self.args) > 0 else None
|
return self.args if len(self.args) > 0 else None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> "NodeOutput":
|
def from_dict(cls, data: dict[str, Any]) -> NodeOutput:
|
||||||
args = ()
|
args = ()
|
||||||
ui = None
|
ui = None
|
||||||
expand = None
|
expand = None
|
||||||
@ -1815,7 +1863,7 @@ class NodeOutput(_NodeOutputInternal):
|
|||||||
ui = data["ui"]
|
ui = data["ui"]
|
||||||
if "expand" in data:
|
if "expand" in data:
|
||||||
expand = data["expand"]
|
expand = data["expand"]
|
||||||
return cls(args=args, ui=ui, expand=expand)
|
return cls(*args, ui=ui, expand=expand)
|
||||||
|
|
||||||
def __getitem__(self, index) -> Any:
|
def __getitem__(self, index) -> Any:
|
||||||
return self.args[index]
|
return self.args[index]
|
||||||
@ -1894,10 +1942,11 @@ __all__ = [
|
|||||||
"SEGS",
|
"SEGS",
|
||||||
"AnyType",
|
"AnyType",
|
||||||
"MultiType",
|
"MultiType",
|
||||||
|
"Tracks",
|
||||||
# Dynamic Types
|
# Dynamic Types
|
||||||
"MatchType",
|
"MatchType",
|
||||||
# "DynamicCombo",
|
"DynamicCombo",
|
||||||
# "Autogrow",
|
"Autogrow",
|
||||||
# Other classes
|
# Other classes
|
||||||
"HiddenHolder",
|
"HiddenHolder",
|
||||||
"Hidden",
|
"Hidden",
|
||||||
|
|||||||
@ -1,72 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import comfy.utils
|
|
||||||
import folder_paths
|
|
||||||
import logging
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any
|
|
||||||
import torch
|
|
||||||
|
|
||||||
class ResourceKey(ABC):
|
|
||||||
Type = Any
|
|
||||||
def __init__(self):
|
|
||||||
...
|
|
||||||
|
|
||||||
class TorchDictFolderFilename(ResourceKey):
|
|
||||||
'''Key for requesting a torch file via file_name from a folder category.'''
|
|
||||||
Type = dict[str, torch.Tensor]
|
|
||||||
def __init__(self, folder_name: str, file_name: str):
|
|
||||||
self.folder_name = folder_name
|
|
||||||
self.file_name = file_name
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return hash((self.folder_name, self.file_name))
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, TorchDictFolderFilename):
|
|
||||||
return False
|
|
||||||
return self.folder_name == other.folder_name and self.file_name == other.file_name
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"{self.folder_name} -> {self.file_name}"
|
|
||||||
|
|
||||||
class Resources(ABC):
|
|
||||||
def __init__(self):
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class ResourcesLocal(Resources):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.local_resources: dict[ResourceKey, Any] = {}
|
|
||||||
|
|
||||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
|
||||||
cached = self.local_resources.get(key, None)
|
|
||||||
if cached is not None:
|
|
||||||
logging.info(f"Using cached resource '{key}'")
|
|
||||||
return cached
|
|
||||||
logging.info(f"Loading resource '{key}'")
|
|
||||||
to_return = None
|
|
||||||
if isinstance(key, TorchDictFolderFilename):
|
|
||||||
if default is ...:
|
|
||||||
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
|
|
||||||
else:
|
|
||||||
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
|
|
||||||
if full_path is not None:
|
|
||||||
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
|
|
||||||
|
|
||||||
if to_return is not None:
|
|
||||||
self.local_resources[key] = to_return
|
|
||||||
return to_return
|
|
||||||
if default is not ...:
|
|
||||||
return default
|
|
||||||
raise Exception(f"Unsupported resource key type: {type(key)}")
|
|
||||||
|
|
||||||
|
|
||||||
class _RESOURCES:
|
|
||||||
ResourceKey = ResourceKey
|
|
||||||
TorchDictFolderFilename = TorchDictFolderFilename
|
|
||||||
Resources = Resources
|
|
||||||
ResourcesLocal = ResourcesLocal
|
|
||||||
@ -5,7 +5,6 @@ import os
|
|||||||
import random
|
import random
|
||||||
import uuid
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
import av
|
import av
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -22,7 +21,7 @@ import folder_paths
|
|||||||
|
|
||||||
# used for image preview
|
# used for image preview
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
|
from ._io import ComfyNode, FolderType, Image, _UIOutput
|
||||||
|
|
||||||
|
|
||||||
class SavedResult(dict):
|
class SavedResult(dict):
|
||||||
@ -83,7 +82,7 @@ class ImageSaveHelper:
|
|||||||
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
|
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
def _create_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
|
||||||
"""Creates a PngInfo object with prompt and extra_pnginfo."""
|
"""Creates a PngInfo object with prompt and extra_pnginfo."""
|
||||||
if args.disable_metadata or cls is None or not cls.hidden:
|
if args.disable_metadata or cls is None or not cls.hidden:
|
||||||
return None
|
return None
|
||||||
@ -96,7 +95,7 @@ class ImageSaveHelper:
|
|||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
def _create_animated_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
|
||||||
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
|
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
|
||||||
if args.disable_metadata or cls is None or not cls.hidden:
|
if args.disable_metadata or cls is None or not cls.hidden:
|
||||||
return None
|
return None
|
||||||
@ -121,7 +120,7 @@ class ImageSaveHelper:
|
|||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
|
def _create_webp_metadata(pil_image: PILImage.Image, cls: type[ComfyNode] | None) -> PILImage.Exif:
|
||||||
"""Creates EXIF metadata bytes for WebP images."""
|
"""Creates EXIF metadata bytes for WebP images."""
|
||||||
exif_data = pil_image.getexif()
|
exif_data = pil_image.getexif()
|
||||||
if args.disable_metadata or cls is None or cls.hidden is None:
|
if args.disable_metadata or cls is None or cls.hidden is None:
|
||||||
@ -137,7 +136,7 @@ class ImageSaveHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_images(
|
def save_images(
|
||||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
|
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, compress_level = 4,
|
||||||
) -> list[SavedResult]:
|
) -> list[SavedResult]:
|
||||||
"""Saves a batch of images as individual PNG files."""
|
"""Saves a batch of images as individual PNG files."""
|
||||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||||
@ -155,7 +154,7 @@ class ImageSaveHelper:
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
def get_save_images_ui(images, filename_prefix: str, cls: type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
||||||
"""Saves a batch of images and returns a UI object for the node output."""
|
"""Saves a batch of images and returns a UI object for the node output."""
|
||||||
return SavedImages(
|
return SavedImages(
|
||||||
ImageSaveHelper.save_images(
|
ImageSaveHelper.save_images(
|
||||||
@ -169,7 +168,7 @@ class ImageSaveHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_animated_png(
|
def save_animated_png(
|
||||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, fps: float, compress_level: int
|
||||||
) -> SavedResult:
|
) -> SavedResult:
|
||||||
"""Saves a batch of images as a single animated PNG."""
|
"""Saves a batch of images as a single animated PNG."""
|
||||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||||
@ -191,7 +190,7 @@ class ImageSaveHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_save_animated_png_ui(
|
def get_save_animated_png_ui(
|
||||||
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
images, filename_prefix: str, cls: type[ComfyNode] | None, fps: float, compress_level: int
|
||||||
) -> SavedImages:
|
) -> SavedImages:
|
||||||
"""Saves an animated PNG and returns a UI object for the node output."""
|
"""Saves an animated PNG and returns a UI object for the node output."""
|
||||||
result = ImageSaveHelper.save_animated_png(
|
result = ImageSaveHelper.save_animated_png(
|
||||||
@ -209,7 +208,7 @@ class ImageSaveHelper:
|
|||||||
images,
|
images,
|
||||||
filename_prefix: str,
|
filename_prefix: str,
|
||||||
folder_type: FolderType,
|
folder_type: FolderType,
|
||||||
cls: Type[ComfyNode] | None,
|
cls: type[ComfyNode] | None,
|
||||||
fps: float,
|
fps: float,
|
||||||
lossless: bool,
|
lossless: bool,
|
||||||
quality: int,
|
quality: int,
|
||||||
@ -238,7 +237,7 @@ class ImageSaveHelper:
|
|||||||
def get_save_animated_webp_ui(
|
def get_save_animated_webp_ui(
|
||||||
images,
|
images,
|
||||||
filename_prefix: str,
|
filename_prefix: str,
|
||||||
cls: Type[ComfyNode] | None,
|
cls: type[ComfyNode] | None,
|
||||||
fps: float,
|
fps: float,
|
||||||
lossless: bool,
|
lossless: bool,
|
||||||
quality: int,
|
quality: int,
|
||||||
@ -267,7 +266,7 @@ class AudioSaveHelper:
|
|||||||
audio: dict,
|
audio: dict,
|
||||||
filename_prefix: str,
|
filename_prefix: str,
|
||||||
folder_type: FolderType,
|
folder_type: FolderType,
|
||||||
cls: Type[ComfyNode] | None,
|
cls: type[ComfyNode] | None,
|
||||||
format: str = "flac",
|
format: str = "flac",
|
||||||
quality: str = "128k",
|
quality: str = "128k",
|
||||||
) -> list[SavedResult]:
|
) -> list[SavedResult]:
|
||||||
@ -372,7 +371,7 @@ class AudioSaveHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_save_audio_ui(
|
def get_save_audio_ui(
|
||||||
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
audio, filename_prefix: str, cls: type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
||||||
) -> SavedAudios:
|
) -> SavedAudios:
|
||||||
"""Save and instantly wrap for UI."""
|
"""Save and instantly wrap for UI."""
|
||||||
return SavedAudios(
|
return SavedAudios(
|
||||||
@ -388,7 +387,7 @@ class AudioSaveHelper:
|
|||||||
|
|
||||||
|
|
||||||
class PreviewImage(_UIOutput):
|
class PreviewImage(_UIOutput):
|
||||||
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
|
def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs):
|
||||||
self.values = ImageSaveHelper.save_images(
|
self.values = ImageSaveHelper.save_images(
|
||||||
image,
|
image,
|
||||||
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
|
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
|
||||||
@ -412,7 +411,7 @@ class PreviewMask(PreviewImage):
|
|||||||
|
|
||||||
|
|
||||||
class PreviewAudio(_UIOutput):
|
class PreviewAudio(_UIOutput):
|
||||||
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
|
def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs):
|
||||||
self.values = AudioSaveHelper.save_audio(
|
self.values = AudioSaveHelper.save_audio(
|
||||||
audio,
|
audio,
|
||||||
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
|
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||||
from .geometry_types import VOXEL, MESH
|
from .geometry_types import VOXEL, MESH
|
||||||
|
from .image_types import SVG
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Utility Types
|
# Utility Types
|
||||||
@ -8,4 +9,5 @@ __all__ = [
|
|||||||
"VideoComponents",
|
"VideoComponents",
|
||||||
"VOXEL",
|
"VOXEL",
|
||||||
"MESH",
|
"MESH",
|
||||||
|
"SVG",
|
||||||
]
|
]
|
||||||
|
|||||||
18
comfy_api/latest/_util/image_types.py
Normal file
18
comfy_api/latest/_util/image_types.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
|
class SVG:
|
||||||
|
"""Stores SVG representations via a list of BytesIO objects."""
|
||||||
|
|
||||||
|
def __init__(self, data: list[BytesIO]):
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def combine(self, other: 'SVG') -> 'SVG':
|
||||||
|
return SVG(self.data + other.data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def combine_all(svgs: list['SVG']) -> 'SVG':
|
||||||
|
all_svgs_list: list[BytesIO] = []
|
||||||
|
for svg_item in svgs:
|
||||||
|
all_svgs_list.extend(svg_item.data)
|
||||||
|
return SVG(all_svgs_list)
|
||||||
@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from comfy_api.latest._input import ImageInput, AudioInput
|
from .._input import ImageInput, AudioInput
|
||||||
|
|
||||||
class VideoCodec(str, Enum):
|
class VideoCodec(str, Enum):
|
||||||
AUTO = "auto"
|
AUTO = "auto"
|
||||||
|
|||||||
@ -2,9 +2,8 @@ from comfy_api.latest import ComfyAPI_latest
|
|||||||
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
||||||
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
||||||
from comfy_api.internal import ComfyAPIBase
|
from comfy_api.internal import ComfyAPIBase
|
||||||
from typing import List, Type
|
|
||||||
|
|
||||||
supported_versions: List[Type[ComfyAPIBase]] = [
|
supported_versions: list[type[ComfyAPIBase]] = [
|
||||||
ComfyAPI_latest,
|
ComfyAPI_latest,
|
||||||
ComfyAPIAdapter_v0_0_2,
|
ComfyAPIAdapter_v0_0_2,
|
||||||
ComfyAPIAdapter_v0_0_1,
|
ComfyAPIAdapter_v0_0_1,
|
||||||
|
|||||||
144
comfy_api_nodes/apis/bytedance_api.py
Normal file
144
comfy_api_nodes/apis/bytedance_api.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Text2ImageTaskCreationRequest(BaseModel):
|
||||||
|
model: str = Field(...)
|
||||||
|
prompt: str = Field(...)
|
||||||
|
response_format: str | None = Field("url")
|
||||||
|
size: str | None = Field(None)
|
||||||
|
seed: int | None = Field(0, ge=0, le=2147483647)
|
||||||
|
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||||
|
watermark: bool | None = Field(False)
|
||||||
|
|
||||||
|
|
||||||
|
class Image2ImageTaskCreationRequest(BaseModel):
|
||||||
|
model: str = Field(...)
|
||||||
|
prompt: str = Field(...)
|
||||||
|
response_format: str | None = Field("url")
|
||||||
|
image: str = Field(..., description="Base64 encoded string or image URL")
|
||||||
|
size: str | None = Field("adaptive")
|
||||||
|
seed: int | None = Field(..., ge=0, le=2147483647)
|
||||||
|
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||||
|
watermark: bool | None = Field(False)
|
||||||
|
|
||||||
|
|
||||||
|
class Seedream4Options(BaseModel):
|
||||||
|
max_images: int = Field(15)
|
||||||
|
|
||||||
|
|
||||||
|
class Seedream4TaskCreationRequest(BaseModel):
|
||||||
|
model: str = Field(...)
|
||||||
|
prompt: str = Field(...)
|
||||||
|
response_format: str = Field("url")
|
||||||
|
image: list[str] | None = Field(None, description="Image URLs")
|
||||||
|
size: str = Field(...)
|
||||||
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
|
sequential_image_generation: str = Field("disabled")
|
||||||
|
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
||||||
|
watermark: bool = Field(False)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageTaskCreationResponse(BaseModel):
|
||||||
|
model: str = Field(...)
|
||||||
|
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
|
||||||
|
data: list = Field([], description="Contains information about the generated image(s).")
|
||||||
|
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskTextContent(BaseModel):
|
||||||
|
type: str = Field("text")
|
||||||
|
text: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskImageContentUrl(BaseModel):
|
||||||
|
url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskImageContent(BaseModel):
|
||||||
|
type: str = Field("image_url")
|
||||||
|
image_url: TaskImageContentUrl = Field(...)
|
||||||
|
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class Text2VideoTaskCreationRequest(BaseModel):
|
||||||
|
model: str = Field(...)
|
||||||
|
content: list[TaskTextContent] = Field(..., min_length=1)
|
||||||
|
|
||||||
|
|
||||||
|
class Image2VideoTaskCreationRequest(BaseModel):
|
||||||
|
model: str = Field(...)
|
||||||
|
content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskCreationResponse(BaseModel):
|
||||||
|
id: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatusError(BaseModel):
|
||||||
|
code: str = Field(...)
|
||||||
|
message: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatusResult(BaseModel):
|
||||||
|
video_url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatusResponse(BaseModel):
|
||||||
|
id: str = Field(...)
|
||||||
|
model: str = Field(...)
|
||||||
|
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
||||||
|
error: TaskStatusError | None = Field(None)
|
||||||
|
content: TaskStatusResult | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
RECOMMENDED_PRESETS = [
|
||||||
|
("1024x1024 (1:1)", 1024, 1024),
|
||||||
|
("864x1152 (3:4)", 864, 1152),
|
||||||
|
("1152x864 (4:3)", 1152, 864),
|
||||||
|
("1280x720 (16:9)", 1280, 720),
|
||||||
|
("720x1280 (9:16)", 720, 1280),
|
||||||
|
("832x1248 (2:3)", 832, 1248),
|
||||||
|
("1248x832 (3:2)", 1248, 832),
|
||||||
|
("1512x648 (21:9)", 1512, 648),
|
||||||
|
("2048x2048 (1:1)", 2048, 2048),
|
||||||
|
("Custom", None, None),
|
||||||
|
]
|
||||||
|
|
||||||
|
RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||||
|
("2048x2048 (1:1)", 2048, 2048),
|
||||||
|
("2304x1728 (4:3)", 2304, 1728),
|
||||||
|
("1728x2304 (3:4)", 1728, 2304),
|
||||||
|
("2560x1440 (16:9)", 2560, 1440),
|
||||||
|
("1440x2560 (9:16)", 1440, 2560),
|
||||||
|
("2496x1664 (3:2)", 2496, 1664),
|
||||||
|
("1664x2496 (2:3)", 1664, 2496),
|
||||||
|
("3024x1296 (21:9)", 3024, 1296),
|
||||||
|
("4096x4096 (1:1)", 4096, 4096),
|
||||||
|
("Custom", None, None),
|
||||||
|
]
|
||||||
|
|
||||||
|
# The time in this dictionary are given for 10 seconds duration.
|
||||||
|
VIDEO_TASKS_EXECUTION_TIME = {
|
||||||
|
"seedance-1-0-lite-t2v-250428": {
|
||||||
|
"480p": 40,
|
||||||
|
"720p": 60,
|
||||||
|
"1080p": 90,
|
||||||
|
},
|
||||||
|
"seedance-1-0-lite-i2v-250428": {
|
||||||
|
"480p": 40,
|
||||||
|
"720p": 60,
|
||||||
|
"1080p": 90,
|
||||||
|
},
|
||||||
|
"seedance-1-0-pro-250528": {
|
||||||
|
"480p": 70,
|
||||||
|
"720p": 85,
|
||||||
|
"1080p": 115,
|
||||||
|
},
|
||||||
|
"seedance-1-0-pro-fast-251015": {
|
||||||
|
"480p": 50,
|
||||||
|
"720p": 65,
|
||||||
|
"1080p": 100,
|
||||||
|
},
|
||||||
|
}
|
||||||
@ -84,15 +84,7 @@ class GeminiSystemInstructionContent(BaseModel):
|
|||||||
description="A list of ordered parts that make up a single message. "
|
description="A list of ordered parts that make up a single message. "
|
||||||
"Different parts may have different IANA MIME types.",
|
"Different parts may have different IANA MIME types.",
|
||||||
)
|
)
|
||||||
role: GeminiRole = Field(
|
role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.")
|
||||||
...,
|
|
||||||
description="The identity of the entity that creates the message. "
|
|
||||||
"The following values are supported: "
|
|
||||||
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
|
|
||||||
"model: This indicates that the message is generated by the model. "
|
|
||||||
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
|
|
||||||
"For non-multi-turn conversations, this field can be left blank or unset.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiFunctionDeclaration(BaseModel):
|
class GeminiFunctionDeclaration(BaseModel):
|
||||||
@ -141,6 +133,7 @@ class GeminiImageGenerateContentRequest(BaseModel):
|
|||||||
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
|
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
|
||||||
tools: list[GeminiTool] | None = Field(None)
|
tools: list[GeminiTool] | None = Field(None)
|
||||||
videoMetadata: GeminiVideoMetadata | None = Field(None)
|
videoMetadata: GeminiVideoMetadata | None = Field(None)
|
||||||
|
uploadImagesToStorage: bool = Field(True)
|
||||||
|
|
||||||
|
|
||||||
class GeminiGenerateContentRequest(BaseModel):
|
class GeminiGenerateContentRequest(BaseModel):
|
||||||
|
|||||||
@ -51,25 +51,25 @@ class TaskStatusImageResult(BaseModel):
|
|||||||
url: str = Field(..., description="URL for generated image")
|
url: str = Field(..., description="URL for generated image")
|
||||||
|
|
||||||
|
|
||||||
class OmniTaskStatusResults(BaseModel):
|
class TaskStatusResults(BaseModel):
|
||||||
videos: list[TaskStatusVideoResult] | None = Field(None)
|
videos: list[TaskStatusVideoResult] | None = Field(None)
|
||||||
images: list[TaskStatusImageResult] | None = Field(None)
|
images: list[TaskStatusImageResult] | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class OmniTaskStatusResponseData(BaseModel):
|
class TaskStatusResponseData(BaseModel):
|
||||||
created_at: int | None = Field(None, description="Task creation time")
|
created_at: int | None = Field(None, description="Task creation time")
|
||||||
updated_at: int | None = Field(None, description="Task update time")
|
updated_at: int | None = Field(None, description="Task update time")
|
||||||
task_status: str | None = None
|
task_status: str | None = None
|
||||||
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
|
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
|
||||||
task_id: str | None = Field(None, description="Task ID")
|
task_id: str | None = Field(None, description="Task ID")
|
||||||
task_result: OmniTaskStatusResults | None = Field(None)
|
task_result: TaskStatusResults | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class OmniTaskStatusResponse(BaseModel):
|
class TaskStatusResponse(BaseModel):
|
||||||
code: int | None = Field(None, description="Error code")
|
code: int | None = Field(None, description="Error code")
|
||||||
message: str | None = Field(None, description="Error message")
|
message: str | None = Field(None, description="Error message")
|
||||||
request_id: str | None = Field(None, description="Request ID")
|
request_id: str | None = Field(None, description="Request ID")
|
||||||
data: OmniTaskStatusResponseData | None = Field(None)
|
data: TaskStatusResponseData | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class OmniImageParamImage(BaseModel):
|
class OmniImageParamImage(BaseModel):
|
||||||
@ -84,3 +84,30 @@ class OmniProImageRequest(BaseModel):
|
|||||||
mode: str = Field("pro")
|
mode: str = Field("pro")
|
||||||
n: int | None = Field(1, le=9)
|
n: int | None = Field(1, le=9)
|
||||||
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
|
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
|
||||||
|
|
||||||
|
|
||||||
|
class TextToVideoWithAudioRequest(BaseModel):
|
||||||
|
model_name: str = Field(..., description="kling-v2-6")
|
||||||
|
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||||
|
duration: str = Field(..., description="'5' or '10'")
|
||||||
|
prompt: str = Field(...)
|
||||||
|
mode: str = Field("pro")
|
||||||
|
sound: str = Field(..., description="'on' or 'off'")
|
||||||
|
|
||||||
|
|
||||||
|
class ImageToVideoWithAudioRequest(BaseModel):
|
||||||
|
model_name: str = Field(..., description="kling-v2-6")
|
||||||
|
image: str = Field(...)
|
||||||
|
duration: str = Field(..., description="'5' or '10'")
|
||||||
|
prompt: str = Field(...)
|
||||||
|
mode: str = Field("pro")
|
||||||
|
sound: str = Field(..., description="'on' or 'off'")
|
||||||
|
|
||||||
|
|
||||||
|
class MotionControlRequest(BaseModel):
|
||||||
|
prompt: str = Field(...)
|
||||||
|
image_url: str = Field(...)
|
||||||
|
video_url: str = Field(...)
|
||||||
|
keep_original_sound: str = Field(...)
|
||||||
|
character_orientation: str = Field(...)
|
||||||
|
mode: str = Field(..., description="'pro' or 'std'")
|
||||||
|
|||||||
52
comfy_api_nodes/apis/openai_api.py
Normal file
52
comfy_api_nodes/apis/openai_api.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Datum2(BaseModel):
|
||||||
|
b64_json: str | None = Field(None, description="Base64 encoded image data")
|
||||||
|
revised_prompt: str | None = Field(None, description="Revised prompt")
|
||||||
|
url: str | None = Field(None, description="URL of the image")
|
||||||
|
|
||||||
|
|
||||||
|
class InputTokensDetails(BaseModel):
|
||||||
|
image_tokens: int | None = None
|
||||||
|
text_tokens: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Usage(BaseModel):
|
||||||
|
input_tokens: int | None = None
|
||||||
|
input_tokens_details: InputTokensDetails | None = None
|
||||||
|
output_tokens: int | None = None
|
||||||
|
total_tokens: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIImageGenerationResponse(BaseModel):
|
||||||
|
data: list[Datum2] | None = None
|
||||||
|
usage: Usage | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIImageEditRequest(BaseModel):
|
||||||
|
background: str | None = Field(None, description="Background transparency")
|
||||||
|
model: str = Field(...)
|
||||||
|
moderation: str | None = Field(None)
|
||||||
|
n: int | None = Field(None, description="The number of images to generate")
|
||||||
|
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
|
||||||
|
output_format: str | None = Field(None)
|
||||||
|
prompt: str = Field(...)
|
||||||
|
quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
|
||||||
|
size: str | None = Field(None, description="Size of the output image")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIImageGenerationRequest(BaseModel):
|
||||||
|
background: str | None = Field(None, description="Background transparency")
|
||||||
|
model: str | None = Field(None)
|
||||||
|
moderation: str | None = Field(None)
|
||||||
|
n: int | None = Field(
|
||||||
|
None,
|
||||||
|
description="The number of images to generate.",
|
||||||
|
)
|
||||||
|
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
|
||||||
|
output_format: str | None = Field(None)
|
||||||
|
prompt: str = Field(...)
|
||||||
|
quality: str | None = Field(None, description="The quality of the generated image")
|
||||||
|
size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
|
||||||
|
style: str | None = Field(None, description="Style of the image (only for dall-e-3)")
|
||||||
@ -1,100 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
from enum import Enum
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class Pikaffect(str, Enum):
|
|
||||||
Cake_ify = "Cake-ify"
|
|
||||||
Crumble = "Crumble"
|
|
||||||
Crush = "Crush"
|
|
||||||
Decapitate = "Decapitate"
|
|
||||||
Deflate = "Deflate"
|
|
||||||
Dissolve = "Dissolve"
|
|
||||||
Explode = "Explode"
|
|
||||||
Eye_pop = "Eye-pop"
|
|
||||||
Inflate = "Inflate"
|
|
||||||
Levitate = "Levitate"
|
|
||||||
Melt = "Melt"
|
|
||||||
Peel = "Peel"
|
|
||||||
Poke = "Poke"
|
|
||||||
Squish = "Squish"
|
|
||||||
Ta_da = "Ta-da"
|
|
||||||
Tear = "Tear"
|
|
||||||
|
|
||||||
|
|
||||||
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
|
|
||||||
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
|
|
||||||
duration: Optional[int] = Field(5)
|
|
||||||
ingredientsMode: str = Field(...)
|
|
||||||
negativePrompt: Optional[str] = Field(None)
|
|
||||||
promptText: Optional[str] = Field(None)
|
|
||||||
resolution: Optional[str] = Field('1080p')
|
|
||||||
seed: Optional[int] = Field(None)
|
|
||||||
|
|
||||||
|
|
||||||
class PikaGenerateResponse(BaseModel):
|
|
||||||
video_id: str = Field(...)
|
|
||||||
|
|
||||||
|
|
||||||
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
|
|
||||||
duration: Optional[int] = 5
|
|
||||||
negativePrompt: Optional[str] = Field(None)
|
|
||||||
promptText: Optional[str] = Field(None)
|
|
||||||
resolution: Optional[str] = '1080p'
|
|
||||||
seed: Optional[int] = Field(None)
|
|
||||||
|
|
||||||
|
|
||||||
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
|
|
||||||
duration: Optional[int] = Field(None, ge=5, le=10)
|
|
||||||
negativePrompt: Optional[str] = Field(None)
|
|
||||||
promptText: str = Field(...)
|
|
||||||
resolution: Optional[str] = '1080p'
|
|
||||||
seed: Optional[int] = Field(None)
|
|
||||||
|
|
||||||
|
|
||||||
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
|
|
||||||
aspectRatio: Optional[float] = Field(
|
|
||||||
1.7777777777777777,
|
|
||||||
description='Aspect ratio (width / height)',
|
|
||||||
ge=0.4,
|
|
||||||
le=2.5,
|
|
||||||
)
|
|
||||||
duration: Optional[int] = 5
|
|
||||||
negativePrompt: Optional[str] = Field(None)
|
|
||||||
promptText: str = Field(...)
|
|
||||||
resolution: Optional[str] = '1080p'
|
|
||||||
seed: Optional[int] = Field(None)
|
|
||||||
|
|
||||||
|
|
||||||
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
|
|
||||||
negativePrompt: Optional[str] = Field(None)
|
|
||||||
promptText: Optional[str] = Field(None)
|
|
||||||
seed: Optional[int] = Field(None)
|
|
||||||
|
|
||||||
|
|
||||||
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
|
|
||||||
negativePrompt: Optional[str] = Field(None)
|
|
||||||
pikaffect: Optional[str] = None
|
|
||||||
promptText: Optional[str] = Field(None)
|
|
||||||
seed: Optional[int] = Field(None)
|
|
||||||
|
|
||||||
|
|
||||||
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
|
|
||||||
negativePrompt: Optional[str] = Field(None)
|
|
||||||
promptText: Optional[str] = Field(None)
|
|
||||||
seed: Optional[int] = Field(None)
|
|
||||||
modifyRegionRoi: Optional[str] = Field(None)
|
|
||||||
|
|
||||||
|
|
||||||
class PikaStatusEnum(str, Enum):
|
|
||||||
queued = "queued"
|
|
||||||
started = "started"
|
|
||||||
finished = "finished"
|
|
||||||
failed = "failed"
|
|
||||||
|
|
||||||
|
|
||||||
class PikaVideoResponse(BaseModel):
|
|
||||||
id: str = Field(...)
|
|
||||||
progress: Optional[int] = Field(None)
|
|
||||||
status: PikaStatusEnum
|
|
||||||
url: Optional[str] = Field(None)
|
|
||||||
@ -5,11 +5,17 @@ from typing import Optional, List, Dict, Any, Union
|
|||||||
from pydantic import BaseModel, Field, RootModel
|
from pydantic import BaseModel, Field, RootModel
|
||||||
|
|
||||||
class TripoModelVersion(str, Enum):
|
class TripoModelVersion(str, Enum):
|
||||||
|
v3_0_20250812 = 'v3.0-20250812'
|
||||||
v2_5_20250123 = 'v2.5-20250123'
|
v2_5_20250123 = 'v2.5-20250123'
|
||||||
v2_0_20240919 = 'v2.0-20240919'
|
v2_0_20240919 = 'v2.0-20240919'
|
||||||
v1_4_20240625 = 'v1.4-20240625'
|
v1_4_20240625 = 'v1.4-20240625'
|
||||||
|
|
||||||
|
|
||||||
|
class TripoGeometryQuality(str, Enum):
|
||||||
|
standard = 'standard'
|
||||||
|
detailed = 'detailed'
|
||||||
|
|
||||||
|
|
||||||
class TripoTextureQuality(str, Enum):
|
class TripoTextureQuality(str, Enum):
|
||||||
standard = 'standard'
|
standard = 'standard'
|
||||||
detailed = 'detailed'
|
detailed = 'detailed'
|
||||||
@ -61,14 +67,20 @@ class TripoSpec(str, Enum):
|
|||||||
class TripoAnimation(str, Enum):
|
class TripoAnimation(str, Enum):
|
||||||
IDLE = "preset:idle"
|
IDLE = "preset:idle"
|
||||||
WALK = "preset:walk"
|
WALK = "preset:walk"
|
||||||
|
RUN = "preset:run"
|
||||||
|
DIVE = "preset:dive"
|
||||||
CLIMB = "preset:climb"
|
CLIMB = "preset:climb"
|
||||||
JUMP = "preset:jump"
|
JUMP = "preset:jump"
|
||||||
RUN = "preset:run"
|
|
||||||
SLASH = "preset:slash"
|
SLASH = "preset:slash"
|
||||||
SHOOT = "preset:shoot"
|
SHOOT = "preset:shoot"
|
||||||
HURT = "preset:hurt"
|
HURT = "preset:hurt"
|
||||||
FALL = "preset:fall"
|
FALL = "preset:fall"
|
||||||
TURN = "preset:turn"
|
TURN = "preset:turn"
|
||||||
|
QUADRUPED_WALK = "preset:quadruped:walk"
|
||||||
|
HEXAPOD_WALK = "preset:hexapod:walk"
|
||||||
|
OCTOPOD_WALK = "preset:octopod:walk"
|
||||||
|
SERPENTINE_MARCH = "preset:serpentine:march"
|
||||||
|
AQUATIC_MARCH = "preset:aquatic:march"
|
||||||
|
|
||||||
class TripoStylizeStyle(str, Enum):
|
class TripoStylizeStyle(str, Enum):
|
||||||
LEGO = "lego"
|
LEGO = "lego"
|
||||||
@ -105,6 +117,11 @@ class TripoTaskStatus(str, Enum):
|
|||||||
BANNED = "banned"
|
BANNED = "banned"
|
||||||
EXPIRED = "expired"
|
EXPIRED = "expired"
|
||||||
|
|
||||||
|
class TripoFbxPreset(str, Enum):
|
||||||
|
BLENDER = "blender"
|
||||||
|
MIXAMO = "mixamo"
|
||||||
|
_3DSMAX = "3dsmax"
|
||||||
|
|
||||||
class TripoFileTokenReference(BaseModel):
|
class TripoFileTokenReference(BaseModel):
|
||||||
type: Optional[str] = Field(None, description='The type of the reference')
|
type: Optional[str] = Field(None, description='The type of the reference')
|
||||||
file_token: str
|
file_token: str
|
||||||
@ -142,6 +159,7 @@ class TripoTextToModelRequest(BaseModel):
|
|||||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||||
|
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
||||||
style: Optional[TripoStyle] = None
|
style: Optional[TripoStyle] = None
|
||||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||||
@ -156,6 +174,7 @@ class TripoImageToModelRequest(BaseModel):
|
|||||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||||
|
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
||||||
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
|
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
|
||||||
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
|
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
|
||||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||||
@ -173,6 +192,7 @@ class TripoMultiviewToModelRequest(BaseModel):
|
|||||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||||
|
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
||||||
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
|
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
|
||||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||||
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
|
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
|
||||||
@ -219,14 +239,24 @@ class TripoConvertModelRequest(BaseModel):
|
|||||||
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
|
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
|
||||||
format: TripoConvertFormat = Field(..., description='The format to convert to')
|
format: TripoConvertFormat = Field(..., description='The format to convert to')
|
||||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
|
quad: Optional[bool] = Field(None, description='Whether to apply quad to the model')
|
||||||
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
|
force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry')
|
||||||
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
|
face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to')
|
||||||
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
|
flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model')
|
||||||
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
|
flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom')
|
||||||
texture_size: Optional[int] = Field(4096, description='The size of the texture')
|
texture_size: Optional[int] = Field(None, description='The size of the texture')
|
||||||
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
|
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
|
||||||
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
|
pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom')
|
||||||
|
scale_factor: Optional[float] = Field(None, description='The scale factor for the model')
|
||||||
|
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
|
||||||
|
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
|
||||||
|
bake: Optional[bool] = Field(None, description='Whether to bake the model')
|
||||||
|
part_names: Optional[List[str]] = Field(None, description='The names of the parts to include')
|
||||||
|
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
|
||||||
|
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
|
||||||
|
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
|
||||||
|
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
|
||||||
|
|
||||||
|
|
||||||
class TripoTaskRequest(RootModel):
|
class TripoTaskRequest(RootModel):
|
||||||
root: Union[
|
root: Union[
|
||||||
|
|||||||
@ -85,7 +85,7 @@ class Response1(BaseModel):
|
|||||||
raiMediaFilteredReasons: Optional[list[str]] = Field(
|
raiMediaFilteredReasons: Optional[list[str]] = Field(
|
||||||
None, description='Reasons why media was filtered by responsible AI policies'
|
None, description='Reasons why media was filtered by responsible AI policies'
|
||||||
)
|
)
|
||||||
videos: Optional[list[Video]] = None
|
videos: Optional[list[Video]] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class VeoGenVidPollResponse(BaseModel):
|
class VeoGenVidPollResponse(BaseModel):
|
||||||
|
|||||||
@ -1,10 +1,8 @@
|
|||||||
from inspect import cleandoc
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.latest import IO, ComfyExtension
|
from comfy_api.latest import IO, ComfyExtension, Input
|
||||||
from comfy_api_nodes.apis.bfl_api import (
|
from comfy_api_nodes.apis.bfl_api import (
|
||||||
BFLFluxExpandImageRequest,
|
BFLFluxExpandImageRequest,
|
||||||
BFLFluxFillImageRequest,
|
BFLFluxFillImageRequest,
|
||||||
@ -28,7 +26,7 @@ from comfy_api_nodes.util import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_mask_to_image(mask: torch.Tensor):
|
def convert_mask_to_image(mask: Input.Image):
|
||||||
"""
|
"""
|
||||||
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
|
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
|
||||||
"""
|
"""
|
||||||
@ -38,9 +36,6 @@ def convert_mask_to_image(mask: torch.Tensor):
|
|||||||
|
|
||||||
|
|
||||||
class FluxProUltraImageNode(IO.ComfyNode):
|
class FluxProUltraImageNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
@ -48,7 +43,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
|||||||
node_id="FluxProUltraImageNode",
|
node_id="FluxProUltraImageNode",
|
||||||
display_name="Flux 1.1 [pro] Ultra Image",
|
display_name="Flux 1.1 [pro] Ultra Image",
|
||||||
category="api node/image/BFL",
|
category="api node/image/BFL",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -117,7 +112,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
|||||||
prompt_upsampling: bool = False,
|
prompt_upsampling: bool = False,
|
||||||
raw: bool = False,
|
raw: bool = False,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
image_prompt: torch.Tensor | None = None,
|
image_prompt: Input.Image | None = None,
|
||||||
image_prompt_strength: float = 0.1,
|
image_prompt_strength: float = 0.1,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
if image_prompt is None:
|
if image_prompt is None:
|
||||||
@ -155,9 +150,6 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class FluxKontextProImageNode(IO.ComfyNode):
|
class FluxKontextProImageNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
@ -165,7 +157,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
|||||||
node_id=cls.NODE_ID,
|
node_id=cls.NODE_ID,
|
||||||
display_name=cls.DISPLAY_NAME,
|
display_name=cls.DISPLAY_NAME,
|
||||||
category="api node/image/BFL",
|
category="api node/image/BFL",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -231,7 +223,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
|||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
guidance: float,
|
guidance: float,
|
||||||
steps: int,
|
steps: int,
|
||||||
input_image: torch.Tensor | None = None,
|
input_image: Input.Image | None = None,
|
||||||
seed=0,
|
seed=0,
|
||||||
prompt_upsampling=False,
|
prompt_upsampling=False,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
@ -271,20 +263,14 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class FluxKontextMaxImageNode(FluxKontextProImageNode):
|
class FluxKontextMaxImageNode(FluxKontextProImageNode):
|
||||||
"""
|
|
||||||
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
|
|
||||||
"""
|
|
||||||
|
|
||||||
DESCRIPTION = cleandoc(__doc__ or "")
|
DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio."
|
||||||
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
|
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
|
||||||
NODE_ID = "FluxKontextMaxImageNode"
|
NODE_ID = "FluxKontextMaxImageNode"
|
||||||
DISPLAY_NAME = "Flux.1 Kontext [max] Image"
|
DISPLAY_NAME = "Flux.1 Kontext [max] Image"
|
||||||
|
|
||||||
|
|
||||||
class FluxProExpandNode(IO.ComfyNode):
|
class FluxProExpandNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Outpaints image based on prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
@ -292,7 +278,7 @@ class FluxProExpandNode(IO.ComfyNode):
|
|||||||
node_id="FluxProExpandNode",
|
node_id="FluxProExpandNode",
|
||||||
display_name="Flux.1 Expand Image",
|
display_name="Flux.1 Expand Image",
|
||||||
category="api node/image/BFL",
|
category="api node/image/BFL",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Outpaints image based on prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
@ -371,7 +357,7 @@ class FluxProExpandNode(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
image: torch.Tensor,
|
image: Input.Image,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
prompt_upsampling: bool,
|
prompt_upsampling: bool,
|
||||||
top: int,
|
top: int,
|
||||||
@ -418,9 +404,6 @@ class FluxProExpandNode(IO.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class FluxProFillNode(IO.ComfyNode):
|
class FluxProFillNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Inpaints image based on mask and prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
@ -428,7 +411,7 @@ class FluxProFillNode(IO.ComfyNode):
|
|||||||
node_id="FluxProFillNode",
|
node_id="FluxProFillNode",
|
||||||
display_name="Flux.1 Fill Image",
|
display_name="Flux.1 Fill Image",
|
||||||
category="api node/image/BFL",
|
category="api node/image/BFL",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Inpaints image based on mask and prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
IO.Mask.Input("mask"),
|
IO.Mask.Input("mask"),
|
||||||
@ -480,8 +463,8 @@ class FluxProFillNode(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
image: torch.Tensor,
|
image: Input.Image,
|
||||||
mask: torch.Tensor,
|
mask: Input.Image,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
prompt_upsampling: bool,
|
prompt_upsampling: bool,
|
||||||
steps: int,
|
steps: int,
|
||||||
@ -525,11 +508,15 @@ class FluxProFillNode(IO.ComfyNode):
|
|||||||
|
|
||||||
class Flux2ProImageNode(IO.ComfyNode):
|
class Flux2ProImageNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
NODE_ID = "Flux2ProImageNode"
|
||||||
|
DISPLAY_NAME = "Flux.2 [pro] Image"
|
||||||
|
API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="Flux2ProImageNode",
|
node_id=cls.NODE_ID,
|
||||||
display_name="Flux.2 [pro] Image",
|
display_name=cls.DISPLAY_NAME,
|
||||||
category="api node/image/BFL",
|
category="api node/image/BFL",
|
||||||
description="Generates images synchronously based on prompt and resolution.",
|
description="Generates images synchronously based on prompt and resolution.",
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -563,12 +550,11 @@ class Flux2ProImageNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"prompt_upsampling",
|
"prompt_upsampling",
|
||||||
default=False,
|
default=True,
|
||||||
tooltip="Whether to perform upsampling on the prompt. "
|
tooltip="Whether to perform upsampling on the prompt. "
|
||||||
"If active, automatically modifies the prompt for more creative generation, "
|
"If active, automatically modifies the prompt for more creative generation.",
|
||||||
"but results are nondeterministic (same seed will not produce exactly the same result).",
|
|
||||||
),
|
),
|
||||||
IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."),
|
IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."),
|
||||||
],
|
],
|
||||||
outputs=[IO.Image.Output()],
|
outputs=[IO.Image.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
@ -587,7 +573,7 @@ class Flux2ProImageNode(IO.ComfyNode):
|
|||||||
height: int,
|
height: int,
|
||||||
seed: int,
|
seed: int,
|
||||||
prompt_upsampling: bool,
|
prompt_upsampling: bool,
|
||||||
images: torch.Tensor | None = None,
|
images: Input.Image | None = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
reference_images = {}
|
reference_images = {}
|
||||||
if images is not None:
|
if images is not None:
|
||||||
@ -598,7 +584,7 @@ class Flux2ProImageNode(IO.ComfyNode):
|
|||||||
reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
|
reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
|
||||||
initial_response = await sync_op(
|
initial_response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"),
|
ApiEndpoint(path=cls.API_ENDPOINT, method="POST"),
|
||||||
response_model=BFLFluxProGenerateResponse,
|
response_model=BFLFluxProGenerateResponse,
|
||||||
data=Flux2ProGenerateRequest(
|
data=Flux2ProGenerateRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -632,6 +618,13 @@ class Flux2ProImageNode(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2MaxImageNode(Flux2ProImageNode):
|
||||||
|
|
||||||
|
NODE_ID = "Flux2MaxImageNode"
|
||||||
|
DISPLAY_NAME = "Flux.2 [max] Image"
|
||||||
|
API_ENDPOINT = "/proxy/bfl/flux-2-max/generate"
|
||||||
|
|
||||||
|
|
||||||
class BFLExtension(ComfyExtension):
|
class BFLExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -642,6 +635,7 @@ class BFLExtension(ComfyExtension):
|
|||||||
FluxProExpandNode,
|
FluxProExpandNode,
|
||||||
FluxProFillNode,
|
FluxProFillNode,
|
||||||
Flux2ProImageNode,
|
Flux2ProImageNode,
|
||||||
|
Flux2MaxImageNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,13 +1,27 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from enum import Enum
|
|
||||||
from typing import Literal, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.latest import IO, ComfyExtension
|
from comfy_api.latest import IO, ComfyExtension, Input
|
||||||
|
from comfy_api_nodes.apis.bytedance_api import (
|
||||||
|
RECOMMENDED_PRESETS,
|
||||||
|
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||||
|
VIDEO_TASKS_EXECUTION_TIME,
|
||||||
|
Image2ImageTaskCreationRequest,
|
||||||
|
Image2VideoTaskCreationRequest,
|
||||||
|
ImageTaskCreationResponse,
|
||||||
|
Seedream4Options,
|
||||||
|
Seedream4TaskCreationRequest,
|
||||||
|
TaskCreationResponse,
|
||||||
|
TaskImageContent,
|
||||||
|
TaskImageContentUrl,
|
||||||
|
TaskStatusResponse,
|
||||||
|
TaskTextContent,
|
||||||
|
Text2ImageTaskCreationRequest,
|
||||||
|
Text2VideoTaskCreationRequest,
|
||||||
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
download_url_to_image_tensor,
|
download_url_to_image_tensor,
|
||||||
@ -29,162 +43,6 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
|||||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||||
|
|
||||||
|
|
||||||
class Text2ImageModelName(str, Enum):
|
|
||||||
seedream_3 = "seedream-3-0-t2i-250415"
|
|
||||||
|
|
||||||
|
|
||||||
class Image2ImageModelName(str, Enum):
|
|
||||||
seededit_3 = "seededit-3-0-i2i-250628"
|
|
||||||
|
|
||||||
|
|
||||||
class Text2VideoModelName(str, Enum):
|
|
||||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
|
||||||
seedance_1_lite = "seedance-1-0-lite-t2v-250428"
|
|
||||||
|
|
||||||
|
|
||||||
class Image2VideoModelName(str, Enum):
|
|
||||||
"""note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
|
|
||||||
|
|
||||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
|
||||||
seedance_1_lite = "seedance-1-0-lite-i2v-250428"
|
|
||||||
|
|
||||||
|
|
||||||
class Text2ImageTaskCreationRequest(BaseModel):
|
|
||||||
model: Text2ImageModelName = Text2ImageModelName.seedream_3
|
|
||||||
prompt: str = Field(...)
|
|
||||||
response_format: Optional[str] = Field("url")
|
|
||||||
size: Optional[str] = Field(None)
|
|
||||||
seed: Optional[int] = Field(0, ge=0, le=2147483647)
|
|
||||||
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
|
|
||||||
watermark: Optional[bool] = Field(True)
|
|
||||||
|
|
||||||
|
|
||||||
class Image2ImageTaskCreationRequest(BaseModel):
|
|
||||||
model: Image2ImageModelName = Image2ImageModelName.seededit_3
|
|
||||||
prompt: str = Field(...)
|
|
||||||
response_format: Optional[str] = Field("url")
|
|
||||||
image: str = Field(..., description="Base64 encoded string or image URL")
|
|
||||||
size: Optional[str] = Field("adaptive")
|
|
||||||
seed: Optional[int] = Field(..., ge=0, le=2147483647)
|
|
||||||
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
|
|
||||||
watermark: Optional[bool] = Field(True)
|
|
||||||
|
|
||||||
|
|
||||||
class Seedream4Options(BaseModel):
|
|
||||||
max_images: int = Field(15)
|
|
||||||
|
|
||||||
|
|
||||||
class Seedream4TaskCreationRequest(BaseModel):
|
|
||||||
model: str = Field("seedream-4-0-250828")
|
|
||||||
prompt: str = Field(...)
|
|
||||||
response_format: str = Field("url")
|
|
||||||
image: Optional[list[str]] = Field(None, description="Image URLs")
|
|
||||||
size: str = Field(...)
|
|
||||||
seed: int = Field(..., ge=0, le=2147483647)
|
|
||||||
sequential_image_generation: str = Field("disabled")
|
|
||||||
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
|
||||||
watermark: bool = Field(True)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageTaskCreationResponse(BaseModel):
|
|
||||||
model: str = Field(...)
|
|
||||||
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
|
|
||||||
data: list = Field([], description="Contains information about the generated image(s).")
|
|
||||||
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
|
|
||||||
|
|
||||||
|
|
||||||
class TaskTextContent(BaseModel):
|
|
||||||
type: str = Field("text")
|
|
||||||
text: str = Field(...)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskImageContentUrl(BaseModel):
|
|
||||||
url: str = Field(...)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskImageContent(BaseModel):
|
|
||||||
type: str = Field("image_url")
|
|
||||||
image_url: TaskImageContentUrl = Field(...)
|
|
||||||
role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None)
|
|
||||||
|
|
||||||
|
|
||||||
class Text2VideoTaskCreationRequest(BaseModel):
|
|
||||||
model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro
|
|
||||||
content: list[TaskTextContent] = Field(..., min_length=1)
|
|
||||||
|
|
||||||
|
|
||||||
class Image2VideoTaskCreationRequest(BaseModel):
|
|
||||||
model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro
|
|
||||||
content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskCreationResponse(BaseModel):
|
|
||||||
id: str = Field(...)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskStatusError(BaseModel):
|
|
||||||
code: str = Field(...)
|
|
||||||
message: str = Field(...)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskStatusResult(BaseModel):
|
|
||||||
video_url: str = Field(...)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskStatusResponse(BaseModel):
|
|
||||||
id: str = Field(...)
|
|
||||||
model: str = Field(...)
|
|
||||||
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
|
||||||
error: Optional[TaskStatusError] = Field(None)
|
|
||||||
content: Optional[TaskStatusResult] = Field(None)
|
|
||||||
|
|
||||||
|
|
||||||
RECOMMENDED_PRESETS = [
|
|
||||||
("1024x1024 (1:1)", 1024, 1024),
|
|
||||||
("864x1152 (3:4)", 864, 1152),
|
|
||||||
("1152x864 (4:3)", 1152, 864),
|
|
||||||
("1280x720 (16:9)", 1280, 720),
|
|
||||||
("720x1280 (9:16)", 720, 1280),
|
|
||||||
("832x1248 (2:3)", 832, 1248),
|
|
||||||
("1248x832 (3:2)", 1248, 832),
|
|
||||||
("1512x648 (21:9)", 1512, 648),
|
|
||||||
("2048x2048 (1:1)", 2048, 2048),
|
|
||||||
("Custom", None, None),
|
|
||||||
]
|
|
||||||
|
|
||||||
RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
|
||||||
("2048x2048 (1:1)", 2048, 2048),
|
|
||||||
("2304x1728 (4:3)", 2304, 1728),
|
|
||||||
("1728x2304 (3:4)", 1728, 2304),
|
|
||||||
("2560x1440 (16:9)", 2560, 1440),
|
|
||||||
("1440x2560 (9:16)", 1440, 2560),
|
|
||||||
("2496x1664 (3:2)", 2496, 1664),
|
|
||||||
("1664x2496 (2:3)", 1664, 2496),
|
|
||||||
("3024x1296 (21:9)", 3024, 1296),
|
|
||||||
("4096x4096 (1:1)", 4096, 4096),
|
|
||||||
("Custom", None, None),
|
|
||||||
]
|
|
||||||
|
|
||||||
# The time in this dictionary are given for 10 seconds duration.
|
|
||||||
VIDEO_TASKS_EXECUTION_TIME = {
|
|
||||||
"seedance-1-0-lite-t2v-250428": {
|
|
||||||
"480p": 40,
|
|
||||||
"720p": 60,
|
|
||||||
"1080p": 90,
|
|
||||||
},
|
|
||||||
"seedance-1-0-lite-i2v-250428": {
|
|
||||||
"480p": 40,
|
|
||||||
"720p": 60,
|
|
||||||
"1080p": 90,
|
|
||||||
},
|
|
||||||
"seedance-1-0-pro-250528": {
|
|
||||||
"480p": 70,
|
|
||||||
"720p": 85,
|
|
||||||
"1080p": 115,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||||
if response.error:
|
if response.error:
|
||||||
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
|
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
|
||||||
@ -194,13 +52,6 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
|||||||
return response.data[0]["url"]
|
return response.data[0]["url"]
|
||||||
|
|
||||||
|
|
||||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
|
||||||
"""Returns the video URL from the task status response if it exists."""
|
|
||||||
if hasattr(response, "content") and response.content:
|
|
||||||
return response.content.video_url
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ByteDanceImageNode(IO.ComfyNode):
|
class ByteDanceImageNode(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -211,12 +62,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
|||||||
category="api node/image/ByteDance",
|
category="api node/image/ByteDance",
|
||||||
description="Generate images using ByteDance models via api based on prompt",
|
description="Generate images using ByteDance models via api based on prompt",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
|
||||||
"model",
|
|
||||||
options=Text2ImageModelName,
|
|
||||||
default=Text2ImageModelName.seedream_3,
|
|
||||||
tooltip="Model name",
|
|
||||||
),
|
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
@ -266,7 +112,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the image',
|
tooltip='Whether to add an "AI generated" watermark to the image',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -335,12 +181,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
|||||||
category="api node/image/ByteDance",
|
category="api node/image/ByteDance",
|
||||||
description="Edit images using ByteDance models via api based on prompt",
|
description="Edit images using ByteDance models via api based on prompt",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]),
|
||||||
"model",
|
|
||||||
options=Image2ImageModelName,
|
|
||||||
default=Image2ImageModelName.seededit_3,
|
|
||||||
tooltip="Model name",
|
|
||||||
),
|
|
||||||
IO.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
tooltip="The base image to edit",
|
tooltip="The base image to edit",
|
||||||
@ -374,7 +215,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the image',
|
tooltip='Whether to add an "AI generated" watermark to the image',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -388,13 +229,14 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
image: torch.Tensor,
|
image: Input.Image,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
guidance_scale: float,
|
guidance_scale: float,
|
||||||
@ -428,13 +270,13 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="ByteDanceSeedreamNode",
|
node_id="ByteDanceSeedreamNode",
|
||||||
display_name="ByteDance Seedream 4",
|
display_name="ByteDance Seedream 4.5",
|
||||||
category="api node/image/ByteDance",
|
category="api node/image/ByteDance",
|
||||||
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["seedream-4-0-250828"],
|
options=["seedream-4-5-251128", "seedream-4-0-250828"],
|
||||||
tooltip="Model name",
|
tooltip="Model name",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
@ -459,7 +301,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
default=2048,
|
default=2048,
|
||||||
min=1024,
|
min=1024,
|
||||||
max=4096,
|
max=4096,
|
||||||
step=64,
|
step=8,
|
||||||
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
|
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -468,7 +310,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
default=2048,
|
default=2048,
|
||||||
min=1024,
|
min=1024,
|
||||||
max=4096,
|
max=4096,
|
||||||
step=64,
|
step=8,
|
||||||
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
|
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -505,7 +347,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the image.',
|
tooltip='Whether to add an "AI generated" watermark to the image.',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -532,14 +374,14 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
image: torch.Tensor = None,
|
image: Input.Image | None = None,
|
||||||
size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0],
|
size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0],
|
||||||
width: int = 2048,
|
width: int = 2048,
|
||||||
height: int = 2048,
|
height: int = 2048,
|
||||||
sequential_image_generation: str = "disabled",
|
sequential_image_generation: str = "disabled",
|
||||||
max_images: int = 1,
|
max_images: int = 1,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
watermark: bool = True,
|
watermark: bool = False,
|
||||||
fail_on_partial: bool = True,
|
fail_on_partial: bool = True,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
@ -555,6 +397,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
|
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
|
||||||
)
|
)
|
||||||
|
out_num_pixels = w * h
|
||||||
|
mp_provided = out_num_pixels / 1_000_000.0
|
||||||
|
if "seedream-4-5" in model and out_num_pixels < 3686400:
|
||||||
|
raise ValueError(
|
||||||
|
f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, "
|
||||||
|
f"but {mp_provided:.2f}MP provided."
|
||||||
|
)
|
||||||
|
if "seedream-4-0" in model and out_num_pixels < 921600:
|
||||||
|
raise ValueError(
|
||||||
|
f"Minimum image resolution that the selected model can generate is 0.92MP, "
|
||||||
|
f"but {mp_provided:.2f}MP provided."
|
||||||
|
)
|
||||||
n_input_images = get_number_of_images(image) if image is not None else 0
|
n_input_images = get_number_of_images(image) if image is not None else 0
|
||||||
if n_input_images > 10:
|
if n_input_images > 10:
|
||||||
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
|
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
|
||||||
@ -607,9 +461,8 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
|||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=Text2VideoModelName,
|
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
|
||||||
default=Text2VideoModelName.seedance_1_pro,
|
default="seedance-1-0-pro-fast-251015",
|
||||||
tooltip="Model name",
|
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -655,7 +508,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -714,9 +567,8 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
|||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=Image2VideoModelName,
|
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
|
||||||
default=Image2VideoModelName.seedance_1_pro,
|
default="seedance-1-0-pro-fast-251015",
|
||||||
tooltip="Model name",
|
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -766,7 +618,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -787,7 +639,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
|||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
image: torch.Tensor,
|
image: Input.Image,
|
||||||
resolution: str,
|
resolution: str,
|
||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
duration: int,
|
duration: int,
|
||||||
@ -833,9 +685,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
|||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[model.value for model in Image2VideoModelName],
|
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
|
||||||
default=Image2VideoModelName.seedance_1_lite.value,
|
default="seedance-1-0-lite-i2v-250428",
|
||||||
tooltip="Model name",
|
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -889,7 +740,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -910,8 +761,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
|||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
first_frame: torch.Tensor,
|
first_frame: Input.Image,
|
||||||
last_frame: torch.Tensor,
|
last_frame: Input.Image,
|
||||||
resolution: str,
|
resolution: str,
|
||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
duration: int,
|
duration: int,
|
||||||
@ -968,9 +819,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
|||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[Image2VideoModelName.seedance_1_lite.value],
|
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
|
||||||
default=Image2VideoModelName.seedance_1_lite.value,
|
default="seedance-1-0-lite-i2v-250428",
|
||||||
tooltip="Model name",
|
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -1013,7 +863,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -1034,7 +884,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
|||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
images: torch.Tensor,
|
images: Input.Image,
|
||||||
resolution: str,
|
resolution: str,
|
||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
duration: int,
|
duration: int,
|
||||||
@ -1069,8 +919,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
|||||||
|
|
||||||
async def process_video_task(
|
async def process_video_task(
|
||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
|
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||||
estimated_duration: Optional[int],
|
estimated_duration: int | None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
initial_response = await sync_op(
|
initial_response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
@ -1085,7 +935,7 @@ async def process_video_task(
|
|||||||
estimated_duration=estimated_duration,
|
estimated_duration=estimated_duration,
|
||||||
response_model=TaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
|
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||||
|
|
||||||
|
|
||||||
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
|
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
|
||||||
|
|||||||
@ -13,8 +13,7 @@ import torch
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy_api.latest import IO, ComfyExtension, Input
|
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||||
from comfy_api.util import VideoCodec, VideoContainer
|
|
||||||
from comfy_api_nodes.apis.gemini_api import (
|
from comfy_api_nodes.apis.gemini_api import (
|
||||||
GeminiContent,
|
GeminiContent,
|
||||||
GeminiFileData,
|
GeminiFileData,
|
||||||
@ -27,12 +26,15 @@ from comfy_api_nodes.apis.gemini_api import (
|
|||||||
GeminiMimeType,
|
GeminiMimeType,
|
||||||
GeminiPart,
|
GeminiPart,
|
||||||
GeminiRole,
|
GeminiRole,
|
||||||
|
GeminiSystemInstructionContent,
|
||||||
|
GeminiTextPart,
|
||||||
Modality,
|
Modality,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
audio_to_base64_string,
|
audio_to_base64_string,
|
||||||
bytesio_to_image_tensor,
|
bytesio_to_image_tensor,
|
||||||
|
download_url_to_image_tensor,
|
||||||
get_number_of_images,
|
get_number_of_images,
|
||||||
sync_op,
|
sync_op,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
@ -43,6 +45,14 @@ from comfy_api_nodes.util import (
|
|||||||
|
|
||||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||||
|
GEMINI_IMAGE_SYS_PROMPT = (
|
||||||
|
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
|
||||||
|
"Interpret all user input—regardless of "
|
||||||
|
"format, intent, or abstraction—as literal visual directives for image composition.\n"
|
||||||
|
"If a prompt is conversational or lacks specific visual details, "
|
||||||
|
"you must creatively invent a concrete visual scenario that depicts the concept.\n"
|
||||||
|
"Prioritize generating the visual representation above any text, formatting, or conversational requests."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GeminiModel(str, Enum):
|
class GeminiModel(str, Enum):
|
||||||
@ -68,7 +78,7 @@ class GeminiImageModel(str, Enum):
|
|||||||
|
|
||||||
async def create_image_parts(
|
async def create_image_parts(
|
||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
images: torch.Tensor,
|
images: Input.Image,
|
||||||
image_limit: int = 0,
|
image_limit: int = 0,
|
||||||
) -> list[GeminiPart]:
|
) -> list[GeminiPart]:
|
||||||
image_parts: list[GeminiPart] = []
|
image_parts: list[GeminiPart] = []
|
||||||
@ -132,9 +142,11 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
|||||||
)
|
)
|
||||||
parts = []
|
parts = []
|
||||||
for part in response.candidates[0].content.parts:
|
for part in response.candidates[0].content.parts:
|
||||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
if part_type == "text" and part.text:
|
||||||
parts.append(part)
|
parts.append(part)
|
||||||
elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type:
|
elif part.inlineData and part.inlineData.mimeType == part_type:
|
||||||
|
parts.append(part)
|
||||||
|
elif part.fileData and part.fileData.mimeType == part_type:
|
||||||
parts.append(part)
|
parts.append(part)
|
||||||
# Skip parts that don't match the requested type
|
# Skip parts that don't match the requested type
|
||||||
return parts
|
return parts
|
||||||
@ -154,12 +166,15 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
|||||||
return "\n".join([part.text for part in parts])
|
return "\n".join([part.text for part in parts])
|
||||||
|
|
||||||
|
|
||||||
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
|
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||||
image_tensors: list[torch.Tensor] = []
|
image_tensors: list[Input.Image] = []
|
||||||
parts = get_parts_by_type(response, "image/png")
|
parts = get_parts_by_type(response, "image/png")
|
||||||
for part in parts:
|
for part in parts:
|
||||||
image_data = base64.b64decode(part.inlineData.data)
|
if part.inlineData:
|
||||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
image_data = base64.b64decode(part.inlineData.data)
|
||||||
|
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||||
|
else:
|
||||||
|
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
|
||||||
image_tensors.append(returned_image)
|
image_tensors.append(returned_image)
|
||||||
if len(image_tensors) == 0:
|
if len(image_tensors) == 0:
|
||||||
return torch.zeros((1, 1024, 1024, 4))
|
return torch.zeros((1, 1024, 1024, 4))
|
||||||
@ -277,6 +292,13 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
tooltip="Optional file(s) to use as context for the model. "
|
tooltip="Optional file(s) to use as context for the model. "
|
||||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||||
),
|
),
|
||||||
|
IO.String.Input(
|
||||||
|
"system_prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.String.Output(),
|
IO.String.Output(),
|
||||||
@ -293,7 +315,9 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
|
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
|
||||||
"""Convert video input to Gemini API compatible parts."""
|
"""Convert video input to Gemini API compatible parts."""
|
||||||
|
|
||||||
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
|
base_64_string = video_to_base64_string(
|
||||||
|
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||||
|
)
|
||||||
return [
|
return [
|
||||||
GeminiPart(
|
GeminiPart(
|
||||||
inlineData=GeminiInlineData(
|
inlineData=GeminiInlineData(
|
||||||
@ -343,10 +367,11 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
model: str,
|
model: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
images: torch.Tensor | None = None,
|
images: Input.Image | None = None,
|
||||||
audio: Input.Audio | None = None,
|
audio: Input.Audio | None = None,
|
||||||
video: Input.Video | None = None,
|
video: Input.Video | None = None,
|
||||||
files: list[GeminiPart] | None = None,
|
files: list[GeminiPart] | None = None,
|
||||||
|
system_prompt: str = "",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
|
||||||
@ -363,7 +388,10 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
if files is not None:
|
if files is not None:
|
||||||
parts.extend(files)
|
parts.extend(files)
|
||||||
|
|
||||||
# Create response
|
gemini_system_prompt = None
|
||||||
|
if system_prompt:
|
||||||
|
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||||
|
|
||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||||
@ -373,7 +401,8 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
role=GeminiRole.user,
|
role=GeminiRole.user,
|
||||||
parts=parts,
|
parts=parts,
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
|
systemInstruction=gemini_system_prompt,
|
||||||
),
|
),
|
||||||
response_model=GeminiGenerateContentResponse,
|
response_model=GeminiGenerateContentResponse,
|
||||||
price_extractor=calculate_tokens_price,
|
price_extractor=calculate_tokens_price,
|
||||||
@ -523,6 +552,13 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
"'IMAGE+TEXT' to return both the generated image and a text response.",
|
"'IMAGE+TEXT' to return both the generated image and a text response.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
|
IO.String.Input(
|
||||||
|
"system_prompt",
|
||||||
|
multiline=True,
|
||||||
|
default=GEMINI_IMAGE_SYS_PROMPT,
|
||||||
|
optional=True,
|
||||||
|
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Image.Output(),
|
IO.Image.Output(),
|
||||||
@ -542,10 +578,11 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
model: str,
|
model: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
images: torch.Tensor | None = None,
|
images: Input.Image | None = None,
|
||||||
files: list[GeminiPart] | None = None,
|
files: list[GeminiPart] | None = None,
|
||||||
aspect_ratio: str = "auto",
|
aspect_ratio: str = "auto",
|
||||||
response_modalities: str = "IMAGE+TEXT",
|
response_modalities: str = "IMAGE+TEXT",
|
||||||
|
system_prompt: str = "",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||||
@ -559,9 +596,13 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
if files is not None:
|
if files is not None:
|
||||||
parts.extend(files)
|
parts.extend(files)
|
||||||
|
|
||||||
|
gemini_system_prompt = None
|
||||||
|
if system_prompt:
|
||||||
|
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||||
|
|
||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
|
||||||
data=GeminiImageGenerateContentRequest(
|
data=GeminiImageGenerateContentRequest(
|
||||||
contents=[
|
contents=[
|
||||||
GeminiContent(role=GeminiRole.user, parts=parts),
|
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||||
@ -570,11 +611,12 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||||
imageConfig=None if aspect_ratio == "auto" else image_config,
|
imageConfig=None if aspect_ratio == "auto" else image_config,
|
||||||
),
|
),
|
||||||
|
systemInstruction=gemini_system_prompt,
|
||||||
),
|
),
|
||||||
response_model=GeminiGenerateContentResponse,
|
response_model=GeminiGenerateContentResponse,
|
||||||
price_extractor=calculate_tokens_price,
|
price_extractor=calculate_tokens_price,
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
|
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
||||||
|
|
||||||
|
|
||||||
class GeminiImage2(IO.ComfyNode):
|
class GeminiImage2(IO.ComfyNode):
|
||||||
@ -640,6 +682,13 @@ class GeminiImage2(IO.ComfyNode):
|
|||||||
tooltip="Optional file(s) to use as context for the model. "
|
tooltip="Optional file(s) to use as context for the model. "
|
||||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||||
),
|
),
|
||||||
|
IO.String.Input(
|
||||||
|
"system_prompt",
|
||||||
|
multiline=True,
|
||||||
|
default=GEMINI_IMAGE_SYS_PROMPT,
|
||||||
|
optional=True,
|
||||||
|
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Image.Output(),
|
IO.Image.Output(),
|
||||||
@ -662,8 +711,9 @@ class GeminiImage2(IO.ComfyNode):
|
|||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
resolution: str,
|
resolution: str,
|
||||||
response_modalities: str,
|
response_modalities: str,
|
||||||
images: torch.Tensor | None = None,
|
images: Input.Image | None = None,
|
||||||
files: list[GeminiPart] | None = None,
|
files: list[GeminiPart] | None = None,
|
||||||
|
system_prompt: str = "",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
|
|
||||||
@ -679,9 +729,13 @@ class GeminiImage2(IO.ComfyNode):
|
|||||||
if aspect_ratio != "auto":
|
if aspect_ratio != "auto":
|
||||||
image_config.aspectRatio = aspect_ratio
|
image_config.aspectRatio = aspect_ratio
|
||||||
|
|
||||||
|
gemini_system_prompt = None
|
||||||
|
if system_prompt:
|
||||||
|
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||||
|
|
||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
|
||||||
data=GeminiImageGenerateContentRequest(
|
data=GeminiImageGenerateContentRequest(
|
||||||
contents=[
|
contents=[
|
||||||
GeminiContent(role=GeminiRole.user, parts=parts),
|
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||||
@ -690,11 +744,12 @@ class GeminiImage2(IO.ComfyNode):
|
|||||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||||
imageConfig=image_config,
|
imageConfig=image_config,
|
||||||
),
|
),
|
||||||
|
systemInstruction=gemini_system_prompt,
|
||||||
),
|
),
|
||||||
response_model=GeminiGenerateContentResponse,
|
response_model=GeminiGenerateContentResponse,
|
||||||
price_extractor=calculate_tokens_price,
|
price_extractor=calculate_tokens_price,
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
|
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
||||||
|
|
||||||
|
|
||||||
class GeminiExtension(ComfyExtension):
|
class GeminiExtension(ComfyExtension):
|
||||||
|
|||||||
@ -50,6 +50,8 @@ from comfy_api_nodes.apis import (
|
|||||||
KlingSingleImageEffectModelName,
|
KlingSingleImageEffectModelName,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.kling_api import (
|
from comfy_api_nodes.apis.kling_api import (
|
||||||
|
ImageToVideoWithAudioRequest,
|
||||||
|
MotionControlRequest,
|
||||||
OmniImageParamImage,
|
OmniImageParamImage,
|
||||||
OmniParamImage,
|
OmniParamImage,
|
||||||
OmniParamVideo,
|
OmniParamVideo,
|
||||||
@ -57,7 +59,8 @@ from comfy_api_nodes.apis.kling_api import (
|
|||||||
OmniProImageRequest,
|
OmniProImageRequest,
|
||||||
OmniProReferences2VideoRequest,
|
OmniProReferences2VideoRequest,
|
||||||
OmniProText2VideoRequest,
|
OmniProText2VideoRequest,
|
||||||
OmniTaskStatusResponse,
|
TaskStatusResponse,
|
||||||
|
TextToVideoWithAudioRequest,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
@ -103,10 +106,6 @@ AVERAGE_DURATION_VIDEO_EXTEND = 320
|
|||||||
|
|
||||||
|
|
||||||
MODE_TEXT2VIDEO = {
|
MODE_TEXT2VIDEO = {
|
||||||
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
|
|
||||||
"standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"),
|
|
||||||
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
|
|
||||||
"pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"),
|
|
||||||
"standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"),
|
"standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"),
|
||||||
"standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"),
|
"standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"),
|
||||||
"pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"),
|
"pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"),
|
||||||
@ -127,8 +126,6 @@ See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document
|
|||||||
|
|
||||||
|
|
||||||
MODE_START_END_FRAME = {
|
MODE_START_END_FRAME = {
|
||||||
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
|
|
||||||
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
|
|
||||||
"pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"),
|
"pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"),
|
||||||
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
|
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
|
||||||
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
|
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
|
||||||
@ -242,7 +239,7 @@ def normalize_omni_prompt_references(prompt: str) -> str:
|
|||||||
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
|
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
|
||||||
|
|
||||||
|
|
||||||
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput:
|
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput:
|
||||||
if response.code:
|
if response.code:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||||
@ -250,7 +247,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStat
|
|||||||
final_response = await poll_op(
|
final_response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
|
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
|
||||||
response_model=OmniTaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||||
max_poll_attempts=160,
|
max_poll_attempts=160,
|
||||||
)
|
)
|
||||||
@ -483,12 +480,12 @@ async def execute_image2video(
|
|||||||
task_id = task_creation_response.data.task_id
|
task_id = task_creation_response.data.task_id
|
||||||
|
|
||||||
final_response = await poll_op(
|
final_response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
|
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
|
||||||
response_model=KlingImage2VideoResponse,
|
response_model=KlingImage2VideoResponse,
|
||||||
estimated_duration=AVERAGE_DURATION_I2V,
|
estimated_duration=AVERAGE_DURATION_I2V,
|
||||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||||
)
|
)
|
||||||
validate_video_result_response(final_response)
|
validate_video_result_response(final_response)
|
||||||
|
|
||||||
video = get_video_from_response(final_response)
|
video = get_video_from_response(final_response)
|
||||||
@ -752,7 +749,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
|
|||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"mode",
|
"mode",
|
||||||
options=modes,
|
options=modes,
|
||||||
default=modes[4],
|
default=modes[8],
|
||||||
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
|
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -810,6 +807,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||||
IO.Combo.Input("duration", options=[5, 10]),
|
IO.Combo.Input("duration", options=[5, 10]),
|
||||||
|
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Video.Output(),
|
IO.Video.Output(),
|
||||||
@ -829,17 +827,19 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
duration: int,
|
duration: int,
|
||||||
|
resolution: str = "1080p",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
response_model=OmniTaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
data=OmniProText2VideoRequest(
|
data=OmniProText2VideoRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
duration=str(duration),
|
duration=str(duration),
|
||||||
|
mode="pro" if resolution == "1080p" else "std",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await finish_omni_video_task(cls, response)
|
return await finish_omni_video_task(cls, response)
|
||||||
@ -862,7 +862,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||||||
tooltip="A text prompt describing the video content. "
|
tooltip="A text prompt describing the video content. "
|
||||||
"This can include both positive and negative descriptions.",
|
"This can include both positive and negative descriptions.",
|
||||||
),
|
),
|
||||||
IO.Combo.Input("duration", options=["5", "10"]),
|
IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider),
|
||||||
IO.Image.Input("first_frame"),
|
IO.Image.Input("first_frame"),
|
||||||
IO.Image.Input(
|
IO.Image.Input(
|
||||||
"end_frame",
|
"end_frame",
|
||||||
@ -875,6 +875,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Up to 6 additional reference images.",
|
tooltip="Up to 6 additional reference images.",
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Video.Output(),
|
IO.Video.Output(),
|
||||||
@ -896,11 +897,16 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||||||
first_frame: Input.Image,
|
first_frame: Input.Image,
|
||||||
end_frame: Input.Image | None = None,
|
end_frame: Input.Image | None = None,
|
||||||
reference_images: Input.Image | None = None,
|
reference_images: Input.Image | None = None,
|
||||||
|
resolution: str = "1080p",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
prompt = normalize_omni_prompt_references(prompt)
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
if end_frame is not None and reference_images is not None:
|
if end_frame is not None and reference_images is not None:
|
||||||
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
|
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
|
||||||
|
if duration not in (5, 10) and end_frame is None and reference_images is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Duration is only supported for 5 or 10 seconds if there is no end frame or reference images."
|
||||||
|
)
|
||||||
validate_image_dimensions(first_frame, min_width=300, min_height=300)
|
validate_image_dimensions(first_frame, min_width=300, min_height=300)
|
||||||
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
|
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
|
||||||
image_list: list[OmniParamImage] = [
|
image_list: list[OmniParamImage] = [
|
||||||
@ -929,12 +935,13 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
response_model=OmniTaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
data=OmniProFirstLastFrameRequest(
|
data=OmniProFirstLastFrameRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
duration=str(duration),
|
duration=str(duration),
|
||||||
image_list=image_list,
|
image_list=image_list,
|
||||||
|
mode="pro" if resolution == "1080p" else "std",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await finish_omni_video_task(cls, response)
|
return await finish_omni_video_task(cls, response)
|
||||||
@ -963,6 +970,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||||||
"reference_images",
|
"reference_images",
|
||||||
tooltip="Up to 7 reference images.",
|
tooltip="Up to 7 reference images.",
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Video.Output(),
|
IO.Video.Output(),
|
||||||
@ -983,6 +991,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
duration: int,
|
duration: int,
|
||||||
reference_images: Input.Image,
|
reference_images: Input.Image,
|
||||||
|
resolution: str = "1080p",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
prompt = normalize_omni_prompt_references(prompt)
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
@ -997,13 +1006,14 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
response_model=OmniTaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
data=OmniProReferences2VideoRequest(
|
data=OmniProReferences2VideoRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
duration=str(duration),
|
duration=str(duration),
|
||||||
image_list=image_list,
|
image_list=image_list,
|
||||||
|
mode="pro" if resolution == "1080p" else "std",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await finish_omni_video_task(cls, response)
|
return await finish_omni_video_task(cls, response)
|
||||||
@ -1035,6 +1045,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
|||||||
tooltip="Up to 4 additional reference images.",
|
tooltip="Up to 4 additional reference images.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Video.Output(),
|
IO.Video.Output(),
|
||||||
@ -1057,6 +1068,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
|||||||
reference_video: Input.Video,
|
reference_video: Input.Video,
|
||||||
keep_original_sound: bool,
|
keep_original_sound: bool,
|
||||||
reference_images: Input.Image | None = None,
|
reference_images: Input.Image | None = None,
|
||||||
|
resolution: str = "1080p",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
prompt = normalize_omni_prompt_references(prompt)
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
@ -1081,7 +1093,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
response_model=OmniTaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
data=OmniProReferences2VideoRequest(
|
data=OmniProReferences2VideoRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -1089,6 +1101,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
|||||||
duration=str(duration),
|
duration=str(duration),
|
||||||
image_list=image_list if image_list else None,
|
image_list=image_list if image_list else None,
|
||||||
video_list=video_list,
|
video_list=video_list,
|
||||||
|
mode="pro" if resolution == "1080p" else "std",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await finish_omni_video_task(cls, response)
|
return await finish_omni_video_task(cls, response)
|
||||||
@ -1118,6 +1131,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||||||
tooltip="Up to 4 additional reference images.",
|
tooltip="Up to 4 additional reference images.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Video.Output(),
|
IO.Video.Output(),
|
||||||
@ -1138,6 +1152,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||||||
video: Input.Video,
|
video: Input.Video,
|
||||||
keep_original_sound: bool,
|
keep_original_sound: bool,
|
||||||
reference_images: Input.Image | None = None,
|
reference_images: Input.Image | None = None,
|
||||||
|
resolution: str = "1080p",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
prompt = normalize_omni_prompt_references(prompt)
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
@ -1162,7 +1177,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
response_model=OmniTaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
data=OmniProReferences2VideoRequest(
|
data=OmniProReferences2VideoRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -1170,6 +1185,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||||||
duration=None,
|
duration=None,
|
||||||
image_list=image_list if image_list else None,
|
image_list=image_list if image_list else None,
|
||||||
video_list=video_list,
|
video_list=video_list,
|
||||||
|
mode="pro" if resolution == "1080p" else "std",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await finish_omni_video_task(cls, response)
|
return await finish_omni_video_task(cls, response)
|
||||||
@ -1237,7 +1253,7 @@ class OmniProImageNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
|
||||||
response_model=OmniTaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
data=OmniProImageRequest(
|
data=OmniProImageRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -1253,7 +1269,7 @@ class OmniProImageNode(IO.ComfyNode):
|
|||||||
final_response = await poll_op(
|
final_response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
|
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
|
||||||
response_model=OmniTaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
|
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
|
||||||
@ -1328,9 +1344,8 @@ class KlingImage2VideoNode(IO.ComfyNode):
|
|||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="KlingImage2VideoNode",
|
node_id="KlingImage2VideoNode",
|
||||||
display_name="Kling Image to Video",
|
display_name="Kling Image(First Frame) to Video",
|
||||||
category="api node/video/Kling",
|
category="api node/video/Kling",
|
||||||
description="Kling Image to Video Node",
|
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
|
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
|
||||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||||
@ -1488,7 +1503,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
|
|||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"mode",
|
"mode",
|
||||||
options=modes,
|
options=modes,
|
||||||
default=modes[8],
|
default=modes[6],
|
||||||
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
|
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -1951,7 +1966,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
|||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model_name",
|
"model_name",
|
||||||
options=[i.value for i in KlingImageGenModelName],
|
options=[i.value for i in KlingImageGenModelName],
|
||||||
default="kling-v1",
|
default="kling-v2",
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
@ -2034,6 +2049,221 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await image_result_to_node_output(images))
|
return IO.NodeOutput(await image_result_to_node_output(images))
|
||||||
|
|
||||||
|
|
||||||
|
class TextToVideoWithAudio(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingTextToVideoWithAudio",
|
||||||
|
display_name="Kling Text to Video with Audio",
|
||||||
|
category="api node/video/Kling",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
||||||
|
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
|
||||||
|
IO.Combo.Input("mode", options=["pro"]),
|
||||||
|
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||||
|
IO.Combo.Input("duration", options=[5, 10]),
|
||||||
|
IO.Boolean.Input("generate_audio", default=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
prompt: str,
|
||||||
|
mode: str,
|
||||||
|
aspect_ratio: str,
|
||||||
|
duration: int,
|
||||||
|
generate_audio: bool,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
data=TextToVideoWithAudioRequest(
|
||||||
|
model_name=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
mode=mode,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
duration=str(duration),
|
||||||
|
sound="on" if generate_audio else "off",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if response.code:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||||
|
)
|
||||||
|
final_response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||||
|
|
||||||
|
|
||||||
|
class ImageToVideoWithAudio(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingImageToVideoWithAudio",
|
||||||
|
display_name="Kling Image(First Frame) to Video with Audio",
|
||||||
|
category="api node/video/Kling",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
||||||
|
IO.Image.Input("start_frame"),
|
||||||
|
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
|
||||||
|
IO.Combo.Input("mode", options=["pro"]),
|
||||||
|
IO.Combo.Input("duration", options=[5, 10]),
|
||||||
|
IO.Boolean.Input("generate_audio", default=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
start_frame: Input.Image,
|
||||||
|
prompt: str,
|
||||||
|
mode: str,
|
||||||
|
duration: int,
|
||||||
|
generate_audio: bool,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
|
validate_image_dimensions(start_frame, min_width=300, min_height=300)
|
||||||
|
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
data=ImageToVideoWithAudioRequest(
|
||||||
|
model_name=model_name,
|
||||||
|
image=(await upload_images_to_comfyapi(cls, start_frame))[0],
|
||||||
|
prompt=prompt,
|
||||||
|
mode=mode,
|
||||||
|
duration=str(duration),
|
||||||
|
sound="on" if generate_audio else "off",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if response.code:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||||
|
)
|
||||||
|
final_response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||||
|
|
||||||
|
|
||||||
|
class MotionControl(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingMotionControl",
|
||||||
|
display_name="Kling Motion Control",
|
||||||
|
category="api node/video/Kling",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input("prompt", multiline=True),
|
||||||
|
IO.Image.Input("reference_image"),
|
||||||
|
IO.Video.Input(
|
||||||
|
"reference_video",
|
||||||
|
tooltip="Motion reference video used to drive movement/expression.\n"
|
||||||
|
"Duration limits depend on character_orientation:\n"
|
||||||
|
" - image: 3–10s (max 10s)\n"
|
||||||
|
" - video: 3–30s (max 30s)",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input("keep_original_sound", default=True),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"character_orientation",
|
||||||
|
options=["video", "image"],
|
||||||
|
tooltip="Controls where the character's facing/orientation comes from.\n"
|
||||||
|
"video: movements, expressions, camera moves, and orientation "
|
||||||
|
"follow the motion reference video (other details via prompt).\n"
|
||||||
|
"image: movements and expressions still follow the motion reference video, "
|
||||||
|
"but the character orientation matches the reference image (camera/other details via prompt).",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("mode", options=["pro", "std"]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
reference_image: Input.Image,
|
||||||
|
reference_video: Input.Video,
|
||||||
|
keep_original_sound: bool,
|
||||||
|
character_orientation: str,
|
||||||
|
mode: str,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, max_length=2500)
|
||||||
|
validate_image_dimensions(reference_image, min_width=340, min_height=340)
|
||||||
|
validate_image_aspect_ratio(reference_image, (1, 2.5), (2.5, 1))
|
||||||
|
if character_orientation == "image":
|
||||||
|
validate_video_duration(reference_video, min_duration=3, max_duration=10)
|
||||||
|
else:
|
||||||
|
validate_video_duration(reference_video, min_duration=3, max_duration=30)
|
||||||
|
validate_video_dimensions(reference_video, min_width=340, min_height=340, max_width=3850, max_height=3850)
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/videos/motion-control", method="POST"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
data=MotionControlRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
image_url=(await upload_images_to_comfyapi(cls, reference_image))[0],
|
||||||
|
video_url=await upload_video_to_comfyapi(cls, reference_video),
|
||||||
|
keep_original_sound="yes" if keep_original_sound else "no",
|
||||||
|
character_orientation=character_orientation,
|
||||||
|
mode=mode,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if response.code:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||||
|
)
|
||||||
|
final_response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/kling/v1/videos/motion-control/{response.data.task_id}"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||||
|
|
||||||
|
|
||||||
class KlingExtension(ComfyExtension):
|
class KlingExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -2056,7 +2286,10 @@ class KlingExtension(ComfyExtension):
|
|||||||
OmniProImageToVideoNode,
|
OmniProImageToVideoNode,
|
||||||
OmniProVideoToVideoNode,
|
OmniProVideoToVideoNode,
|
||||||
OmniProEditVideoNode,
|
OmniProEditVideoNode,
|
||||||
# OmniProImageNode, # need support from backend
|
OmniProImageNode,
|
||||||
|
TextToVideoWithAudio,
|
||||||
|
ImageToVideoWithAudio,
|
||||||
|
MotionControl,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,9 @@
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.input_impl import VideoFromFile
|
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||||
from comfy_api.latest import IO, ComfyExtension
|
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
get_number_of_images,
|
get_number_of_images,
|
||||||
@ -26,9 +23,9 @@ class ExecuteTaskRequest(BaseModel):
|
|||||||
model: str = Field(...)
|
model: str = Field(...)
|
||||||
duration: int = Field(...)
|
duration: int = Field(...)
|
||||||
resolution: str = Field(...)
|
resolution: str = Field(...)
|
||||||
fps: Optional[int] = Field(25)
|
fps: int | None = Field(25)
|
||||||
generate_audio: Optional[bool] = Field(True)
|
generate_audio: bool | None = Field(True)
|
||||||
image_uri: Optional[str] = Field(None)
|
image_uri: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class TextToVideoNode(IO.ComfyNode):
|
class TextToVideoNode(IO.ComfyNode):
|
||||||
@ -103,7 +100,7 @@ class TextToVideoNode(IO.ComfyNode):
|
|||||||
as_binary=True,
|
as_binary=True,
|
||||||
max_retries=1,
|
max_retries=1,
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
|
||||||
|
|
||||||
|
|
||||||
class ImageToVideoNode(IO.ComfyNode):
|
class ImageToVideoNode(IO.ComfyNode):
|
||||||
@ -153,7 +150,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
image: torch.Tensor,
|
image: Input.Image,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
duration: int,
|
duration: int,
|
||||||
@ -183,7 +180,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
|||||||
as_binary=True,
|
as_binary=True,
|
||||||
max_retries=1,
|
max_retries=1,
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
|
||||||
|
|
||||||
|
|
||||||
class LtxvApiExtension(ComfyExtension):
|
class LtxvApiExtension(ComfyExtension):
|
||||||
|
|||||||
@ -1,11 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.input import VideoInput
|
from comfy_api.latest import IO, ComfyExtension, Input
|
||||||
from comfy_api.latest import IO, ComfyExtension
|
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
MoonvalleyPromptResponse,
|
MoonvalleyPromptResponse,
|
||||||
MoonvalleyTextToVideoInferenceParams,
|
MoonvalleyTextToVideoInferenceParams,
|
||||||
@ -61,7 +58,7 @@ def validate_task_creation_response(response) -> None:
|
|||||||
raise RuntimeError(error_msg)
|
raise RuntimeError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
def validate_video_to_video_input(video: Input.Video) -> Input.Video:
|
||||||
"""
|
"""
|
||||||
Validates and processes video input for Moonvalley Video-to-Video generation.
|
Validates and processes video input for Moonvalley Video-to-Video generation.
|
||||||
|
|
||||||
@ -82,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
|||||||
return _validate_and_trim_duration(video)
|
return _validate_and_trim_duration(video)
|
||||||
|
|
||||||
|
|
||||||
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
|
def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
|
||||||
"""Extracts video dimensions with error handling."""
|
"""Extracts video dimensions with error handling."""
|
||||||
try:
|
try:
|
||||||
return video.get_dimensions()
|
return video.get_dimensions()
|
||||||
@ -106,7 +103,7 @@ def _validate_video_dimensions(width: int, height: int) -> None:
|
|||||||
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||||
|
|
||||||
|
|
||||||
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
|
||||||
"""Validates video duration and trims to 5 seconds if needed."""
|
"""Validates video duration and trims to 5 seconds if needed."""
|
||||||
duration = video.get_duration()
|
duration = video.get_duration()
|
||||||
_validate_minimum_duration(duration)
|
_validate_minimum_duration(duration)
|
||||||
@ -119,7 +116,7 @@ def _validate_minimum_duration(duration: float) -> None:
|
|||||||
raise ValueError("Input video must be at least 5 seconds long.")
|
raise ValueError("Input video must be at least 5 seconds long.")
|
||||||
|
|
||||||
|
|
||||||
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
|
||||||
"""Trims video to 5 seconds if longer."""
|
"""Trims video to 5 seconds if longer."""
|
||||||
if duration > 5:
|
if duration > 5:
|
||||||
return trim_video(video, 5)
|
return trim_video(video, 5)
|
||||||
@ -241,7 +238,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
image: torch.Tensor,
|
image: Input.Image,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str,
|
negative_prompt: str,
|
||||||
resolution: str,
|
resolution: str,
|
||||||
@ -362,9 +359,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str,
|
negative_prompt: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
video: Optional[VideoInput] = None,
|
video: Input.Video | None = None,
|
||||||
control_type: str = "Motion Transfer",
|
control_type: str = "Motion Transfer",
|
||||||
motion_intensity: Optional[int] = 100,
|
motion_intensity: int | None = 100,
|
||||||
steps=33,
|
steps=33,
|
||||||
prompt_adherence=4.5,
|
prompt_adherence=4.5,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
|
|||||||
@ -1,46 +1,45 @@
|
|||||||
from io import BytesIO
|
import base64
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from inspect import cleandoc
|
from io import BytesIO
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import folder_paths
|
|
||||||
import base64
|
|
||||||
from comfy_api.latest import IO, ComfyExtension
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
from comfy_api.latest import IO, ComfyExtension, Input
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
OpenAIImageGenerationRequest,
|
|
||||||
OpenAIImageEditRequest,
|
|
||||||
OpenAIImageGenerationResponse,
|
|
||||||
OpenAICreateResponse,
|
|
||||||
OpenAIResponse,
|
|
||||||
CreateModelResponseProperties,
|
CreateModelResponseProperties,
|
||||||
Item,
|
|
||||||
OutputContent,
|
|
||||||
InputImageContent,
|
|
||||||
Detail,
|
Detail,
|
||||||
InputTextContent,
|
|
||||||
InputMessage,
|
|
||||||
InputMessageContentList,
|
|
||||||
InputContent,
|
InputContent,
|
||||||
InputFileContent,
|
InputFileContent,
|
||||||
|
InputImageContent,
|
||||||
|
InputMessage,
|
||||||
|
InputMessageContentList,
|
||||||
|
InputTextContent,
|
||||||
|
Item,
|
||||||
|
OpenAICreateResponse,
|
||||||
|
OpenAIResponse,
|
||||||
|
OutputContent,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apis.openai_api import (
|
||||||
|
OpenAIImageEditRequest,
|
||||||
|
OpenAIImageGenerationRequest,
|
||||||
|
OpenAIImageGenerationResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
downscale_image_tensor,
|
|
||||||
download_url_to_bytesio,
|
|
||||||
validate_string,
|
|
||||||
tensor_to_base64_string,
|
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
sync_op,
|
download_url_to_bytesio,
|
||||||
|
downscale_image_tensor,
|
||||||
poll_op,
|
poll_op,
|
||||||
|
sync_op,
|
||||||
|
tensor_to_base64_string,
|
||||||
text_filepath_to_data_uri,
|
text_filepath_to_data_uri,
|
||||||
|
validate_string,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
|
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
|
||||||
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
|
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
|
||||||
|
|
||||||
@ -98,9 +97,6 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIDalle2(IO.ComfyNode):
|
class OpenAIDalle2(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -108,7 +104,7 @@ class OpenAIDalle2(IO.ComfyNode):
|
|||||||
node_id="OpenAIDalle2",
|
node_id="OpenAIDalle2",
|
||||||
display_name="OpenAI DALL·E 2",
|
display_name="OpenAI DALL·E 2",
|
||||||
category="api node/image/OpenAI",
|
category="api node/image/OpenAI",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -234,9 +230,6 @@ class OpenAIDalle2(IO.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIDalle3(IO.ComfyNode):
|
class OpenAIDalle3(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -244,7 +237,7 @@ class OpenAIDalle3(IO.ComfyNode):
|
|||||||
node_id="OpenAIDalle3",
|
node_id="OpenAIDalle3",
|
||||||
display_name="OpenAI DALL·E 3",
|
display_name="OpenAI DALL·E 3",
|
||||||
category="api node/image/OpenAI",
|
category="api node/image/OpenAI",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -326,10 +319,16 @@ class OpenAIDalle3(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_tokens_price_image_1(response: OpenAIImageGenerationResponse) -> float | None:
|
||||||
|
# https://platform.openai.com/docs/pricing
|
||||||
|
return ((response.usage.input_tokens * 10.0) + (response.usage.output_tokens * 40.0)) / 1_000_000.0
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) -> float | None:
|
||||||
|
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
|
||||||
|
|
||||||
|
|
||||||
class OpenAIGPTImage1(IO.ComfyNode):
|
class OpenAIGPTImage1(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -337,13 +336,13 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||||||
node_id="OpenAIGPTImage1",
|
node_id="OpenAIGPTImage1",
|
||||||
display_name="OpenAI GPT Image 1",
|
display_name="OpenAI GPT Image 1",
|
||||||
category="api node/image/OpenAI",
|
category="api node/image/OpenAI",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
default="",
|
default="",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
tooltip="Text prompt for GPT Image 1",
|
tooltip="Text prompt for GPT Image",
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
@ -365,8 +364,8 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"background",
|
"background",
|
||||||
default="opaque",
|
default="auto",
|
||||||
options=["opaque", "transparent"],
|
options=["auto", "opaque", "transparent"],
|
||||||
tooltip="Return image with or without background",
|
tooltip="Return image with or without background",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -397,6 +396,11 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||||||
tooltip="Optional mask for inpainting (white areas will be replaced)",
|
tooltip="Optional mask for inpainting (white areas will be replaced)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=["gpt-image-1", "gpt-image-1.5"],
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Image.Output(),
|
IO.Image.Output(),
|
||||||
@ -412,32 +416,34 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
prompt,
|
prompt: str,
|
||||||
seed=0,
|
seed: int = 0,
|
||||||
quality="low",
|
quality: str = "low",
|
||||||
background="opaque",
|
background: str = "opaque",
|
||||||
image=None,
|
image: Input.Image | None = None,
|
||||||
mask=None,
|
mask: Input.Image | None = None,
|
||||||
n=1,
|
n: int = 1,
|
||||||
size="1024x1024",
|
size: str = "1024x1024",
|
||||||
|
model: str = "gpt-image-1",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
model = "gpt-image-1"
|
|
||||||
path = "/proxy/openai/images/generations"
|
if mask is not None and image is None:
|
||||||
content_type = "application/json"
|
raise ValueError("Cannot use a mask without an input image")
|
||||||
request_class = OpenAIImageGenerationRequest
|
|
||||||
files = []
|
if model == "gpt-image-1":
|
||||||
|
price_extractor = calculate_tokens_price_image_1
|
||||||
|
elif model == "gpt-image-1.5":
|
||||||
|
price_extractor = calculate_tokens_price_image_1_5
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model: {model}")
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
path = "/proxy/openai/images/edits"
|
files = []
|
||||||
request_class = OpenAIImageEditRequest
|
|
||||||
content_type = "multipart/form-data"
|
|
||||||
|
|
||||||
batch_size = image.shape[0]
|
batch_size = image.shape[0]
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
single_image = image[i : i + 1]
|
single_image = image[i: i + 1]
|
||||||
scaled_image = downscale_image_tensor(single_image).squeeze()
|
scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze()
|
||||||
|
|
||||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||||
img = Image.fromarray(image_np)
|
img = Image.fromarray(image_np)
|
||||||
@ -450,44 +456,59 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
|
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
if image is None:
|
if image.shape[0] != 1:
|
||||||
raise Exception("Cannot use a mask without an input image")
|
raise Exception("Cannot use a mask with multiple image")
|
||||||
if image.shape[0] != 1:
|
if mask.shape[1:] != image.shape[1:-1]:
|
||||||
raise Exception("Cannot use a mask with multiple image")
|
raise Exception("Mask and Image must be the same size")
|
||||||
if mask.shape[1:] != image.shape[1:-1]:
|
_, height, width = mask.shape
|
||||||
raise Exception("Mask and Image must be the same size")
|
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||||
batch, height, width = mask.shape
|
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
|
||||||
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
|
||||||
|
|
||||||
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze()
|
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze()
|
||||||
|
|
||||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||||
mask_img = Image.fromarray(mask_np)
|
mask_img = Image.fromarray(mask_np)
|
||||||
mask_img_byte_arr = BytesIO()
|
mask_img_byte_arr = BytesIO()
|
||||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||||
mask_img_byte_arr.seek(0)
|
mask_img_byte_arr.seek(0)
|
||||||
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
||||||
|
|
||||||
# Build the operation
|
|
||||||
response = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path=path, method="POST"),
|
|
||||||
response_model=OpenAIImageGenerationResponse,
|
|
||||||
data=request_class(
|
|
||||||
model=model,
|
|
||||||
prompt=prompt,
|
|
||||||
quality=quality,
|
|
||||||
background=background,
|
|
||||||
n=n,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
),
|
|
||||||
files=files if files else None,
|
|
||||||
content_type=content_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
|
||||||
|
response_model=OpenAIImageGenerationResponse,
|
||||||
|
data=OpenAIImageEditRequest(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
quality=quality,
|
||||||
|
background=background,
|
||||||
|
n=n,
|
||||||
|
seed=seed,
|
||||||
|
size=size,
|
||||||
|
moderation="low",
|
||||||
|
),
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
files=files,
|
||||||
|
price_extractor=price_extractor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
|
||||||
|
response_model=OpenAIImageGenerationResponse,
|
||||||
|
data=OpenAIImageGenerationRequest(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
quality=quality,
|
||||||
|
background=background,
|
||||||
|
n=n,
|
||||||
|
seed=seed,
|
||||||
|
size=size,
|
||||||
|
moderation="low",
|
||||||
|
),
|
||||||
|
price_extractor=price_extractor,
|
||||||
|
)
|
||||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,575 +0,0 @@
|
|||||||
"""
|
|
||||||
Pika x ComfyUI API Nodes
|
|
||||||
|
|
||||||
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from io import BytesIO
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from typing_extensions import override
|
|
||||||
from comfy_api.latest import ComfyExtension, IO
|
|
||||||
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
|
||||||
from comfy_api_nodes.apis import pika_api as pika_defs
|
|
||||||
from comfy_api_nodes.util import (
|
|
||||||
validate_string,
|
|
||||||
download_url_to_video_output,
|
|
||||||
tensor_to_bytesio,
|
|
||||||
ApiEndpoint,
|
|
||||||
sync_op,
|
|
||||||
poll_op,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
|
||||||
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
|
||||||
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
|
|
||||||
|
|
||||||
PIKA_API_VERSION = "2.2"
|
|
||||||
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
|
|
||||||
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
|
|
||||||
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
|
|
||||||
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
|
|
||||||
|
|
||||||
PATH_VIDEO_GET = "/proxy/pika/videos"
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_task(
|
|
||||||
task_id: str,
|
|
||||||
cls: type[IO.ComfyNode],
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
final_response: pika_defs.PikaVideoResponse = await poll_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
|
|
||||||
response_model=pika_defs.PikaVideoResponse,
|
|
||||||
status_extractor=lambda response: (response.status.value if response.status else None),
|
|
||||||
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
|
|
||||||
estimated_duration=60,
|
|
||||||
max_poll_attempts=240,
|
|
||||||
)
|
|
||||||
if not final_response.url:
|
|
||||||
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
|
|
||||||
logging.error(error_msg)
|
|
||||||
raise Exception(error_msg)
|
|
||||||
video_url = final_response.url
|
|
||||||
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
|
||||||
return IO.NodeOutput(await download_url_to_video_output(video_url))
|
|
||||||
|
|
||||||
|
|
||||||
def get_base_inputs_types() -> list[IO.Input]:
|
|
||||||
"""Get the base required inputs types common to all Pika nodes."""
|
|
||||||
return [
|
|
||||||
IO.String.Input("prompt_text", multiline=True),
|
|
||||||
IO.String.Input("negative_prompt", multiline=True),
|
|
||||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
|
||||||
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
|
|
||||||
IO.Combo.Input("duration", options=[5, 10], default=5),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class PikaImageToVideo(IO.ComfyNode):
|
|
||||||
"""Pika 2.2 Image to Video Node."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls) -> IO.Schema:
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="PikaImageToVideoNode2_2",
|
|
||||||
display_name="Pika Image to Video",
|
|
||||||
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
|
|
||||||
category="api node/video/Pika",
|
|
||||||
inputs=[
|
|
||||||
IO.Image.Input("image", tooltip="The image to convert to video"),
|
|
||||||
*get_base_inputs_types(),
|
|
||||||
],
|
|
||||||
outputs=[IO.Video.Output()],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
is_deprecated=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
image: torch.Tensor,
|
|
||||||
prompt_text: str,
|
|
||||||
negative_prompt: str,
|
|
||||||
seed: int,
|
|
||||||
resolution: str,
|
|
||||||
duration: int,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
image_bytes_io = tensor_to_bytesio(image)
|
|
||||||
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
|
|
||||||
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
|
|
||||||
promptText=prompt_text,
|
|
||||||
negativePrompt=negative_prompt,
|
|
||||||
seed=seed,
|
|
||||||
resolution=resolution,
|
|
||||||
duration=duration,
|
|
||||||
)
|
|
||||||
initial_operation = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
data=pika_request_data,
|
|
||||||
files=pika_files,
|
|
||||||
content_type="multipart/form-data",
|
|
||||||
)
|
|
||||||
return await execute_task(initial_operation.video_id, cls)
|
|
||||||
|
|
||||||
|
|
||||||
class PikaTextToVideoNode(IO.ComfyNode):
|
|
||||||
"""Pika Text2Video v2.2 Node."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls) -> IO.Schema:
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="PikaTextToVideoNode2_2",
|
|
||||||
display_name="Pika Text to Video",
|
|
||||||
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
|
|
||||||
category="api node/video/Pika",
|
|
||||||
inputs=[
|
|
||||||
*get_base_inputs_types(),
|
|
||||||
IO.Float.Input(
|
|
||||||
"aspect_ratio",
|
|
||||||
step=0.001,
|
|
||||||
min=0.4,
|
|
||||||
max=2.5,
|
|
||||||
default=1.7777777777777777,
|
|
||||||
tooltip="Aspect ratio (width / height)",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
outputs=[IO.Video.Output()],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
is_deprecated=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
prompt_text: str,
|
|
||||||
negative_prompt: str,
|
|
||||||
seed: int,
|
|
||||||
resolution: str,
|
|
||||||
duration: int,
|
|
||||||
aspect_ratio: float,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
initial_operation = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
|
||||||
promptText=prompt_text,
|
|
||||||
negativePrompt=negative_prompt,
|
|
||||||
seed=seed,
|
|
||||||
resolution=resolution,
|
|
||||||
duration=duration,
|
|
||||||
aspectRatio=aspect_ratio,
|
|
||||||
),
|
|
||||||
content_type="application/x-www-form-urlencoded",
|
|
||||||
)
|
|
||||||
return await execute_task(initial_operation.video_id, cls)
|
|
||||||
|
|
||||||
|
|
||||||
class PikaScenes(IO.ComfyNode):
|
|
||||||
"""PikaScenes v2.2 Node."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls) -> IO.Schema:
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="PikaScenesV2_2",
|
|
||||||
display_name="Pika Scenes (Video Image Composition)",
|
|
||||||
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
|
|
||||||
category="api node/video/Pika",
|
|
||||||
inputs=[
|
|
||||||
*get_base_inputs_types(),
|
|
||||||
IO.Combo.Input(
|
|
||||||
"ingredients_mode",
|
|
||||||
options=["creative", "precise"],
|
|
||||||
default="creative",
|
|
||||||
),
|
|
||||||
IO.Float.Input(
|
|
||||||
"aspect_ratio",
|
|
||||||
step=0.001,
|
|
||||||
min=0.4,
|
|
||||||
max=2.5,
|
|
||||||
default=1.7777777777777777,
|
|
||||||
tooltip="Aspect ratio (width / height)",
|
|
||||||
),
|
|
||||||
IO.Image.Input(
|
|
||||||
"image_ingredient_1",
|
|
||||||
optional=True,
|
|
||||||
tooltip="Image that will be used as ingredient to create a video.",
|
|
||||||
),
|
|
||||||
IO.Image.Input(
|
|
||||||
"image_ingredient_2",
|
|
||||||
optional=True,
|
|
||||||
tooltip="Image that will be used as ingredient to create a video.",
|
|
||||||
),
|
|
||||||
IO.Image.Input(
|
|
||||||
"image_ingredient_3",
|
|
||||||
optional=True,
|
|
||||||
tooltip="Image that will be used as ingredient to create a video.",
|
|
||||||
),
|
|
||||||
IO.Image.Input(
|
|
||||||
"image_ingredient_4",
|
|
||||||
optional=True,
|
|
||||||
tooltip="Image that will be used as ingredient to create a video.",
|
|
||||||
),
|
|
||||||
IO.Image.Input(
|
|
||||||
"image_ingredient_5",
|
|
||||||
optional=True,
|
|
||||||
tooltip="Image that will be used as ingredient to create a video.",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
outputs=[IO.Video.Output()],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
is_deprecated=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
prompt_text: str,
|
|
||||||
negative_prompt: str,
|
|
||||||
seed: int,
|
|
||||||
resolution: str,
|
|
||||||
duration: int,
|
|
||||||
ingredients_mode: str,
|
|
||||||
aspect_ratio: float,
|
|
||||||
image_ingredient_1: Optional[torch.Tensor] = None,
|
|
||||||
image_ingredient_2: Optional[torch.Tensor] = None,
|
|
||||||
image_ingredient_3: Optional[torch.Tensor] = None,
|
|
||||||
image_ingredient_4: Optional[torch.Tensor] = None,
|
|
||||||
image_ingredient_5: Optional[torch.Tensor] = None,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
all_image_bytes_io = []
|
|
||||||
for image in [
|
|
||||||
image_ingredient_1,
|
|
||||||
image_ingredient_2,
|
|
||||||
image_ingredient_3,
|
|
||||||
image_ingredient_4,
|
|
||||||
image_ingredient_5,
|
|
||||||
]:
|
|
||||||
if image is not None:
|
|
||||||
all_image_bytes_io.append(tensor_to_bytesio(image))
|
|
||||||
|
|
||||||
pika_files = [
|
|
||||||
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
|
|
||||||
for i, image_bytes_io in enumerate(all_image_bytes_io)
|
|
||||||
]
|
|
||||||
|
|
||||||
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
|
|
||||||
ingredientsMode=ingredients_mode,
|
|
||||||
promptText=prompt_text,
|
|
||||||
negativePrompt=negative_prompt,
|
|
||||||
seed=seed,
|
|
||||||
resolution=resolution,
|
|
||||||
duration=duration,
|
|
||||||
aspectRatio=aspect_ratio,
|
|
||||||
)
|
|
||||||
initial_operation = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
data=pika_request_data,
|
|
||||||
files=pika_files,
|
|
||||||
content_type="multipart/form-data",
|
|
||||||
)
|
|
||||||
|
|
||||||
return await execute_task(initial_operation.video_id, cls)
|
|
||||||
|
|
||||||
|
|
||||||
class PikAdditionsNode(IO.ComfyNode):
|
|
||||||
"""Pika Pikadditions Node. Add an image into a video."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls) -> IO.Schema:
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="Pikadditions",
|
|
||||||
display_name="Pikadditions (Video Object Insertion)",
|
|
||||||
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
|
|
||||||
category="api node/video/Pika",
|
|
||||||
inputs=[
|
|
||||||
IO.Video.Input("video", tooltip="The video to add an image to."),
|
|
||||||
IO.Image.Input("image", tooltip="The image to add to the video."),
|
|
||||||
IO.String.Input("prompt_text", multiline=True),
|
|
||||||
IO.String.Input("negative_prompt", multiline=True),
|
|
||||||
IO.Int.Input(
|
|
||||||
"seed",
|
|
||||||
min=0,
|
|
||||||
max=0xFFFFFFFF,
|
|
||||||
control_after_generate=True,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
outputs=[IO.Video.Output()],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
is_deprecated=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
video: VideoInput,
|
|
||||||
image: torch.Tensor,
|
|
||||||
prompt_text: str,
|
|
||||||
negative_prompt: str,
|
|
||||||
seed: int,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
video_bytes_io = BytesIO()
|
|
||||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
|
||||||
video_bytes_io.seek(0)
|
|
||||||
|
|
||||||
image_bytes_io = tensor_to_bytesio(image)
|
|
||||||
pika_files = {
|
|
||||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
|
||||||
"image": ("image.png", image_bytes_io, "image/png"),
|
|
||||||
}
|
|
||||||
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
|
||||||
promptText=prompt_text,
|
|
||||||
negativePrompt=negative_prompt,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
initial_operation = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
data=pika_request_data,
|
|
||||||
files=pika_files,
|
|
||||||
content_type="multipart/form-data",
|
|
||||||
)
|
|
||||||
|
|
||||||
return await execute_task(initial_operation.video_id, cls)
|
|
||||||
|
|
||||||
|
|
||||||
class PikaSwapsNode(IO.ComfyNode):
|
|
||||||
"""Pika Pikaswaps Node."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls) -> IO.Schema:
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="Pikaswaps",
|
|
||||||
display_name="Pika Swaps (Video Object Replacement)",
|
|
||||||
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
|
|
||||||
category="api node/video/Pika",
|
|
||||||
inputs=[
|
|
||||||
IO.Video.Input("video", tooltip="The video to swap an object in."),
|
|
||||||
IO.Image.Input(
|
|
||||||
"image",
|
|
||||||
tooltip="The image used to replace the masked object in the video.",
|
|
||||||
optional=True,
|
|
||||||
),
|
|
||||||
IO.Mask.Input(
|
|
||||||
"mask",
|
|
||||||
tooltip="Use the mask to define areas in the video to replace.",
|
|
||||||
optional=True,
|
|
||||||
),
|
|
||||||
IO.String.Input("prompt_text", multiline=True, optional=True),
|
|
||||||
IO.String.Input("negative_prompt", multiline=True, optional=True),
|
|
||||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
|
|
||||||
IO.String.Input(
|
|
||||||
"region_to_modify",
|
|
||||||
multiline=True,
|
|
||||||
optional=True,
|
|
||||||
tooltip="Plaintext description of the object / region to modify.",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
outputs=[IO.Video.Output()],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
is_deprecated=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
video: VideoInput,
|
|
||||||
image: Optional[torch.Tensor] = None,
|
|
||||||
mask: Optional[torch.Tensor] = None,
|
|
||||||
prompt_text: str = "",
|
|
||||||
negative_prompt: str = "",
|
|
||||||
seed: int = 0,
|
|
||||||
region_to_modify: str = "",
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
video_bytes_io = BytesIO()
|
|
||||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
|
||||||
video_bytes_io.seek(0)
|
|
||||||
pika_files = {
|
|
||||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
|
||||||
}
|
|
||||||
if mask is not None:
|
|
||||||
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
|
|
||||||
if image is not None:
|
|
||||||
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
|
|
||||||
|
|
||||||
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
|
||||||
promptText=prompt_text,
|
|
||||||
negativePrompt=negative_prompt,
|
|
||||||
seed=seed,
|
|
||||||
modifyRegionRoi=region_to_modify if region_to_modify else None,
|
|
||||||
)
|
|
||||||
initial_operation = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
data=pika_request_data,
|
|
||||||
files=pika_files,
|
|
||||||
content_type="multipart/form-data",
|
|
||||||
)
|
|
||||||
return await execute_task(initial_operation.video_id, cls)
|
|
||||||
|
|
||||||
|
|
||||||
class PikaffectsNode(IO.ComfyNode):
|
|
||||||
"""Pika Pikaffects Node."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls) -> IO.Schema:
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="Pikaffects",
|
|
||||||
display_name="Pikaffects (Video Effects)",
|
|
||||||
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
|
|
||||||
category="api node/video/Pika",
|
|
||||||
inputs=[
|
|
||||||
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
|
|
||||||
IO.Combo.Input(
|
|
||||||
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
|
|
||||||
),
|
|
||||||
IO.String.Input("prompt_text", multiline=True),
|
|
||||||
IO.String.Input("negative_prompt", multiline=True),
|
|
||||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
|
||||||
],
|
|
||||||
outputs=[IO.Video.Output()],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
is_deprecated=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
image: torch.Tensor,
|
|
||||||
pikaffect: str,
|
|
||||||
prompt_text: str,
|
|
||||||
negative_prompt: str,
|
|
||||||
seed: int,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
initial_operation = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
|
||||||
pikaffect=pikaffect,
|
|
||||||
promptText=prompt_text,
|
|
||||||
negativePrompt=negative_prompt,
|
|
||||||
seed=seed,
|
|
||||||
),
|
|
||||||
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
|
||||||
content_type="multipart/form-data",
|
|
||||||
)
|
|
||||||
return await execute_task(initial_operation.video_id, cls)
|
|
||||||
|
|
||||||
|
|
||||||
class PikaStartEndFrameNode(IO.ComfyNode):
|
|
||||||
"""PikaFrames v2.2 Node."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls) -> IO.Schema:
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="PikaStartEndFrameNode2_2",
|
|
||||||
display_name="Pika Start and End Frame to Video",
|
|
||||||
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
|
|
||||||
category="api node/video/Pika",
|
|
||||||
inputs=[
|
|
||||||
IO.Image.Input("image_start", tooltip="The first image to combine."),
|
|
||||||
IO.Image.Input("image_end", tooltip="The last image to combine."),
|
|
||||||
*get_base_inputs_types(),
|
|
||||||
],
|
|
||||||
outputs=[IO.Video.Output()],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
is_deprecated=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
image_start: torch.Tensor,
|
|
||||||
image_end: torch.Tensor,
|
|
||||||
prompt_text: str,
|
|
||||||
negative_prompt: str,
|
|
||||||
seed: int,
|
|
||||||
resolution: str,
|
|
||||||
duration: int,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
validate_string(prompt_text, field_name="prompt_text", min_length=1)
|
|
||||||
pika_files = [
|
|
||||||
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
|
||||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
|
||||||
]
|
|
||||||
initial_operation = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
|
||||||
promptText=prompt_text,
|
|
||||||
negativePrompt=negative_prompt,
|
|
||||||
seed=seed,
|
|
||||||
resolution=resolution,
|
|
||||||
duration=duration,
|
|
||||||
),
|
|
||||||
files=pika_files,
|
|
||||||
content_type="multipart/form-data",
|
|
||||||
)
|
|
||||||
return await execute_task(initial_operation.video_id, cls)
|
|
||||||
|
|
||||||
|
|
||||||
class PikaApiNodesExtension(ComfyExtension):
|
|
||||||
@override
|
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
|
||||||
return [
|
|
||||||
PikaImageToVideo,
|
|
||||||
PikaTextToVideoNode,
|
|
||||||
PikaScenes,
|
|
||||||
PikAdditionsNode,
|
|
||||||
PikaSwapsNode,
|
|
||||||
PikaffectsNode,
|
|
||||||
PikaStartEndFrameNode,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> PikaApiNodesExtension:
|
|
||||||
return PikaApiNodesExtension()
|
|
||||||
@ -11,12 +11,11 @@ User Guides:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Union, Optional
|
|
||||||
from typing_extensions import override
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
import torch
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
RunwayImageToVideoRequest,
|
RunwayImageToVideoRequest,
|
||||||
RunwayImageToVideoResponse,
|
RunwayImageToVideoResponse,
|
||||||
@ -44,8 +43,6 @@ from comfy_api_nodes.util import (
|
|||||||
sync_op,
|
sync_op,
|
||||||
poll_op,
|
poll_op,
|
||||||
)
|
)
|
||||||
from comfy_api.input_impl import VideoFromFile
|
|
||||||
from comfy_api.latest import ComfyExtension, IO
|
|
||||||
|
|
||||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||||
@ -80,7 +77,7 @@ class RunwayGen3aAspectRatio(str, Enum):
|
|||||||
field_1280_768 = "1280:768"
|
field_1280_768 = "1280:768"
|
||||||
|
|
||||||
|
|
||||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||||
"""Returns the video URL from the task status response if it exists."""
|
"""Returns the video URL from the task status response if it exists."""
|
||||||
if hasattr(response, "output") and len(response.output) > 0:
|
if hasattr(response, "output") and len(response.output) > 0:
|
||||||
return response.output[0]
|
return response.output[0]
|
||||||
@ -89,13 +86,13 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
|||||||
|
|
||||||
def extract_progress_from_task_status(
|
def extract_progress_from_task_status(
|
||||||
response: TaskStatusResponse,
|
response: TaskStatusResponse,
|
||||||
) -> Union[float, None]:
|
) -> float | None:
|
||||||
if hasattr(response, "progress") and response.progress is not None:
|
if hasattr(response, "progress") and response.progress is not None:
|
||||||
return response.progress * 100
|
return response.progress * 100
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||||
"""Returns the image URL from the task status response if it exists."""
|
"""Returns the image URL from the task status response if it exists."""
|
||||||
if hasattr(response, "output") and len(response.output) > 0:
|
if hasattr(response, "output") and len(response.output) > 0:
|
||||||
return response.output[0]
|
return response.output[0]
|
||||||
@ -103,7 +100,7 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
|||||||
|
|
||||||
|
|
||||||
async def get_response(
|
async def get_response(
|
||||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
|
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
|
||||||
) -> TaskStatusResponse:
|
) -> TaskStatusResponse:
|
||||||
"""Poll the task status until it is finished then get the response."""
|
"""Poll the task status until it is finished then get the response."""
|
||||||
return await poll_op(
|
return await poll_op(
|
||||||
@ -119,8 +116,8 @@ async def get_response(
|
|||||||
async def generate_video(
|
async def generate_video(
|
||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
request: RunwayImageToVideoRequest,
|
request: RunwayImageToVideoRequest,
|
||||||
estimated_duration: Optional[int] = None,
|
estimated_duration: int | None = None,
|
||||||
) -> VideoFromFile:
|
) -> InputImpl.VideoFromFile:
|
||||||
initial_response = await sync_op(
|
initial_response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||||
@ -193,7 +190,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
|||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
start_frame: torch.Tensor,
|
start_frame: Input.Image,
|
||||||
duration: str,
|
duration: str,
|
||||||
ratio: str,
|
ratio: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
@ -283,7 +280,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
|||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
start_frame: torch.Tensor,
|
start_frame: Input.Image,
|
||||||
duration: str,
|
duration: str,
|
||||||
ratio: str,
|
ratio: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
@ -381,8 +378,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
|||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
start_frame: torch.Tensor,
|
start_frame: Input.Image,
|
||||||
end_frame: torch.Tensor,
|
end_frame: Input.Image,
|
||||||
duration: str,
|
duration: str,
|
||||||
ratio: str,
|
ratio: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
@ -467,7 +464,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
|||||||
cls,
|
cls,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
ratio: str,
|
ratio: str,
|
||||||
reference_image: Optional[torch.Tensor] = None,
|
reference_image: Input.Image | None = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1)
|
validate_string(prompt, min_length=1)
|
||||||
|
|
||||||
|
|||||||
@ -23,10 +23,6 @@ UPSCALER_MODELS_MAP = {
|
|||||||
"Starlight (Astra) Fast": "slf-1",
|
"Starlight (Astra) Fast": "slf-1",
|
||||||
"Starlight (Astra) Creative": "slc-1",
|
"Starlight (Astra) Creative": "slc-1",
|
||||||
}
|
}
|
||||||
UPSCALER_VALUES_MAP = {
|
|
||||||
"FullHD (1080p)": 1920,
|
|
||||||
"4K (2160p)": 3840,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TopazImageEnhance(IO.ComfyNode):
|
class TopazImageEnhance(IO.ComfyNode):
|
||||||
@ -214,7 +210,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
|||||||
IO.Video.Input("video"),
|
IO.Video.Input("video"),
|
||||||
IO.Boolean.Input("upscaler_enabled", default=True),
|
IO.Boolean.Input("upscaler_enabled", default=True),
|
||||||
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
|
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
|
||||||
IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())),
|
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"upscaler_creativity",
|
"upscaler_creativity",
|
||||||
options=["low", "middle", "high"],
|
options=["low", "middle", "high"],
|
||||||
@ -306,8 +302,33 @@ class TopazVideoEnhance(IO.ComfyNode):
|
|||||||
target_frame_rate = src_frame_rate
|
target_frame_rate = src_frame_rate
|
||||||
filters = []
|
filters = []
|
||||||
if upscaler_enabled:
|
if upscaler_enabled:
|
||||||
target_width = UPSCALER_VALUES_MAP[upscaler_resolution]
|
if "1080p" in upscaler_resolution:
|
||||||
target_height = UPSCALER_VALUES_MAP[upscaler_resolution]
|
target_pixel_p = 1080
|
||||||
|
max_long_side = 1920
|
||||||
|
else:
|
||||||
|
target_pixel_p = 2160
|
||||||
|
max_long_side = 3840
|
||||||
|
ar = src_width / src_height
|
||||||
|
if src_width >= src_height:
|
||||||
|
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
|
||||||
|
target_height = target_pixel_p
|
||||||
|
target_width = int(target_height * ar)
|
||||||
|
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
|
||||||
|
if target_width > max_long_side:
|
||||||
|
target_width = max_long_side
|
||||||
|
target_height = int(target_width / ar)
|
||||||
|
else:
|
||||||
|
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
|
||||||
|
target_width = target_pixel_p
|
||||||
|
target_height = int(target_width / ar)
|
||||||
|
# Check if height exceeds standard bounds
|
||||||
|
if target_height > max_long_side:
|
||||||
|
target_height = max_long_side
|
||||||
|
target_width = int(target_height * ar)
|
||||||
|
if target_width % 2 != 0:
|
||||||
|
target_width += 1
|
||||||
|
if target_height % 2 != 0:
|
||||||
|
target_height += 1
|
||||||
filters.append(
|
filters.append(
|
||||||
topaz_api.VideoEnhancementFilter(
|
topaz_api.VideoEnhancementFilter(
|
||||||
model=UPSCALER_MODELS_MAP[upscaler_model],
|
model=UPSCALER_MODELS_MAP[upscaler_model],
|
||||||
|
|||||||
@ -102,8 +102,9 @@ class TripoTextToModelNode(IO.ComfyNode):
|
|||||||
IO.Int.Input("model_seed", default=42, optional=True),
|
IO.Int.Input("model_seed", default=42, optional=True),
|
||||||
IO.Int.Input("texture_seed", default=42, optional=True),
|
IO.Int.Input("texture_seed", default=42, optional=True),
|
||||||
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
|
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||||
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True),
|
||||||
IO.Boolean.Input("quad", default=False, optional=True),
|
IO.Boolean.Input("quad", default=False, optional=True),
|
||||||
|
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.String.Output(display_name="model_file"),
|
IO.String.Output(display_name="model_file"),
|
||||||
@ -131,6 +132,7 @@ class TripoTextToModelNode(IO.ComfyNode):
|
|||||||
model_seed: Optional[int] = None,
|
model_seed: Optional[int] = None,
|
||||||
texture_seed: Optional[int] = None,
|
texture_seed: Optional[int] = None,
|
||||||
texture_quality: Optional[str] = None,
|
texture_quality: Optional[str] = None,
|
||||||
|
geometry_quality: Optional[str] = None,
|
||||||
face_limit: Optional[int] = None,
|
face_limit: Optional[int] = None,
|
||||||
quad: Optional[bool] = None,
|
quad: Optional[bool] = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
@ -153,7 +155,8 @@ class TripoTextToModelNode(IO.ComfyNode):
|
|||||||
model_seed=model_seed,
|
model_seed=model_seed,
|
||||||
texture_seed=texture_seed,
|
texture_seed=texture_seed,
|
||||||
texture_quality=texture_quality,
|
texture_quality=texture_quality,
|
||||||
face_limit=face_limit,
|
face_limit=face_limit if face_limit != -1 else None,
|
||||||
|
geometry_quality=geometry_quality,
|
||||||
auto_size=True,
|
auto_size=True,
|
||||||
quad=quad,
|
quad=quad,
|
||||||
),
|
),
|
||||||
@ -194,6 +197,7 @@ class TripoImageToModelNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
||||||
IO.Boolean.Input("quad", default=False, optional=True),
|
IO.Boolean.Input("quad", default=False, optional=True),
|
||||||
|
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.String.Output(display_name="model_file"),
|
IO.String.Output(display_name="model_file"),
|
||||||
@ -220,6 +224,7 @@ class TripoImageToModelNode(IO.ComfyNode):
|
|||||||
orientation=None,
|
orientation=None,
|
||||||
texture_seed: Optional[int] = None,
|
texture_seed: Optional[int] = None,
|
||||||
texture_quality: Optional[str] = None,
|
texture_quality: Optional[str] = None,
|
||||||
|
geometry_quality: Optional[str] = None,
|
||||||
texture_alignment: Optional[str] = None,
|
texture_alignment: Optional[str] = None,
|
||||||
face_limit: Optional[int] = None,
|
face_limit: Optional[int] = None,
|
||||||
quad: Optional[bool] = None,
|
quad: Optional[bool] = None,
|
||||||
@ -246,10 +251,11 @@ class TripoImageToModelNode(IO.ComfyNode):
|
|||||||
pbr=pbr,
|
pbr=pbr,
|
||||||
model_seed=model_seed,
|
model_seed=model_seed,
|
||||||
orientation=orientation,
|
orientation=orientation,
|
||||||
|
geometry_quality=geometry_quality,
|
||||||
texture_alignment=texture_alignment,
|
texture_alignment=texture_alignment,
|
||||||
texture_seed=texture_seed,
|
texture_seed=texture_seed,
|
||||||
texture_quality=texture_quality,
|
texture_quality=texture_quality,
|
||||||
face_limit=face_limit,
|
face_limit=face_limit if face_limit != -1 else None,
|
||||||
auto_size=True,
|
auto_size=True,
|
||||||
quad=quad,
|
quad=quad,
|
||||||
),
|
),
|
||||||
@ -295,6 +301,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
||||||
IO.Boolean.Input("quad", default=False, optional=True),
|
IO.Boolean.Input("quad", default=False, optional=True),
|
||||||
|
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.String.Output(display_name="model_file"),
|
IO.String.Output(display_name="model_file"),
|
||||||
@ -323,6 +330,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
|||||||
model_seed: Optional[int] = None,
|
model_seed: Optional[int] = None,
|
||||||
texture_seed: Optional[int] = None,
|
texture_seed: Optional[int] = None,
|
||||||
texture_quality: Optional[str] = None,
|
texture_quality: Optional[str] = None,
|
||||||
|
geometry_quality: Optional[str] = None,
|
||||||
texture_alignment: Optional[str] = None,
|
texture_alignment: Optional[str] = None,
|
||||||
face_limit: Optional[int] = None,
|
face_limit: Optional[int] = None,
|
||||||
quad: Optional[bool] = None,
|
quad: Optional[bool] = None,
|
||||||
@ -359,8 +367,9 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
|||||||
model_seed=model_seed,
|
model_seed=model_seed,
|
||||||
texture_seed=texture_seed,
|
texture_seed=texture_seed,
|
||||||
texture_quality=texture_quality,
|
texture_quality=texture_quality,
|
||||||
|
geometry_quality=geometry_quality,
|
||||||
texture_alignment=texture_alignment,
|
texture_alignment=texture_alignment,
|
||||||
face_limit=face_limit,
|
face_limit=face_limit if face_limit != -1 else None,
|
||||||
quad=quad,
|
quad=quad,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -508,6 +517,8 @@ class TripoRetargetNode(IO.ComfyNode):
|
|||||||
options=[
|
options=[
|
||||||
"preset:idle",
|
"preset:idle",
|
||||||
"preset:walk",
|
"preset:walk",
|
||||||
|
"preset:run",
|
||||||
|
"preset:dive",
|
||||||
"preset:climb",
|
"preset:climb",
|
||||||
"preset:jump",
|
"preset:jump",
|
||||||
"preset:slash",
|
"preset:slash",
|
||||||
@ -515,6 +526,11 @@ class TripoRetargetNode(IO.ComfyNode):
|
|||||||
"preset:hurt",
|
"preset:hurt",
|
||||||
"preset:fall",
|
"preset:fall",
|
||||||
"preset:turn",
|
"preset:turn",
|
||||||
|
"preset:quadruped:walk",
|
||||||
|
"preset:hexapod:walk",
|
||||||
|
"preset:octopod:walk",
|
||||||
|
"preset:serpentine:march",
|
||||||
|
"preset:aquatic:march"
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -563,7 +579,7 @@ class TripoConversionNode(IO.ComfyNode):
|
|||||||
"face_limit",
|
"face_limit",
|
||||||
default=-1,
|
default=-1,
|
||||||
min=-1,
|
min=-1,
|
||||||
max=500000,
|
max=2000000,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
@ -579,6 +595,40 @@ class TripoConversionNode(IO.ComfyNode):
|
|||||||
default="JPEG",
|
default="JPEG",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
|
IO.Boolean.Input("force_symmetry", default=False, optional=True),
|
||||||
|
IO.Boolean.Input("flatten_bottom", default=False, optional=True),
|
||||||
|
IO.Float.Input(
|
||||||
|
"flatten_bottom_threshold",
|
||||||
|
default=0.0,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True),
|
||||||
|
IO.Float.Input(
|
||||||
|
"scale_factor",
|
||||||
|
default=1.0,
|
||||||
|
min=0.0,
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Boolean.Input("with_animation", default=False, optional=True),
|
||||||
|
IO.Boolean.Input("pack_uv", default=False, optional=True),
|
||||||
|
IO.Boolean.Input("bake", default=False, optional=True),
|
||||||
|
IO.String.Input("part_names", default="", optional=True), # comma-separated list
|
||||||
|
IO.Combo.Input(
|
||||||
|
"fbx_preset",
|
||||||
|
options=["blender", "mixamo", "3dsmax"],
|
||||||
|
default="blender",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Boolean.Input("export_vertex_colors", default=False, optional=True),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"export_orientation",
|
||||||
|
options=["align_image", "default"],
|
||||||
|
default="default",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Boolean.Input("animate_in_place", default=False, optional=True),
|
||||||
],
|
],
|
||||||
outputs=[],
|
outputs=[],
|
||||||
hidden=[
|
hidden=[
|
||||||
@ -604,12 +654,31 @@ class TripoConversionNode(IO.ComfyNode):
|
|||||||
original_model_task_id,
|
original_model_task_id,
|
||||||
format: str,
|
format: str,
|
||||||
quad: bool,
|
quad: bool,
|
||||||
|
force_symmetry: bool,
|
||||||
face_limit: int,
|
face_limit: int,
|
||||||
|
flatten_bottom: bool,
|
||||||
|
flatten_bottom_threshold: float,
|
||||||
texture_size: int,
|
texture_size: int,
|
||||||
texture_format: str,
|
texture_format: str,
|
||||||
|
pivot_to_center_bottom: bool,
|
||||||
|
scale_factor: float,
|
||||||
|
with_animation: bool,
|
||||||
|
pack_uv: bool,
|
||||||
|
bake: bool,
|
||||||
|
part_names: str,
|
||||||
|
fbx_preset: str,
|
||||||
|
export_vertex_colors: bool,
|
||||||
|
export_orientation: str,
|
||||||
|
animate_in_place: bool,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
if not original_model_task_id:
|
if not original_model_task_id:
|
||||||
raise RuntimeError("original_model_task_id is required")
|
raise RuntimeError("original_model_task_id is required")
|
||||||
|
|
||||||
|
# Parse part_names from comma-separated string to list
|
||||||
|
part_names_list = None
|
||||||
|
if part_names and part_names.strip():
|
||||||
|
part_names_list = [name.strip() for name in part_names.split(',') if name.strip()]
|
||||||
|
|
||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
|
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
|
||||||
@ -618,9 +687,22 @@ class TripoConversionNode(IO.ComfyNode):
|
|||||||
original_model_task_id=original_model_task_id,
|
original_model_task_id=original_model_task_id,
|
||||||
format=format,
|
format=format,
|
||||||
quad=quad if quad else None,
|
quad=quad if quad else None,
|
||||||
|
force_symmetry=force_symmetry if force_symmetry else None,
|
||||||
face_limit=face_limit if face_limit != -1 else None,
|
face_limit=face_limit if face_limit != -1 else None,
|
||||||
|
flatten_bottom=flatten_bottom if flatten_bottom else None,
|
||||||
|
flatten_bottom_threshold=flatten_bottom_threshold if flatten_bottom_threshold != 0.0 else None,
|
||||||
texture_size=texture_size if texture_size != 4096 else None,
|
texture_size=texture_size if texture_size != 4096 else None,
|
||||||
texture_format=texture_format if texture_format != "JPEG" else None,
|
texture_format=texture_format if texture_format != "JPEG" else None,
|
||||||
|
pivot_to_center_bottom=pivot_to_center_bottom if pivot_to_center_bottom else None,
|
||||||
|
scale_factor=scale_factor if scale_factor != 1.0 else None,
|
||||||
|
with_animation=with_animation if with_animation else None,
|
||||||
|
pack_uv=pack_uv if pack_uv else None,
|
||||||
|
bake=bake if bake else None,
|
||||||
|
part_names=part_names_list,
|
||||||
|
fbx_preset=fbx_preset if fbx_preset != "blender" else None,
|
||||||
|
export_vertex_colors=export_vertex_colors if export_vertex_colors else None,
|
||||||
|
export_orientation=export_orientation if export_orientation != "default" else None,
|
||||||
|
animate_in_place=animate_in_place if animate_in_place else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await poll_until_finished(cls, response, average_duration=30)
|
return await poll_until_finished(cls, response, average_duration=30)
|
||||||
|
|||||||
@ -1,11 +1,9 @@
|
|||||||
import base64
|
import base64
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import torch
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.input_impl.video_types import VideoFromFile
|
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||||
from comfy_api.latest import IO, ComfyExtension
|
|
||||||
from comfy_api_nodes.apis.veo_api import (
|
from comfy_api_nodes.apis.veo_api import (
|
||||||
VeoGenVidPollRequest,
|
VeoGenVidPollRequest,
|
||||||
VeoGenVidPollResponse,
|
VeoGenVidPollResponse,
|
||||||
@ -170,6 +168,8 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
|||||||
# Only add generateAudio for Veo 3 models
|
# Only add generateAudio for Veo 3 models
|
||||||
if model.find("veo-2.0") == -1:
|
if model.find("veo-2.0") == -1:
|
||||||
parameters["generateAudio"] = generate_audio
|
parameters["generateAudio"] = generate_audio
|
||||||
|
# force "enhance_prompt" to True for Veo3 models
|
||||||
|
parameters["enhancePrompt"] = True
|
||||||
|
|
||||||
initial_response = await sync_op(
|
initial_response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
@ -232,7 +232,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
|||||||
|
|
||||||
# Check if video is provided as base64 or URL
|
# Check if video is provided as base64 or URL
|
||||||
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
|
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
|
||||||
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||||
|
|
||||||
if hasattr(video, "gcsUri") and video.gcsUri:
|
if hasattr(video, "gcsUri") and video.gcsUri:
|
||||||
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||||
@ -293,7 +293,7 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
|||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"enhance_prompt",
|
"enhance_prompt",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip="Whether to enhance the prompt with AI assistance",
|
tooltip="This parameter is deprecated and ignored.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
@ -431,8 +431,8 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
|||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
duration: int,
|
duration: int,
|
||||||
seed: int,
|
seed: int,
|
||||||
first_frame: torch.Tensor,
|
first_frame: Input.Image,
|
||||||
last_frame: torch.Tensor,
|
last_frame: Input.Image,
|
||||||
model: str,
|
model: str,
|
||||||
generate_audio: bool,
|
generate_audio: bool,
|
||||||
):
|
):
|
||||||
@ -493,7 +493,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
|||||||
if response.videos:
|
if response.videos:
|
||||||
video = response.videos[0]
|
video = response.videos[0]
|
||||||
if video.bytesBase64Encoded:
|
if video.bytesBase64Encoded:
|
||||||
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||||
if video.gcsUri:
|
if video.gcsUri:
|
||||||
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||||
raise Exception("Video returned but no data or URL was provided")
|
raise Exception("Video returned but no data or URL was provided")
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@ -15,32 +13,40 @@ from comfy_api_nodes.util import (
|
|||||||
poll_op,
|
poll_op,
|
||||||
sync_op,
|
sync_op,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
|
upload_video_to_comfyapi,
|
||||||
validate_audio_duration,
|
validate_audio_duration,
|
||||||
|
validate_video_duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Text2ImageInputField(BaseModel):
|
class Text2ImageInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: Optional[str] = Field(None)
|
negative_prompt: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class Image2ImageInputField(BaseModel):
|
class Image2ImageInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: Optional[str] = Field(None)
|
negative_prompt: str | None = Field(None)
|
||||||
images: list[str] = Field(..., min_length=1, max_length=2)
|
images: list[str] = Field(..., min_length=1, max_length=2)
|
||||||
|
|
||||||
|
|
||||||
class Text2VideoInputField(BaseModel):
|
class Text2VideoInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: Optional[str] = Field(None)
|
negative_prompt: str | None = Field(None)
|
||||||
audio_url: Optional[str] = Field(None)
|
audio_url: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class Image2VideoInputField(BaseModel):
|
class Image2VideoInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: Optional[str] = Field(None)
|
negative_prompt: str | None = Field(None)
|
||||||
img_url: str = Field(...)
|
img_url: str = Field(...)
|
||||||
audio_url: Optional[str] = Field(None)
|
audio_url: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class Reference2VideoInputField(BaseModel):
|
||||||
|
prompt: str = Field(...)
|
||||||
|
negative_prompt: str | None = Field(None)
|
||||||
|
reference_video_urls: list[str] = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class Txt2ImageParametersField(BaseModel):
|
class Txt2ImageParametersField(BaseModel):
|
||||||
@ -48,32 +54,42 @@ class Txt2ImageParametersField(BaseModel):
|
|||||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
prompt_extend: bool = Field(True)
|
prompt_extend: bool = Field(True)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(False)
|
||||||
|
|
||||||
|
|
||||||
class Image2ImageParametersField(BaseModel):
|
class Image2ImageParametersField(BaseModel):
|
||||||
size: Optional[str] = Field(None)
|
size: str | None = Field(None)
|
||||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(False)
|
||||||
|
|
||||||
|
|
||||||
class Text2VideoParametersField(BaseModel):
|
class Text2VideoParametersField(BaseModel):
|
||||||
size: str = Field(...)
|
size: str = Field(...)
|
||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
duration: int = Field(5, ge=5, le=10)
|
duration: int = Field(5, ge=5, le=15)
|
||||||
prompt_extend: bool = Field(True)
|
prompt_extend: bool = Field(True)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(False)
|
||||||
audio: bool = Field(False, description="Should be audio generated automatically")
|
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
||||||
|
shot_type: str = Field("single")
|
||||||
|
|
||||||
|
|
||||||
class Image2VideoParametersField(BaseModel):
|
class Image2VideoParametersField(BaseModel):
|
||||||
resolution: str = Field(...)
|
resolution: str = Field(...)
|
||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
duration: int = Field(5, ge=5, le=10)
|
duration: int = Field(5, ge=5, le=15)
|
||||||
prompt_extend: bool = Field(True)
|
prompt_extend: bool = Field(True)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(False)
|
||||||
audio: bool = Field(False, description="Should be audio generated automatically")
|
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
||||||
|
shot_type: str = Field("single")
|
||||||
|
|
||||||
|
|
||||||
|
class Reference2VideoParametersField(BaseModel):
|
||||||
|
size: str = Field(...)
|
||||||
|
duration: int = Field(5, ge=5, le=15)
|
||||||
|
shot_type: str = Field("single")
|
||||||
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
|
watermark: bool = Field(False)
|
||||||
|
|
||||||
|
|
||||||
class Text2ImageTaskCreationRequest(BaseModel):
|
class Text2ImageTaskCreationRequest(BaseModel):
|
||||||
@ -100,45 +116,51 @@ class Image2VideoTaskCreationRequest(BaseModel):
|
|||||||
parameters: Image2VideoParametersField = Field(...)
|
parameters: Image2VideoParametersField = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class Reference2VideoTaskCreationRequest(BaseModel):
|
||||||
|
model: str = Field(...)
|
||||||
|
input: Reference2VideoInputField = Field(...)
|
||||||
|
parameters: Reference2VideoParametersField = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class TaskCreationOutputField(BaseModel):
|
class TaskCreationOutputField(BaseModel):
|
||||||
task_id: str = Field(...)
|
task_id: str = Field(...)
|
||||||
task_status: str = Field(...)
|
task_status: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class TaskCreationResponse(BaseModel):
|
class TaskCreationResponse(BaseModel):
|
||||||
output: Optional[TaskCreationOutputField] = Field(None)
|
output: TaskCreationOutputField | None = Field(None)
|
||||||
request_id: str = Field(...)
|
request_id: str = Field(...)
|
||||||
code: Optional[str] = Field(None, description="The error code of the failed request.")
|
code: str | None = Field(None, description="Error code for the failed request.")
|
||||||
message: Optional[str] = Field(None, description="Details of the failed request.")
|
message: str | None = Field(None, description="Details about the failed request.")
|
||||||
|
|
||||||
|
|
||||||
class TaskResult(BaseModel):
|
class TaskResult(BaseModel):
|
||||||
url: Optional[str] = Field(None)
|
url: str | None = Field(None)
|
||||||
code: Optional[str] = Field(None)
|
code: str | None = Field(None)
|
||||||
message: Optional[str] = Field(None)
|
message: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class ImageTaskStatusOutputField(TaskCreationOutputField):
|
class ImageTaskStatusOutputField(TaskCreationOutputField):
|
||||||
task_id: str = Field(...)
|
task_id: str = Field(...)
|
||||||
task_status: str = Field(...)
|
task_status: str = Field(...)
|
||||||
results: Optional[list[TaskResult]] = Field(None)
|
results: list[TaskResult] | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class VideoTaskStatusOutputField(TaskCreationOutputField):
|
class VideoTaskStatusOutputField(TaskCreationOutputField):
|
||||||
task_id: str = Field(...)
|
task_id: str = Field(...)
|
||||||
task_status: str = Field(...)
|
task_status: str = Field(...)
|
||||||
video_url: Optional[str] = Field(None)
|
video_url: str | None = Field(None)
|
||||||
code: Optional[str] = Field(None)
|
code: str | None = Field(None)
|
||||||
message: Optional[str] = Field(None)
|
message: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class ImageTaskStatusResponse(BaseModel):
|
class ImageTaskStatusResponse(BaseModel):
|
||||||
output: Optional[ImageTaskStatusOutputField] = Field(None)
|
output: ImageTaskStatusOutputField | None = Field(None)
|
||||||
request_id: str = Field(...)
|
request_id: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class VideoTaskStatusResponse(BaseModel):
|
class VideoTaskStatusResponse(BaseModel):
|
||||||
output: Optional[VideoTaskStatusOutputField] = Field(None)
|
output: VideoTaskStatusOutputField | None = Field(None)
|
||||||
request_id: str = Field(...)
|
request_id: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
@ -152,7 +174,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
node_id="WanTextToImageApi",
|
node_id="WanTextToImageApi",
|
||||||
display_name="Wan Text to Image",
|
display_name="Wan Text to Image",
|
||||||
category="api node/image/Wan",
|
category="api node/image/Wan",
|
||||||
description="Generates image based on text prompt.",
|
description="Generates an image based on a text prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
@ -164,13 +186,13 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid.",
|
tooltip="Negative prompt describing what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
@ -208,8 +230,8 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -234,7 +256,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
height: int = 1024,
|
height: int = 1024,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
prompt_extend: bool = True,
|
prompt_extend: bool = True,
|
||||||
watermark: bool = True,
|
watermark: bool = False,
|
||||||
):
|
):
|
||||||
initial_response = await sync_op(
|
initial_response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
@ -252,7 +274,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not initial_response.output:
|
if not initial_response.output:
|
||||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
response = await poll_op(
|
response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
@ -272,7 +294,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
display_name="Wan Image to Image",
|
display_name="Wan Image to Image",
|
||||||
category="api node/image/Wan",
|
category="api node/image/Wan",
|
||||||
description="Generates an image from one or two input images and a text prompt. "
|
description="Generates an image from one or two input images and a text prompt. "
|
||||||
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
|
"The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
@ -282,19 +304,19 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
tooltip="Single-image editing or multi-image fusion, maximum 2 images.",
|
tooltip="Single-image editing or multi-image fusion. Maximum 2 images.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid.",
|
tooltip="Negative prompt describing what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
# redo this later as an optional combo of recommended resolutions
|
# redo this later as an optional combo of recommended resolutions
|
||||||
@ -327,8 +349,8 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -347,17 +369,17 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
image: torch.Tensor,
|
image: Input.Image,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
# width: int = 1024,
|
# width: int = 1024,
|
||||||
# height: int = 1024,
|
# height: int = 1024,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
watermark: bool = True,
|
watermark: bool = False,
|
||||||
):
|
):
|
||||||
n_images = get_number_of_images(image)
|
n_images = get_number_of_images(image)
|
||||||
if n_images not in (1, 2):
|
if n_images not in (1, 2):
|
||||||
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
|
raise ValueError(f"Expected 1 or 2 input images, but got {n_images}.")
|
||||||
images = []
|
images = []
|
||||||
for i in image:
|
for i in image:
|
||||||
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
|
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
|
||||||
@ -376,7 +398,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not initial_response.output:
|
if not initial_response.output:
|
||||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
response = await poll_op(
|
response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
@ -395,25 +417,25 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
node_id="WanTextToVideoApi",
|
node_id="WanTextToVideoApi",
|
||||||
display_name="Wan Text to Video",
|
display_name="Wan Text to Video",
|
||||||
category="api node/video/Wan",
|
category="api node/video/Wan",
|
||||||
description="Generates video based on text prompt.",
|
description="Generates a video based on a text prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["wan2.5-t2v-preview"],
|
options=["wan2.5-t2v-preview", "wan2.6-t2v"],
|
||||||
default="wan2.5-t2v-preview",
|
default="wan2.6-t2v",
|
||||||
tooltip="Model to use.",
|
tooltip="Model to use.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid.",
|
tooltip="Negative prompt describing what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
@ -433,23 +455,23 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
"1080p: 4:3 (1632x1248)",
|
"1080p: 4:3 (1632x1248)",
|
||||||
"1080p: 3:4 (1248x1632)",
|
"1080p: 3:4 (1248x1632)",
|
||||||
],
|
],
|
||||||
default="480p: 1:1 (624x624)",
|
default="720p: 1:1 (960x960)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=5,
|
default=5,
|
||||||
min=5,
|
min=5,
|
||||||
max=10,
|
max=15,
|
||||||
step=5,
|
step=5,
|
||||||
display_mode=IO.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Available durations: 5 and 10 seconds",
|
tooltip="A 15-second duration is available only for the Wan 2.6 model.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Audio.Input(
|
IO.Audio.Input(
|
||||||
"audio",
|
"audio",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.",
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
@ -466,7 +488,7 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
"generate_audio",
|
"generate_audio",
|
||||||
default=False,
|
default=False,
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="If there is no audio input, generate audio automatically.",
|
tooltip="If no audio input is provided, generate audio automatically.",
|
||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"prompt_extend",
|
"prompt_extend",
|
||||||
@ -476,8 +498,16 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"shot_type",
|
||||||
|
options=["single", "multi"],
|
||||||
|
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
|
||||||
|
"single continuous shot or multiple shots with cuts. "
|
||||||
|
"This parameter takes effect only when prompt_extend is True.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -498,14 +528,19 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
size: str = "480p: 1:1 (624x624)",
|
size: str = "720p: 1:1 (960x960)",
|
||||||
duration: int = 5,
|
duration: int = 5,
|
||||||
audio: Optional[Input.Audio] = None,
|
audio: Input.Audio | None = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
generate_audio: bool = False,
|
generate_audio: bool = False,
|
||||||
prompt_extend: bool = True,
|
prompt_extend: bool = True,
|
||||||
watermark: bool = True,
|
watermark: bool = False,
|
||||||
|
shot_type: str = "single",
|
||||||
):
|
):
|
||||||
|
if "480p" in size and model == "wan2.6-t2v":
|
||||||
|
raise ValueError("The Wan 2.6 model does not support 480p.")
|
||||||
|
if duration == 15 and model == "wan2.5-t2v-preview":
|
||||||
|
raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.")
|
||||||
width, height = RES_IN_PARENS.search(size).groups()
|
width, height = RES_IN_PARENS.search(size).groups()
|
||||||
audio_url = None
|
audio_url = None
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
@ -526,11 +561,12 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
audio=generate_audio,
|
audio=generate_audio,
|
||||||
prompt_extend=prompt_extend,
|
prompt_extend=prompt_extend,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
|
shot_type=shot_type,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not initial_response.output:
|
if not initial_response.output:
|
||||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
response = await poll_op(
|
response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
@ -549,12 +585,12 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
node_id="WanImageToVideoApi",
|
node_id="WanImageToVideoApi",
|
||||||
display_name="Wan Image to Video",
|
display_name="Wan Image to Video",
|
||||||
category="api node/video/Wan",
|
category="api node/video/Wan",
|
||||||
description="Generates video based on the first frame and text prompt.",
|
description="Generates a video from the first frame and a text prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["wan2.5-i2v-preview"],
|
options=["wan2.5-i2v-preview", "wan2.6-i2v"],
|
||||||
default="wan2.5-i2v-preview",
|
default="wan2.6-i2v",
|
||||||
tooltip="Model to use.",
|
tooltip="Model to use.",
|
||||||
),
|
),
|
||||||
IO.Image.Input(
|
IO.Image.Input(
|
||||||
@ -564,13 +600,13 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid.",
|
tooltip="Negative prompt describing what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
@ -580,23 +616,23 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
"720P",
|
"720P",
|
||||||
"1080P",
|
"1080P",
|
||||||
],
|
],
|
||||||
default="480P",
|
default="720P",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=5,
|
default=5,
|
||||||
min=5,
|
min=5,
|
||||||
max=10,
|
max=15,
|
||||||
step=5,
|
step=5,
|
||||||
display_mode=IO.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Available durations: 5 and 10 seconds",
|
tooltip="Duration 15 available only for WAN2.6 model.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Audio.Input(
|
IO.Audio.Input(
|
||||||
"audio",
|
"audio",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.",
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
@ -613,7 +649,7 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
"generate_audio",
|
"generate_audio",
|
||||||
default=False,
|
default=False,
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="If there is no audio input, generate audio automatically.",
|
tooltip="If no audio input is provided, generate audio automatically.",
|
||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"prompt_extend",
|
"prompt_extend",
|
||||||
@ -623,8 +659,16 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"shot_type",
|
||||||
|
options=["single", "multi"],
|
||||||
|
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
|
||||||
|
"single continuous shot or multiple shots with cuts. "
|
||||||
|
"This parameter takes effect only when prompt_extend is True.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -643,19 +687,24 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
image: torch.Tensor,
|
image: Input.Image,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
resolution: str = "480P",
|
resolution: str = "720P",
|
||||||
duration: int = 5,
|
duration: int = 5,
|
||||||
audio: Optional[Input.Audio] = None,
|
audio: Input.Audio | None = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
generate_audio: bool = False,
|
generate_audio: bool = False,
|
||||||
prompt_extend: bool = True,
|
prompt_extend: bool = True,
|
||||||
watermark: bool = True,
|
watermark: bool = False,
|
||||||
|
shot_type: str = "single",
|
||||||
):
|
):
|
||||||
if get_number_of_images(image) != 1:
|
if get_number_of_images(image) != 1:
|
||||||
raise ValueError("Exactly one input image is required.")
|
raise ValueError("Exactly one input image is required.")
|
||||||
|
if "480P" in resolution and model == "wan2.6-i2v":
|
||||||
|
raise ValueError("The Wan 2.6 model does not support 480P.")
|
||||||
|
if duration == 15 and model == "wan2.5-i2v-preview":
|
||||||
|
raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.")
|
||||||
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
|
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
|
||||||
audio_url = None
|
audio_url = None
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
@ -677,11 +726,12 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
audio=generate_audio,
|
audio=generate_audio,
|
||||||
prompt_extend=prompt_extend,
|
prompt_extend=prompt_extend,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
|
shot_type=shot_type,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not initial_response.output:
|
if not initial_response.output:
|
||||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
response = await poll_op(
|
response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
@ -693,6 +743,143 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||||
|
|
||||||
|
|
||||||
|
class WanReferenceVideoApi(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="WanReferenceVideoApi",
|
||||||
|
display_name="Wan Reference to Video",
|
||||||
|
category="api node/video/Wan",
|
||||||
|
description="Use the character and voice from input videos, combined with a prompt, "
|
||||||
|
"to generate a new video that maintains character consistency.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model", options=["wan2.6-r2v"]),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Prompt describing the elements and visual features. Supports English and Chinese. "
|
||||||
|
"Use identifiers such as `character1` and `character2` to refer to the reference characters.",
|
||||||
|
),
|
||||||
|
IO.String.Input(
|
||||||
|
"negative_prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Negative prompt describing what to avoid.",
|
||||||
|
),
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"reference_videos",
|
||||||
|
template=IO.Autogrow.TemplateNames(
|
||||||
|
IO.Video.Input("reference_video"),
|
||||||
|
names=["character1", "character2", "character3"],
|
||||||
|
min=1,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"size",
|
||||||
|
options=[
|
||||||
|
"720p: 1:1 (960x960)",
|
||||||
|
"720p: 16:9 (1280x720)",
|
||||||
|
"720p: 9:16 (720x1280)",
|
||||||
|
"720p: 4:3 (1088x832)",
|
||||||
|
"720p: 3:4 (832x1088)",
|
||||||
|
"1080p: 1:1 (1440x1440)",
|
||||||
|
"1080p: 16:9 (1920x1080)",
|
||||||
|
"1080p: 9:16 (1080x1920)",
|
||||||
|
"1080p: 4:3 (1632x1248)",
|
||||||
|
"1080p: 3:4 (1248x1632)",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=5,
|
||||||
|
min=5,
|
||||||
|
max=10,
|
||||||
|
step=5,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"shot_type",
|
||||||
|
options=["single", "multi"],
|
||||||
|
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
|
||||||
|
"single continuous shot or multiple shots with cuts.",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"watermark",
|
||||||
|
default=False,
|
||||||
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
reference_videos: IO.Autogrow.Type,
|
||||||
|
size: str,
|
||||||
|
duration: int,
|
||||||
|
seed: int,
|
||||||
|
shot_type: str,
|
||||||
|
watermark: bool,
|
||||||
|
):
|
||||||
|
reference_video_urls = []
|
||||||
|
for i in reference_videos:
|
||||||
|
validate_video_duration(reference_videos[i], min_duration=2, max_duration=30)
|
||||||
|
for i in reference_videos:
|
||||||
|
reference_video_urls.append(await upload_video_to_comfyapi(cls, reference_videos[i]))
|
||||||
|
width, height = RES_IN_PARENS.search(size).groups()
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
|
||||||
|
response_model=TaskCreationResponse,
|
||||||
|
data=Reference2VideoTaskCreationRequest(
|
||||||
|
model=model,
|
||||||
|
input=Reference2VideoInputField(
|
||||||
|
prompt=prompt, negative_prompt=negative_prompt, reference_video_urls=reference_video_urls
|
||||||
|
),
|
||||||
|
parameters=Reference2VideoParametersField(
|
||||||
|
size=f"{width}*{height}",
|
||||||
|
duration=duration,
|
||||||
|
shot_type=shot_type,
|
||||||
|
watermark=watermark,
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not initial_response.output:
|
||||||
|
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
|
response_model=VideoTaskStatusResponse,
|
||||||
|
status_extractor=lambda x: x.output.task_status,
|
||||||
|
poll_interval=6,
|
||||||
|
max_poll_attempts=280,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||||
|
|
||||||
|
|
||||||
class WanApiExtension(ComfyExtension):
|
class WanApiExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -701,6 +888,7 @@ class WanApiExtension(ComfyExtension):
|
|||||||
WanImageToImageApi,
|
WanImageToImageApi,
|
||||||
WanTextToVideoApi,
|
WanTextToVideoApi,
|
||||||
WanImageToVideoApi,
|
WanImageToVideoApi,
|
||||||
|
WanReferenceVideoApi,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,16 +1,22 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy.model_management import processing_interrupted
|
from comfy.model_management import processing_interrupted
|
||||||
from comfy_api.latest import IO
|
from comfy_api.latest import IO
|
||||||
|
|
||||||
from .common_exceptions import ProcessingInterrupted
|
from .common_exceptions import ProcessingInterrupted
|
||||||
|
|
||||||
|
_HAS_PCT_ESC = re.compile(r"%[0-9A-Fa-f]{2}") # any % followed by 2 hex digits
|
||||||
|
_HAS_BAD_PCT = re.compile(r"%(?![0-9A-Fa-f]{2})") # any % not followed by 2 hex digits
|
||||||
|
|
||||||
|
|
||||||
def is_processing_interrupted() -> bool:
|
def is_processing_interrupted() -> bool:
|
||||||
"""Return True if user/runtime requested interruption."""
|
"""Return True if user/runtime requested interruption."""
|
||||||
@ -69,3 +75,17 @@ def get_fs_object_size(path_or_object: str | BytesIO) -> int:
|
|||||||
if isinstance(path_or_object, str):
|
if isinstance(path_or_object, str):
|
||||||
return os.path.getsize(path_or_object)
|
return os.path.getsize(path_or_object)
|
||||||
return len(path_or_object.getvalue())
|
return len(path_or_object.getvalue())
|
||||||
|
|
||||||
|
|
||||||
|
def to_aiohttp_url(url: str) -> URL:
|
||||||
|
"""If `url` appears to be already percent-encoded (contains at least one valid %HH
|
||||||
|
escape and no malformed '%' sequences) and contains no raw whitespace/control
|
||||||
|
characters preserve the original encoding byte-for-byte (important for signed/presigned URLs).
|
||||||
|
Otherwise, return `URL(url)` and allow yarl to normalize/quote as needed."""
|
||||||
|
if any(c.isspace() for c in url) or any(ord(c) < 0x20 for c in url):
|
||||||
|
# Avoid encoded=True if URL contains raw whitespace/control chars
|
||||||
|
return URL(url)
|
||||||
|
if _HAS_PCT_ESC.search(url) and not _HAS_BAD_PCT.search(url):
|
||||||
|
# Preserve encoding only if it appears pre-encoded AND has no invalid % sequences
|
||||||
|
return URL(url, encoded=True)
|
||||||
|
return URL(url)
|
||||||
|
|||||||
@ -430,9 +430,9 @@ def _display_text(
|
|||||||
if status:
|
if status:
|
||||||
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
||||||
if price is not None:
|
if price is not None:
|
||||||
p = f"{float(price):,.4f}".rstrip("0").rstrip(".")
|
p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
|
||||||
if p != "0":
|
if p != "0":
|
||||||
display_lines.append(f"Price: ${p}")
|
display_lines.append(f"Price: {p} credits")
|
||||||
if text is not None:
|
if text is not None:
|
||||||
display_lines.append(text)
|
display_lines.append(text)
|
||||||
if display_lines:
|
if display_lines:
|
||||||
|
|||||||
@ -129,7 +129,7 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
|||||||
return img_byte_arr
|
return img_byte_arr
|
||||||
|
|
||||||
|
|
||||||
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor:
|
||||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||||
samples = image.movedim(-1, 1)
|
samples = image.movedim(-1, 1)
|
||||||
total = int(total_pixels)
|
total = int(total_pixels)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user