mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 22:00:49 +08:00
Compare commits
48 Commits
d40347e373
...
b91e3f9859
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b91e3f9859 | ||
|
|
efac5a7075 | ||
|
|
3cd7b32f1b | ||
|
|
c0c9720d77 | ||
|
|
fc0cb10bcb | ||
|
|
b7d7cc1d49 | ||
|
|
79e94544bd | ||
|
|
ce0000c4f2 | ||
|
|
c5cfb34c07 | ||
|
|
edee33f55e | ||
|
|
2c03884f5f | ||
|
|
6e9ee55cdd | ||
|
|
023cf13721 | ||
|
|
c3566c0d76 | ||
|
|
c3c3e93c5b | ||
|
|
6ffc159bdd | ||
|
|
96e0d0924e | ||
|
|
e14f3b6610 | ||
|
|
1618002411 | ||
|
|
6ef85c4915 | ||
|
|
6da00dd899 | ||
|
|
4f3f9e72a9 | ||
|
|
d157c3299d | ||
|
|
d1b9822f74 | ||
|
|
f2b002372b | ||
|
|
38d0493825 | ||
|
|
acbf08cd60 | ||
|
|
53e762a3af | ||
|
|
9a552df898 | ||
|
|
f2fda021ab | ||
|
|
303b1735f8 | ||
|
|
9e5f677746 | ||
|
|
65cfcf5b1b | ||
|
|
1bdc9a947f | ||
|
|
d622a61874 | ||
|
|
236b9e211d | ||
|
|
6ca3d5c011 | ||
|
|
0be8a76c93 | ||
|
|
0357ed7ec4 | ||
|
|
f59f71cf34 | ||
|
|
178bdc5e14 | ||
|
|
25a1bfab4e | ||
|
|
d7111e426a | ||
|
|
0e6221cc79 | ||
|
|
9ca7e143af | ||
|
|
8fd07170f1 | ||
|
|
5c565e9125 | ||
|
|
083d8aa330 |
@ -1,3 +1,3 @@
|
|||||||
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
||||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||||
pause
|
pause
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||||
pause
|
pause
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||||
pause
|
pause
|
||||||
|
|||||||
2
.github/workflows/stable-release.yml
vendored
2
.github/workflows/stable-release.yml
vendored
@ -117,7 +117,7 @@ jobs:
|
|||||||
./python.exe get-pip.py
|
./python.exe get-pip.py
|
||||||
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
|
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
|
||||||
|
|
||||||
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
|
grep comfy ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
|
||||||
./python.exe -s -m pip install -r requirements_comfyui.txt
|
./python.exe -s -m pip install -r requirements_comfyui.txt
|
||||||
rm requirements_comfyui.txt
|
rm requirements_comfyui.txt
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/test-build.yml
vendored
2
.github/workflows/test-build.yml
vendored
@ -18,7 +18,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
|||||||
2
.github/workflows/test-ci.yml
vendored
2
.github/workflows/test-ci.yml
vendored
@ -20,6 +20,7 @@ jobs:
|
|||||||
test-stable:
|
test-stable:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
max-parallel: 1 # This forces sequential execution
|
||||||
matrix:
|
matrix:
|
||||||
# os: [macos, linux, windows]
|
# os: [macos, linux, windows]
|
||||||
# os: [macos, linux]
|
# os: [macos, linux]
|
||||||
@ -74,6 +75,7 @@ jobs:
|
|||||||
test-unix-nightly:
|
test-unix-nightly:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
max-parallel: 1 # This forces sequential execution
|
||||||
matrix:
|
matrix:
|
||||||
# os: [macos, linux]
|
# os: [macos, linux]
|
||||||
os: [linux]
|
os: [linux]
|
||||||
|
|||||||
4
.github/workflows/test-launch.yml
vendored
4
.github/workflows/test-launch.yml
vendored
@ -32,7 +32,9 @@ jobs:
|
|||||||
working-directory: ComfyUI
|
working-directory: ComfyUI
|
||||||
- name: Check for unhandled exceptions in server log
|
- name: Check for unhandled exceptions in server log
|
||||||
run: |
|
run: |
|
||||||
if grep -qE "Exception|Error" console_output.log; then
|
grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': True, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" console_output.log | grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': False, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" > console_output_filtered.log
|
||||||
|
cat console_output_filtered.log
|
||||||
|
if grep -qE "Exception|Error" console_output_filtered.log; then
|
||||||
echo "Unhandled exception/error found in server log."
|
echo "Unhandled exception/error found in server log."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|||||||
@ -44,7 +44,7 @@ class ModelFileManager:
|
|||||||
@routes.get("/experiment/models/{folder}")
|
@routes.get("/experiment/models/{folder}")
|
||||||
async def get_all_models(request):
|
async def get_all_models(request):
|
||||||
folder = request.match_info.get("folder", None)
|
folder = request.match_info.get("folder", None)
|
||||||
if not folder in folder_paths.folder_names_and_paths:
|
if folder not in folder_paths.folder_names_and_paths:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
files = self.get_model_file_list(folder)
|
files = self.get_model_file_list(folder)
|
||||||
return web.json_response(files)
|
return web.json_response(files)
|
||||||
@ -55,7 +55,7 @@ class ModelFileManager:
|
|||||||
path_index = int(request.match_info.get("path_index", None))
|
path_index = int(request.match_info.get("path_index", None))
|
||||||
filename = request.match_info.get("filename", None)
|
filename = request.match_info.get("filename", None)
|
||||||
|
|
||||||
if not folder_name in folder_paths.folder_names_and_paths:
|
if folder_name not in folder_paths.folder_names_and_paths:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||||
|
|||||||
@ -2,6 +2,25 @@ import torch
|
|||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
|
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||||
|
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||||
|
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||||
|
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||||
|
image = image.movedim(-1, 1)
|
||||||
|
if not (image.shape[2] == size and image.shape[3] == size):
|
||||||
|
if crop:
|
||||||
|
scale = (size / min(image.shape[2], image.shape[3]))
|
||||||
|
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||||
|
else:
|
||||||
|
scale_size = (size, size)
|
||||||
|
|
||||||
|
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||||
|
h = (image.shape[2] - size)//2
|
||||||
|
w = (image.shape[3] - size)//2
|
||||||
|
image = image[:,:,h:h+size,w:w+size]
|
||||||
|
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||||
|
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||||
|
|
||||||
class CLIPAttention(torch.nn.Module):
|
class CLIPAttention(torch.nn.Module):
|
||||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -17,24 +16,7 @@ class Output:
|
|||||||
def __setitem__(self, key, item):
|
def __setitem__(self, key, item):
|
||||||
setattr(self, key, item)
|
setattr(self, key, item)
|
||||||
|
|
||||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually
|
||||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
|
||||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
|
||||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
|
||||||
image = image.movedim(-1, 1)
|
|
||||||
if not (image.shape[2] == size and image.shape[3] == size):
|
|
||||||
if crop:
|
|
||||||
scale = (size / min(image.shape[2], image.shape[3]))
|
|
||||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
|
||||||
else:
|
|
||||||
scale_size = (size, size)
|
|
||||||
|
|
||||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
|
||||||
h = (image.shape[2] - size)//2
|
|
||||||
w = (image.shape[3] - size)//2
|
|
||||||
image = image[:,:,h:h+size,w:w+size]
|
|
||||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
|
||||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
|
||||||
|
|
||||||
IMAGE_ENCODERS = {
|
IMAGE_ENCODERS = {
|
||||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
@ -73,7 +55,7 @@ class ClipVisionModel():
|
|||||||
|
|
||||||
def encode_image(self, image, crop=True):
|
def encode_image(self, image, crop=True):
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||||
|
|
||||||
outputs = Output()
|
outputs = Output()
|
||||||
|
|||||||
@ -188,6 +188,12 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
audio_cond = cond_value.cond
|
audio_cond = cond_value.cond
|
||||||
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
||||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
||||||
|
# Handle vace_context (temporal dim is 3)
|
||||||
|
elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
|
vace_cond = cond_value.cond
|
||||||
|
if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim):
|
||||||
|
sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list)
|
||||||
|
new_cond_item[cond_key] = cond_value._copy_with(sliced_vace)
|
||||||
# if has cond that is a Tensor, check if needs to be subset
|
# if has cond that is a Tensor, check if needs to be subset
|
||||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
||||||
|
|||||||
@ -527,7 +527,8 @@ class HookKeyframeGroup:
|
|||||||
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
||||||
break
|
break
|
||||||
# if eval_c is outside the percent range, stop looking further
|
# if eval_c is outside the percent range, stop looking further
|
||||||
else: break
|
else:
|
||||||
|
break
|
||||||
# update steps current context is used
|
# update steps current context is used
|
||||||
self._current_used_steps += 1
|
self._current_used_steps += 1
|
||||||
# update current timestep this was performed on
|
# update current timestep this was performed on
|
||||||
|
|||||||
@ -407,6 +407,11 @@ class LTXV(LatentFormat):
|
|||||||
|
|
||||||
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
||||||
|
|
||||||
|
class LTXAV(LTXV):
|
||||||
|
def __init__(self):
|
||||||
|
self.latent_rgb_factors = None
|
||||||
|
self.latent_rgb_factors_bias = None
|
||||||
|
|
||||||
class HunyuanVideo(LatentFormat):
|
class HunyuanVideo(LatentFormat):
|
||||||
latent_channels = 16
|
latent_channels = 16
|
||||||
latent_dimensions = 3
|
latent_dimensions = 3
|
||||||
|
|||||||
@ -270,7 +270,7 @@ class ChromaRadiance(Chroma):
|
|||||||
bad_keys = tuple(
|
bad_keys = tuple(
|
||||||
k
|
k
|
||||||
for k, v in overrides.items()
|
for k, v in overrides.items()
|
||||||
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
|
if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys)
|
||||||
)
|
)
|
||||||
if bad_keys:
|
if bad_keys:
|
||||||
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from torch import Tensor
|
|||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||||
@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
|
|||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||||
assert dim % 2 == 0
|
assert dim % 2 == 0
|
||||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
||||||
@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||||
return out.to(dtype=torch.float32, device=pos.device)
|
return out.to(dtype=torch.float32, device=pos.device)
|
||||||
|
|
||||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
|
||||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
|
||||||
|
|
||||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
try:
|
||||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
import comfy.quant_ops
|
||||||
|
apply_rope = comfy.quant_ops.ck.apply_rope
|
||||||
|
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||||
|
except:
|
||||||
|
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||||
|
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||||
|
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||||
|
|
||||||
return x_out.reshape(*x.shape).type_as(x)
|
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||||
|
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||||
|
|
||||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
return x_out.reshape(*x.shape).type_as(x)
|
||||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
|
||||||
|
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||||
|
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||||
|
|||||||
@ -3,7 +3,8 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||||
import model_management, model_patcher
|
import model_management
|
||||||
|
import model_patcher
|
||||||
|
|
||||||
class SRResidualCausalBlock3D(nn.Module):
|
class SRResidualCausalBlock3D(nn.Module):
|
||||||
def __init__(self, channels: int):
|
def __init__(self, channels: int):
|
||||||
|
|||||||
837
comfy/ldm/lightricks/av_model.py
Normal file
837
comfy/ldm/lightricks/av_model.py
Normal file
@ -0,0 +1,837 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from comfy.ldm.lightricks.model import (
|
||||||
|
CrossAttention,
|
||||||
|
FeedForward,
|
||||||
|
AdaLayerNormSingle,
|
||||||
|
PixArtAlphaTextProjection,
|
||||||
|
LTXVModel,
|
||||||
|
)
|
||||||
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
class BasicAVTransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
v_dim,
|
||||||
|
a_dim,
|
||||||
|
v_heads,
|
||||||
|
a_heads,
|
||||||
|
vd_head,
|
||||||
|
ad_head,
|
||||||
|
v_context_dim=None,
|
||||||
|
a_context_dim=None,
|
||||||
|
attn_precision=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn_precision = attn_precision
|
||||||
|
|
||||||
|
self.attn1 = CrossAttention(
|
||||||
|
query_dim=v_dim,
|
||||||
|
heads=v_heads,
|
||||||
|
dim_head=vd_head,
|
||||||
|
context_dim=None,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.audio_attn1 = CrossAttention(
|
||||||
|
query_dim=a_dim,
|
||||||
|
heads=a_heads,
|
||||||
|
dim_head=ad_head,
|
||||||
|
context_dim=None,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attn2 = CrossAttention(
|
||||||
|
query_dim=v_dim,
|
||||||
|
context_dim=v_context_dim,
|
||||||
|
heads=v_heads,
|
||||||
|
dim_head=vd_head,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.audio_attn2 = CrossAttention(
|
||||||
|
query_dim=a_dim,
|
||||||
|
context_dim=a_context_dim,
|
||||||
|
heads=a_heads,
|
||||||
|
dim_head=ad_head,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Q: Video, K,V: Audio
|
||||||
|
self.audio_to_video_attn = CrossAttention(
|
||||||
|
query_dim=v_dim,
|
||||||
|
context_dim=a_dim,
|
||||||
|
heads=a_heads,
|
||||||
|
dim_head=ad_head,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Q: Audio, K,V: Video
|
||||||
|
self.video_to_audio_attn = CrossAttention(
|
||||||
|
query_dim=a_dim,
|
||||||
|
context_dim=v_dim,
|
||||||
|
heads=a_heads,
|
||||||
|
dim_head=ad_head,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ff = FeedForward(
|
||||||
|
v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
self.audio_ff = FeedForward(
|
||||||
|
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
||||||
|
self.audio_scale_shift_table = nn.Parameter(
|
||||||
|
torch.empty(6, a_dim, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
||||||
|
torch.empty(5, a_dim, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
self.scale_shift_table_a2v_ca_video = nn.Parameter(
|
||||||
|
torch.empty(5, v_dim, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_ada_values(
|
||||||
|
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
||||||
|
):
|
||||||
|
num_ada_params = scale_shift_table.shape[0]
|
||||||
|
|
||||||
|
ada_values = (
|
||||||
|
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
|
||||||
|
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
|
||||||
|
).unbind(dim=2)
|
||||||
|
return ada_values
|
||||||
|
|
||||||
|
def get_av_ca_ada_values(
|
||||||
|
self,
|
||||||
|
scale_shift_table: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
scale_shift_timestep: torch.Tensor,
|
||||||
|
gate_timestep: torch.Tensor,
|
||||||
|
num_scale_shift_values: int = 4,
|
||||||
|
):
|
||||||
|
scale_shift_ada_values = self.get_ada_values(
|
||||||
|
scale_shift_table[:num_scale_shift_values, :],
|
||||||
|
batch_size,
|
||||||
|
scale_shift_timestep,
|
||||||
|
)
|
||||||
|
gate_ada_values = self.get_ada_values(
|
||||||
|
scale_shift_table[num_scale_shift_values:, :],
|
||||||
|
batch_size,
|
||||||
|
gate_timestep,
|
||||||
|
)
|
||||||
|
|
||||||
|
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
|
||||||
|
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
|
||||||
|
|
||||||
|
return (*scale_shift_chunks, *gate_ada_values)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
v_context=None,
|
||||||
|
a_context=None,
|
||||||
|
attention_mask=None,
|
||||||
|
v_timestep=None,
|
||||||
|
a_timestep=None,
|
||||||
|
v_pe=None,
|
||||||
|
a_pe=None,
|
||||||
|
v_cross_pe=None,
|
||||||
|
a_cross_pe=None,
|
||||||
|
v_cross_scale_shift_timestep=None,
|
||||||
|
a_cross_scale_shift_timestep=None,
|
||||||
|
v_cross_gate_timestep=None,
|
||||||
|
a_cross_gate_timestep=None,
|
||||||
|
transformer_options=None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
run_vx = transformer_options.get("run_vx", True)
|
||||||
|
run_ax = transformer_options.get("run_ax", True)
|
||||||
|
|
||||||
|
vx, ax = x
|
||||||
|
run_ax = run_ax and ax.numel() > 0
|
||||||
|
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||||
|
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||||
|
|
||||||
|
if run_vx:
|
||||||
|
vshift_msa, vscale_msa, vgate_msa = (
|
||||||
|
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
||||||
|
)
|
||||||
|
|
||||||
|
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||||
|
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
||||||
|
vx += self.attn2(
|
||||||
|
comfy.ldm.common_dit.rms_norm(vx),
|
||||||
|
context=v_context,
|
||||||
|
mask=attention_mask,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
del vshift_msa, vscale_msa, vgate_msa
|
||||||
|
|
||||||
|
if run_ax:
|
||||||
|
ashift_msa, ascale_msa, agate_msa = (
|
||||||
|
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
||||||
|
)
|
||||||
|
|
||||||
|
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||||
|
ax += (
|
||||||
|
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||||
|
* agate_msa
|
||||||
|
)
|
||||||
|
ax += self.audio_attn2(
|
||||||
|
comfy.ldm.common_dit.rms_norm(ax),
|
||||||
|
context=a_context,
|
||||||
|
mask=attention_mask,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
del ashift_msa, ascale_msa, agate_msa
|
||||||
|
|
||||||
|
# Audio - Video cross attention.
|
||||||
|
if run_a2v or run_v2a:
|
||||||
|
# norm3
|
||||||
|
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||||
|
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||||
|
|
||||||
|
(
|
||||||
|
scale_ca_audio_hidden_states_a2v,
|
||||||
|
shift_ca_audio_hidden_states_a2v,
|
||||||
|
scale_ca_audio_hidden_states_v2a,
|
||||||
|
shift_ca_audio_hidden_states_v2a,
|
||||||
|
gate_out_v2a,
|
||||||
|
) = self.get_av_ca_ada_values(
|
||||||
|
self.scale_shift_table_a2v_ca_audio,
|
||||||
|
ax.shape[0],
|
||||||
|
a_cross_scale_shift_timestep,
|
||||||
|
a_cross_gate_timestep,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
scale_ca_video_hidden_states_a2v,
|
||||||
|
shift_ca_video_hidden_states_a2v,
|
||||||
|
scale_ca_video_hidden_states_v2a,
|
||||||
|
shift_ca_video_hidden_states_v2a,
|
||||||
|
gate_out_a2v,
|
||||||
|
) = self.get_av_ca_ada_values(
|
||||||
|
self.scale_shift_table_a2v_ca_video,
|
||||||
|
vx.shape[0],
|
||||||
|
v_cross_scale_shift_timestep,
|
||||||
|
v_cross_gate_timestep,
|
||||||
|
)
|
||||||
|
|
||||||
|
if run_a2v:
|
||||||
|
vx_scaled = (
|
||||||
|
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
||||||
|
+ shift_ca_video_hidden_states_a2v
|
||||||
|
)
|
||||||
|
ax_scaled = (
|
||||||
|
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
||||||
|
+ shift_ca_audio_hidden_states_a2v
|
||||||
|
)
|
||||||
|
vx += (
|
||||||
|
self.audio_to_video_attn(
|
||||||
|
vx_scaled,
|
||||||
|
context=ax_scaled,
|
||||||
|
pe=v_cross_pe,
|
||||||
|
k_pe=a_cross_pe,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
* gate_out_a2v
|
||||||
|
)
|
||||||
|
|
||||||
|
del gate_out_a2v
|
||||||
|
del scale_ca_video_hidden_states_a2v,\
|
||||||
|
shift_ca_video_hidden_states_a2v,\
|
||||||
|
scale_ca_audio_hidden_states_a2v,\
|
||||||
|
shift_ca_audio_hidden_states_a2v,\
|
||||||
|
|
||||||
|
if run_v2a:
|
||||||
|
ax_scaled = (
|
||||||
|
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
||||||
|
+ shift_ca_audio_hidden_states_v2a
|
||||||
|
)
|
||||||
|
vx_scaled = (
|
||||||
|
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
||||||
|
+ shift_ca_video_hidden_states_v2a
|
||||||
|
)
|
||||||
|
ax += (
|
||||||
|
self.video_to_audio_attn(
|
||||||
|
ax_scaled,
|
||||||
|
context=vx_scaled,
|
||||||
|
pe=a_cross_pe,
|
||||||
|
k_pe=v_cross_pe,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
* gate_out_v2a
|
||||||
|
)
|
||||||
|
|
||||||
|
del gate_out_v2a
|
||||||
|
del scale_ca_video_hidden_states_v2a,\
|
||||||
|
shift_ca_video_hidden_states_v2a,\
|
||||||
|
scale_ca_audio_hidden_states_v2a,\
|
||||||
|
shift_ca_audio_hidden_states_v2a
|
||||||
|
|
||||||
|
if run_vx:
|
||||||
|
vshift_mlp, vscale_mlp, vgate_mlp = (
|
||||||
|
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
||||||
|
)
|
||||||
|
|
||||||
|
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||||
|
vx += self.ff(vx_scaled) * vgate_mlp
|
||||||
|
del vshift_mlp, vscale_mlp, vgate_mlp
|
||||||
|
|
||||||
|
if run_ax:
|
||||||
|
ashift_mlp, ascale_mlp, agate_mlp = (
|
||||||
|
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
||||||
|
)
|
||||||
|
|
||||||
|
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||||
|
ax += self.audio_ff(ax_scaled) * agate_mlp
|
||||||
|
|
||||||
|
del ashift_mlp, ascale_mlp, agate_mlp
|
||||||
|
|
||||||
|
|
||||||
|
return vx, ax
|
||||||
|
|
||||||
|
|
||||||
|
class LTXAVModel(LTXVModel):
|
||||||
|
"""LTXAV model for audio-video generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=128,
|
||||||
|
audio_in_channels=128,
|
||||||
|
cross_attention_dim=4096,
|
||||||
|
audio_cross_attention_dim=2048,
|
||||||
|
attention_head_dim=128,
|
||||||
|
audio_attention_head_dim=64,
|
||||||
|
num_attention_heads=32,
|
||||||
|
audio_num_attention_heads=32,
|
||||||
|
caption_channels=3840,
|
||||||
|
num_layers=48,
|
||||||
|
positional_embedding_theta=10000.0,
|
||||||
|
positional_embedding_max_pos=[20, 2048, 2048],
|
||||||
|
audio_positional_embedding_max_pos=[20],
|
||||||
|
causal_temporal_positioning=False,
|
||||||
|
vae_scale_factors=(8, 32, 32),
|
||||||
|
use_middle_indices_grid=False,
|
||||||
|
timestep_scale_multiplier=1000.0,
|
||||||
|
av_ca_timestep_scale_multiplier=1.0,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# Store audio-specific parameters
|
||||||
|
self.audio_in_channels = audio_in_channels
|
||||||
|
self.audio_cross_attention_dim = audio_cross_attention_dim
|
||||||
|
self.audio_attention_head_dim = audio_attention_head_dim
|
||||||
|
self.audio_num_attention_heads = audio_num_attention_heads
|
||||||
|
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
||||||
|
|
||||||
|
# Calculate audio dimensions
|
||||||
|
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
||||||
|
self.audio_out_channels = audio_in_channels
|
||||||
|
|
||||||
|
# Audio-specific constants
|
||||||
|
self.num_audio_channels = 8
|
||||||
|
self.audio_frequency_bins = 16
|
||||||
|
|
||||||
|
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
in_channels=in_channels,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
caption_channels=caption_channels,
|
||||||
|
num_layers=num_layers,
|
||||||
|
positional_embedding_theta=positional_embedding_theta,
|
||||||
|
positional_embedding_max_pos=positional_embedding_max_pos,
|
||||||
|
causal_temporal_positioning=causal_temporal_positioning,
|
||||||
|
vae_scale_factors=vae_scale_factors,
|
||||||
|
use_middle_indices_grid=use_middle_indices_grid,
|
||||||
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_model_components(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize LTXAV-specific components."""
|
||||||
|
# Audio-specific projections
|
||||||
|
self.audio_patchify_proj = self.operations.Linear(
|
||||||
|
self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audio-specific AdaLN
|
||||||
|
self.audio_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.audio_inner_dim,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_scale_shift_values = 4
|
||||||
|
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.inner_dim,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
embedding_coefficient=num_scale_shift_values,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.inner_dim,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
embedding_coefficient=1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.audio_inner_dim,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
embedding_coefficient=num_scale_shift_values,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.audio_inner_dim,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
embedding_coefficient=1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audio caption projection
|
||||||
|
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||||
|
in_features=self.caption_channels,
|
||||||
|
hidden_size=self.audio_inner_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize transformer blocks for LTXAV."""
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicAVTransformerBlock(
|
||||||
|
v_dim=self.inner_dim,
|
||||||
|
a_dim=self.audio_inner_dim,
|
||||||
|
v_heads=self.num_attention_heads,
|
||||||
|
a_heads=self.audio_num_attention_heads,
|
||||||
|
vd_head=self.attention_head_dim,
|
||||||
|
ad_head=self.audio_attention_head_dim,
|
||||||
|
v_context_dim=self.cross_attention_dim,
|
||||||
|
a_context_dim=self.audio_cross_attention_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_output_components(self, device, dtype):
|
||||||
|
"""Initialize output components for LTXAV."""
|
||||||
|
# Video output components
|
||||||
|
super()._init_output_components(device, dtype)
|
||||||
|
# Audio output components
|
||||||
|
self.audio_scale_shift_table = nn.Parameter(
|
||||||
|
torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
self.audio_norm_out = self.operations.LayerNorm(
|
||||||
|
self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.audio_proj_out = self.operations.Linear(
|
||||||
|
self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.a_patchifier = AudioPatchifier(1, start_end=True)
|
||||||
|
|
||||||
|
def separate_audio_and_video_latents(self, x, audio_length):
|
||||||
|
"""Separate audio and video latents from combined input."""
|
||||||
|
# vx = x[:, : self.in_channels]
|
||||||
|
# ax = x[:, self.in_channels :]
|
||||||
|
#
|
||||||
|
# ax = ax.reshape(ax.shape[0], -1)
|
||||||
|
# ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins]
|
||||||
|
#
|
||||||
|
# ax = ax.reshape(
|
||||||
|
# ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins
|
||||||
|
# )
|
||||||
|
|
||||||
|
vx = x[0]
|
||||||
|
ax = x[1] if len(x) > 1 else torch.zeros(
|
||||||
|
(vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins),
|
||||||
|
device=vx.device, dtype=vx.dtype
|
||||||
|
)
|
||||||
|
return vx, ax
|
||||||
|
|
||||||
|
def recombine_audio_and_video_latents(self, vx, ax, target_shape=None):
|
||||||
|
if ax.numel() == 0:
|
||||||
|
return vx
|
||||||
|
else:
|
||||||
|
return [vx, ax]
|
||||||
|
"""Recombine audio and video latents for output."""
|
||||||
|
# if ax.device != vx.device or ax.dtype != vx.dtype:
|
||||||
|
# logging.warning("Audio and video latents are on different devices or dtypes.")
|
||||||
|
# ax = ax.to(device=vx.device, dtype=vx.dtype)
|
||||||
|
# logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}")
|
||||||
|
#
|
||||||
|
# ax = ax.reshape(ax.shape[0], -1)
|
||||||
|
# # pad to f x h x w of the video latents
|
||||||
|
# divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3]
|
||||||
|
# if target_shape is None:
|
||||||
|
# repetitions = math.ceil(ax.shape[-1] / divisor)
|
||||||
|
# else:
|
||||||
|
# repetitions = target_shape[1] - vx.shape[1]
|
||||||
|
# padded_len = repetitions * divisor
|
||||||
|
# ax = F.pad(ax, (0, padded_len - ax.shape[-1]))
|
||||||
|
# ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1])
|
||||||
|
# return torch.cat([vx, ax], dim=1)
|
||||||
|
|
||||||
|
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||||
|
"""Process input for LTXAV - separate audio and video, then patchify."""
|
||||||
|
audio_length = kwargs.get("audio_length", 0)
|
||||||
|
# Separate audio and video latents
|
||||||
|
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||||
|
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||||
|
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||||
|
ax = self.audio_patchify_proj(ax)
|
||||||
|
|
||||||
|
# additional_args.update({"av_orig_shape": list(x.shape)})
|
||||||
|
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
|
||||||
|
|
||||||
|
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||||
|
"""Prepare timestep embeddings."""
|
||||||
|
# TODO: some code reuse is needed here.
|
||||||
|
grid_mask = kwargs.get("grid_mask", None)
|
||||||
|
if grid_mask is not None:
|
||||||
|
timestep = timestep[:, grid_mask]
|
||||||
|
|
||||||
|
timestep = timestep * self.timestep_scale_multiplier
|
||||||
|
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||||
|
timestep.flatten(),
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||||
|
v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1])
|
||||||
|
v_embedded_timestep = v_embedded_timestep.view(
|
||||||
|
batch_size, -1, v_embedded_timestep.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare audio timestep
|
||||||
|
a_timestep = kwargs.get("a_timestep")
|
||||||
|
if a_timestep is not None:
|
||||||
|
a_timestep = a_timestep * self.timestep_scale_multiplier
|
||||||
|
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
||||||
|
|
||||||
|
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||||
|
a_timestep.flatten(),
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||||
|
timestep.flatten(),
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||||
|
timestep.flatten() * av_ca_factor,
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||||
|
a_timestep.flatten() * av_ca_factor,
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
||||||
|
a_timestep.flatten(),
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||||
|
a_embedded_timestep = a_embedded_timestep.view(
|
||||||
|
batch_size, -1, a_embedded_timestep.shape[-1]
|
||||||
|
)
|
||||||
|
cross_av_timestep_ss = [
|
||||||
|
av_ca_audio_scale_shift_timestep,
|
||||||
|
av_ca_video_scale_shift_timestep,
|
||||||
|
av_ca_a2v_gate_noise_timestep,
|
||||||
|
av_ca_v2a_gate_noise_timestep,
|
||||||
|
]
|
||||||
|
cross_av_timestep_ss = list(
|
||||||
|
[t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
a_timestep = timestep
|
||||||
|
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||||
|
cross_av_timestep_ss = []
|
||||||
|
|
||||||
|
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
||||||
|
v_embedded_timestep,
|
||||||
|
a_embedded_timestep,
|
||||||
|
]
|
||||||
|
|
||||||
|
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||||
|
vx = x[0]
|
||||||
|
ax = x[1]
|
||||||
|
v_context, a_context = torch.split(
|
||||||
|
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
v_context, attention_mask = super()._prepare_context(
|
||||||
|
v_context, batch_size, vx, attention_mask
|
||||||
|
)
|
||||||
|
if self.audio_caption_projection is not None:
|
||||||
|
a_context = self.audio_caption_projection(a_context)
|
||||||
|
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
||||||
|
|
||||||
|
return [v_context, a_context], attention_mask
|
||||||
|
|
||||||
|
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
||||||
|
v_pixel_coords = pixel_coords[0]
|
||||||
|
v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype)
|
||||||
|
|
||||||
|
a_latent_coords = pixel_coords[1]
|
||||||
|
a_pe = self._precompute_freqs_cis(
|
||||||
|
a_latent_coords,
|
||||||
|
dim=self.audio_inner_dim,
|
||||||
|
out_dtype=x_dtype,
|
||||||
|
max_pos=self.audio_positional_embedding_max_pos,
|
||||||
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||||
|
num_attention_heads=self.audio_num_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate positional embeddings for the middle of the token duration, to use in av cross attention layers.
|
||||||
|
max_pos = max(
|
||||||
|
self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]
|
||||||
|
)
|
||||||
|
v_pixel_coords = v_pixel_coords.to(torch.float32)
|
||||||
|
v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate)
|
||||||
|
av_cross_video_freq_cis = self._precompute_freqs_cis(
|
||||||
|
v_pixel_coords[:, 0:1, :],
|
||||||
|
dim=self.audio_cross_attention_dim,
|
||||||
|
out_dtype=x_dtype,
|
||||||
|
max_pos=[max_pos],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=self.audio_num_attention_heads,
|
||||||
|
)
|
||||||
|
av_cross_audio_freq_cis = self._precompute_freqs_cis(
|
||||||
|
a_latent_coords[:, 0:1, :],
|
||||||
|
dim=self.audio_cross_attention_dim,
|
||||||
|
out_dtype=x_dtype,
|
||||||
|
max_pos=[max_pos],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=self.audio_num_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||||
|
|
||||||
|
def _process_transformer_blocks(
|
||||||
|
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
||||||
|
):
|
||||||
|
vx = x[0]
|
||||||
|
ax = x[1]
|
||||||
|
v_context = context[0]
|
||||||
|
a_context = context[1]
|
||||||
|
v_timestep = timestep[0]
|
||||||
|
a_timestep = timestep[1]
|
||||||
|
v_pe, av_cross_video_freq_cis = pe[0]
|
||||||
|
a_pe, av_cross_audio_freq_cis = pe[1]
|
||||||
|
|
||||||
|
(
|
||||||
|
av_ca_audio_scale_shift_timestep,
|
||||||
|
av_ca_video_scale_shift_timestep,
|
||||||
|
av_ca_a2v_gate_noise_timestep,
|
||||||
|
av_ca_v2a_gate_noise_timestep,
|
||||||
|
) = timestep[2]
|
||||||
|
|
||||||
|
"""Process transformer blocks for LTXAV."""
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
|
# Process transformer blocks
|
||||||
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(
|
||||||
|
args["img"],
|
||||||
|
v_context=args["v_context"],
|
||||||
|
a_context=args["a_context"],
|
||||||
|
attention_mask=args["attention_mask"],
|
||||||
|
v_timestep=args["v_timestep"],
|
||||||
|
a_timestep=args["a_timestep"],
|
||||||
|
v_pe=args["v_pe"],
|
||||||
|
a_pe=args["a_pe"],
|
||||||
|
v_cross_pe=args["v_cross_pe"],
|
||||||
|
a_cross_pe=args["a_cross_pe"],
|
||||||
|
v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"],
|
||||||
|
a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"],
|
||||||
|
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||||
|
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||||
|
transformer_options=args["transformer_options"],
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("double_block", i)](
|
||||||
|
{
|
||||||
|
"img": (vx, ax),
|
||||||
|
"v_context": v_context,
|
||||||
|
"a_context": a_context,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"v_timestep": v_timestep,
|
||||||
|
"a_timestep": a_timestep,
|
||||||
|
"v_pe": v_pe,
|
||||||
|
"a_pe": a_pe,
|
||||||
|
"v_cross_pe": av_cross_video_freq_cis,
|
||||||
|
"a_cross_pe": av_cross_audio_freq_cis,
|
||||||
|
"v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep,
|
||||||
|
"a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep,
|
||||||
|
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||||
|
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||||
|
"transformer_options": transformer_options,
|
||||||
|
},
|
||||||
|
{"original_block": block_wrap},
|
||||||
|
)
|
||||||
|
vx, ax = out["img"]
|
||||||
|
else:
|
||||||
|
vx, ax = block(
|
||||||
|
(vx, ax),
|
||||||
|
v_context=v_context,
|
||||||
|
a_context=a_context,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
v_timestep=v_timestep,
|
||||||
|
a_timestep=a_timestep,
|
||||||
|
v_pe=v_pe,
|
||||||
|
a_pe=a_pe,
|
||||||
|
v_cross_pe=av_cross_video_freq_cis,
|
||||||
|
a_cross_pe=av_cross_audio_freq_cis,
|
||||||
|
v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep,
|
||||||
|
a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep,
|
||||||
|
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||||
|
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [vx, ax]
|
||||||
|
|
||||||
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||||
|
vx = x[0]
|
||||||
|
ax = x[1]
|
||||||
|
v_embedded_timestep = embedded_timestep[0]
|
||||||
|
a_embedded_timestep = embedded_timestep[1]
|
||||||
|
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
||||||
|
|
||||||
|
# Process audio output
|
||||||
|
a_scale_shift_values = (
|
||||||
|
self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype)
|
||||||
|
+ a_embedded_timestep[:, :, None]
|
||||||
|
)
|
||||||
|
a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1]
|
||||||
|
|
||||||
|
ax = self.audio_norm_out(ax)
|
||||||
|
ax = ax * (1 + a_scale) + a_shift
|
||||||
|
ax = self.audio_proj_out(ax)
|
||||||
|
|
||||||
|
# Unpatchify audio
|
||||||
|
ax = self.a_patchifier.unpatchify(
|
||||||
|
ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recombine audio and video
|
||||||
|
original_shape = kwargs.get("av_orig_shape")
|
||||||
|
return self.recombine_audio_and_video_latents(vx, ax, original_shape)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
context,
|
||||||
|
attention_mask=None,
|
||||||
|
frame_rate=25,
|
||||||
|
transformer_options={},
|
||||||
|
keyframe_idxs=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass for LTXAV model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Combined audio-video input tensor
|
||||||
|
timestep: Tuple of (video_timestep, audio_timestep) or single timestep
|
||||||
|
context: Context tensor (e.g., text embeddings)
|
||||||
|
attention_mask: Attention mask tensor
|
||||||
|
frame_rate: Frame rate for temporal processing
|
||||||
|
transformer_options: Additional options for transformer blocks
|
||||||
|
keyframe_idxs: Keyframe indices for temporal processing
|
||||||
|
**kwargs: Additional keyword arguments including audio_length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined audio-video output tensor
|
||||||
|
"""
|
||||||
|
# Handle timestep format
|
||||||
|
if isinstance(timestep, (tuple, list)) and len(timestep) == 2:
|
||||||
|
v_timestep, a_timestep = timestep
|
||||||
|
kwargs["a_timestep"] = a_timestep
|
||||||
|
timestep = v_timestep
|
||||||
|
else:
|
||||||
|
kwargs["a_timestep"] = timestep
|
||||||
|
|
||||||
|
# Call parent forward method
|
||||||
|
return super().forward(
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
context,
|
||||||
|
attention_mask,
|
||||||
|
frame_rate,
|
||||||
|
transformer_options,
|
||||||
|
keyframe_idxs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
305
comfy/ldm/lightricks/embeddings_connector.py
Normal file
305
comfy/ldm/lightricks/embeddings_connector.py
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
import torch
|
||||||
|
from comfy.ldm.lightricks.model import (
|
||||||
|
CrossAttention,
|
||||||
|
FeedForward,
|
||||||
|
generate_freq_grid_np,
|
||||||
|
interleaved_freqs_cis,
|
||||||
|
split_freqs_cis,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTransformerBlock1D(nn.Module):
|
||||||
|
r"""
|
||||||
|
A basic Transformer block.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
|
||||||
|
dim (`int`): The number of channels in the input and output.
|
||||||
|
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||||
|
attention_head_dim (`int`): The number of channels in each head.
|
||||||
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||||
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||||
|
attention_bias (:
|
||||||
|
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||||
|
upcast_attention (`bool`, *optional*):
|
||||||
|
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||||
|
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use learnable elementwise affine parameters for normalization.
|
||||||
|
standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
|
||||||
|
norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers.
|
||||||
|
qk_norm (`str`, *optional*, defaults to None):
|
||||||
|
Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
|
||||||
|
final_dropout (`bool` *optional*, defaults to False):
|
||||||
|
Whether to apply a final dropout after the last feed-forward layer.
|
||||||
|
ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`.
|
||||||
|
ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer.
|
||||||
|
attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer.
|
||||||
|
use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE).
|
||||||
|
ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
context_dim=None,
|
||||||
|
attn_precision=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Define 3 blocks. Each block has its own normalization layer.
|
||||||
|
# 1. Self-Attn
|
||||||
|
self.attn1 = CrossAttention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
context_dim=None,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Feed-forward
|
||||||
|
self.ff = FeedForward(
|
||||||
|
dim,
|
||||||
|
dim_out=dim,
|
||||||
|
glu=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor:
|
||||||
|
|
||||||
|
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||||
|
|
||||||
|
# 1. Normalization Before Self-Attention
|
||||||
|
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||||
|
|
||||||
|
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
# 2. Self-Attention
|
||||||
|
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
|
||||||
|
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
if hidden_states.ndim == 4:
|
||||||
|
hidden_states = hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
# 3. Normalization before Feed-Forward
|
||||||
|
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||||
|
|
||||||
|
# 4. Feed-forward
|
||||||
|
ff_output = self.ff(norm_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = ff_output + hidden_states
|
||||||
|
if hidden_states.ndim == 4:
|
||||||
|
hidden_states = hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Embeddings1DConnector(nn.Module):
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=128,
|
||||||
|
cross_attention_dim=2048,
|
||||||
|
attention_head_dim=128,
|
||||||
|
num_attention_heads=30,
|
||||||
|
num_layers=2,
|
||||||
|
positional_embedding_theta=10000.0,
|
||||||
|
positional_embedding_max_pos=[4096],
|
||||||
|
causal_temporal_positioning=False,
|
||||||
|
num_learnable_registers: Optional[int] = 128,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
split_rope=False,
|
||||||
|
double_precision_rope=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.out_channels = in_channels
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.causal_temporal_positioning = causal_temporal_positioning
|
||||||
|
self.positional_embedding_theta = positional_embedding_theta
|
||||||
|
self.positional_embedding_max_pos = positional_embedding_max_pos
|
||||||
|
self.split_rope = split_rope
|
||||||
|
self.double_precision_rope = double_precision_rope
|
||||||
|
self.transformer_1d_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicTransformerBlock1D(
|
||||||
|
self.inner_dim,
|
||||||
|
num_attention_heads,
|
||||||
|
attention_head_dim,
|
||||||
|
context_dim=cross_attention_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.num_learnable_registers = num_learnable_registers
|
||||||
|
if self.num_learnable_registers:
|
||||||
|
self.learnable_registers = nn.Parameter(
|
||||||
|
torch.rand(
|
||||||
|
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
* 2.0
|
||||||
|
- 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_fractional_positions(self, indices_grid):
|
||||||
|
fractional_positions = torch.stack(
|
||||||
|
[
|
||||||
|
indices_grid[:, i] / self.positional_embedding_max_pos[i]
|
||||||
|
for i in range(1)
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
return fractional_positions
|
||||||
|
|
||||||
|
def precompute_freqs(self, indices_grid, spacing):
|
||||||
|
source_dtype = indices_grid.dtype
|
||||||
|
dtype = (
|
||||||
|
torch.float32
|
||||||
|
if source_dtype in (torch.bfloat16, torch.float16)
|
||||||
|
else source_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
fractional_positions = self.get_fractional_positions(indices_grid)
|
||||||
|
indices = (
|
||||||
|
generate_freq_grid_np(
|
||||||
|
self.positional_embedding_theta,
|
||||||
|
indices_grid.shape[1],
|
||||||
|
self.inner_dim,
|
||||||
|
)
|
||||||
|
if self.double_precision_rope
|
||||||
|
else self.generate_freq_grid(spacing, dtype, fractional_positions.device)
|
||||||
|
).to(device=fractional_positions.device)
|
||||||
|
|
||||||
|
if spacing == "exp_2":
|
||||||
|
freqs = (
|
||||||
|
(indices * fractional_positions.unsqueeze(-1))
|
||||||
|
.transpose(-1, -2)
|
||||||
|
.flatten(2)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
freqs = (
|
||||||
|
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||||
|
.transpose(-1, -2)
|
||||||
|
.flatten(2)
|
||||||
|
)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
def generate_freq_grid(self, spacing, dtype, device):
|
||||||
|
dim = self.inner_dim
|
||||||
|
theta = self.positional_embedding_theta
|
||||||
|
n_pos_dims = 1
|
||||||
|
n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6
|
||||||
|
start = 1
|
||||||
|
end = theta
|
||||||
|
|
||||||
|
if spacing == "exp":
|
||||||
|
indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem))
|
||||||
|
indices = indices.to(dtype=dtype, device=device)
|
||||||
|
elif spacing == "exp_2":
|
||||||
|
indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim)
|
||||||
|
indices = indices.to(dtype=dtype)
|
||||||
|
elif spacing == "linear":
|
||||||
|
indices = torch.linspace(
|
||||||
|
start, end, dim // n_elem, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
elif spacing == "sqrt":
|
||||||
|
indices = torch.linspace(
|
||||||
|
start**2, end**2, dim // n_elem, device=device, dtype=dtype
|
||||||
|
).sqrt()
|
||||||
|
|
||||||
|
indices = indices * math.pi / 2
|
||||||
|
|
||||||
|
return indices
|
||||||
|
|
||||||
|
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
||||||
|
dim = self.inner_dim
|
||||||
|
n_elem = 2 # 2 because of cos and sin
|
||||||
|
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||||
|
if self.split_rope:
|
||||||
|
expected_freqs = dim // 2
|
||||||
|
current_freqs = freqs.shape[-1]
|
||||||
|
pad_size = expected_freqs - current_freqs
|
||||||
|
cos_freq, sin_freq = split_freqs_cis(
|
||||||
|
freqs, pad_size, self.num_attention_heads
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||||
|
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
The [`Transformer2DModel`] forward method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||||
|
Input `hidden_states`.
|
||||||
|
indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
|
||||||
|
attention_mask ( `torch.Tensor`, *optional*):
|
||||||
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||||
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||||
|
negative values to the attention scores corresponding to "discard" tokens.
|
||||||
|
Returns:
|
||||||
|
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||||
|
`tuple` where the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
# 1. Input
|
||||||
|
|
||||||
|
if self.num_learnable_registers:
|
||||||
|
num_registers_duplications = math.ceil(
|
||||||
|
max(1024, hidden_states.shape[1]) / self.num_learnable_registers
|
||||||
|
)
|
||||||
|
learnable_registers = torch.tile(
|
||||||
|
self.learnable_registers.to(hidden_states), (num_registers_duplications, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device)
|
||||||
|
|
||||||
|
indices_grid = torch.arange(
|
||||||
|
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
indices_grid = indices_grid[None, None, :]
|
||||||
|
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
||||||
|
|
||||||
|
# 2. Blocks
|
||||||
|
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||||
|
hidden_states = block(
|
||||||
|
hidden_states, attention_mask=attention_mask, pe=freqs_cis
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Output
|
||||||
|
# if self.output_scale is not None:
|
||||||
|
# hidden_states = hidden_states / self.output_scale
|
||||||
|
|
||||||
|
hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, attention_mask
|
||||||
292
comfy/ldm/lightricks/latent_upsampler.py
Normal file
292
comfy/ldm/lightricks/latent_upsampler.py
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||||
|
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
|
||||||
|
if float(scale) not in mapping:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}"
|
||||||
|
)
|
||||||
|
return mapping[float(scale)]
|
||||||
|
|
||||||
|
|
||||||
|
class PixelShuffleND(nn.Module):
|
||||||
|
def __init__(self, dims, upscale_factors=(2, 2, 2)):
|
||||||
|
super().__init__()
|
||||||
|
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
|
||||||
|
self.dims = dims
|
||||||
|
self.upscale_factors = upscale_factors
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.dims == 3:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
p2=self.upscale_factors[1],
|
||||||
|
p3=self.upscale_factors[2],
|
||||||
|
)
|
||||||
|
elif self.dims == 2:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
p2=self.upscale_factors[1],
|
||||||
|
)
|
||||||
|
elif self.dims == 1:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1) f h w -> b c (f p1) h w",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlurDownsample(nn.Module):
|
||||||
|
"""
|
||||||
|
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
|
||||||
|
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dims: int, stride: int):
|
||||||
|
super().__init__()
|
||||||
|
assert dims in (2, 3)
|
||||||
|
assert stride >= 1 and isinstance(stride, int)
|
||||||
|
self.dims = dims
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
# 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized
|
||||||
|
k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0])
|
||||||
|
k2d = k[:, None] @ k[None, :]
|
||||||
|
k2d = (k2d / k2d.sum()).float() # shape (5,5)
|
||||||
|
self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.stride == 1:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _apply_2d(x2d: torch.Tensor) -> torch.Tensor:
|
||||||
|
# x2d: (B, C, H, W)
|
||||||
|
B, C, H, W = x2d.shape
|
||||||
|
weight = self.kernel.expand(C, 1, 5, 5) # depthwise
|
||||||
|
x2d = F.conv2d(
|
||||||
|
x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C
|
||||||
|
)
|
||||||
|
return x2d
|
||||||
|
|
||||||
|
if self.dims == 2:
|
||||||
|
return _apply_2d(x)
|
||||||
|
else:
|
||||||
|
# dims == 3: apply per-frame on H,W
|
||||||
|
b, c, f, h, w = x.shape
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = _apply_2d(x)
|
||||||
|
h2, w2 = x.shape[-2:]
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialRationalResampler(nn.Module):
|
||||||
|
"""
|
||||||
|
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
|
||||||
|
downsample by 'den' using fixed blur + stride. Operates on H,W only.
|
||||||
|
|
||||||
|
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mid_channels: int, scale: float):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.num, self.den = _rational_for_scale(self.scale)
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1
|
||||||
|
)
|
||||||
|
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
|
||||||
|
self.blur_down = BlurDownsample(dims=2, stride=self.den)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
b, c, f, h, w = x.shape
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.pixel_shuffle(x)
|
||||||
|
x = self.blur_down(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if mid_channels is None:
|
||||||
|
mid_channels = channels
|
||||||
|
|
||||||
|
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
|
||||||
|
|
||||||
|
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.norm1 = nn.GroupNorm(32, mid_channels)
|
||||||
|
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||||
|
self.norm2 = nn.GroupNorm(32, channels)
|
||||||
|
self.activation = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
residual = x
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.activation(x + residual)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LatentUpsampler(nn.Module):
|
||||||
|
"""
|
||||||
|
Model to spatially upsample VAE latents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (`int`): Number of channels in the input latent
|
||||||
|
mid_channels (`int`): Number of channels in the middle layers
|
||||||
|
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||||
|
dims (`int`): Number of dimensions for convolutions (2 or 3)
|
||||||
|
spatial_upsample (`bool`): Whether to spatially upsample the latent
|
||||||
|
temporal_upsample (`bool`): Whether to temporally upsample the latent
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 128,
|
||||||
|
mid_channels: int = 512,
|
||||||
|
num_blocks_per_stage: int = 4,
|
||||||
|
dims: int = 3,
|
||||||
|
spatial_upsample: bool = True,
|
||||||
|
temporal_upsample: bool = False,
|
||||||
|
spatial_scale: float = 2.0,
|
||||||
|
rational_resampler: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.mid_channels = mid_channels
|
||||||
|
self.num_blocks_per_stage = num_blocks_per_stage
|
||||||
|
self.dims = dims
|
||||||
|
self.spatial_upsample = spatial_upsample
|
||||||
|
self.temporal_upsample = temporal_upsample
|
||||||
|
self.spatial_scale = float(spatial_scale)
|
||||||
|
self.rational_resampler = rational_resampler
|
||||||
|
|
||||||
|
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
|
||||||
|
|
||||||
|
self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.initial_norm = nn.GroupNorm(32, mid_channels)
|
||||||
|
self.initial_activation = nn.SiLU()
|
||||||
|
|
||||||
|
self.res_blocks = nn.ModuleList(
|
||||||
|
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||||
|
)
|
||||||
|
|
||||||
|
if spatial_upsample and temporal_upsample:
|
||||||
|
self.upsampler = nn.Sequential(
|
||||||
|
nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(3),
|
||||||
|
)
|
||||||
|
elif spatial_upsample:
|
||||||
|
if rational_resampler:
|
||||||
|
self.upsampler = SpatialRationalResampler(
|
||||||
|
mid_channels=mid_channels, scale=self.spatial_scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.upsampler = nn.Sequential(
|
||||||
|
nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(2),
|
||||||
|
)
|
||||||
|
elif temporal_upsample:
|
||||||
|
self.upsampler = nn.Sequential(
|
||||||
|
nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Either spatial_upsample or temporal_upsample must be True"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.post_upsample_res_blocks = nn.ModuleList(
|
||||||
|
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
||||||
|
b, c, f, h, w = latent.shape
|
||||||
|
|
||||||
|
if self.dims == 2:
|
||||||
|
x = rearrange(latent, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.initial_conv(x)
|
||||||
|
x = self.initial_norm(x)
|
||||||
|
x = self.initial_activation(x)
|
||||||
|
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.upsampler(x)
|
||||||
|
|
||||||
|
for block in self.post_upsample_res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
else:
|
||||||
|
x = self.initial_conv(latent)
|
||||||
|
x = self.initial_norm(x)
|
||||||
|
x = self.initial_activation(x)
|
||||||
|
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
if self.temporal_upsample:
|
||||||
|
x = self.upsampler(x)
|
||||||
|
x = x[:, :, 1:, :, :]
|
||||||
|
else:
|
||||||
|
if isinstance(self.upsampler, SpatialRationalResampler):
|
||||||
|
x = self.upsampler(x)
|
||||||
|
else:
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.upsampler(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
|
||||||
|
for block in self.post_upsample_res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
return cls(
|
||||||
|
in_channels=config.get("in_channels", 4),
|
||||||
|
mid_channels=config.get("mid_channels", 128),
|
||||||
|
num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
|
||||||
|
dims=config.get("dims", 2),
|
||||||
|
spatial_upsample=config.get("spatial_upsample", True),
|
||||||
|
temporal_upsample=config.get("temporal_upsample", False),
|
||||||
|
spatial_scale=config.get("spatial_scale", 2.0),
|
||||||
|
rational_resampler=config.get("rational_resampler", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def config(self):
|
||||||
|
return {
|
||||||
|
"_class_name": "LatentUpsampler",
|
||||||
|
"in_channels": self.in_channels,
|
||||||
|
"mid_channels": self.mid_channels,
|
||||||
|
"num_blocks_per_stage": self.num_blocks_per_stage,
|
||||||
|
"dims": self.dims,
|
||||||
|
"spatial_upsample": self.spatial_upsample,
|
||||||
|
"temporal_upsample": self.temporal_upsample,
|
||||||
|
"spatial_scale": self.spatial_scale,
|
||||||
|
"rational_resampler": self.rational_resampler,
|
||||||
|
}
|
||||||
@ -1,13 +1,47 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
import functools
|
||||||
|
import math
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.ldm.modules.attention
|
import comfy.ldm.modules.attention
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import math
|
|
||||||
from typing import Dict, Optional, Tuple
|
|
||||||
|
|
||||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||||
from comfy.ldm.flux.math import apply_rope1
|
|
||||||
|
def _log_base(x, base):
|
||||||
|
return np.log(x) / np.log(base)
|
||||||
|
|
||||||
|
class LTXRopeType(str, Enum):
|
||||||
|
INTERLEAVED = "interleaved"
|
||||||
|
SPLIT = "split"
|
||||||
|
|
||||||
|
KEY = "rope_type"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, kwargs, default=None):
|
||||||
|
if default is None:
|
||||||
|
default = cls.INTERLEAVED
|
||||||
|
return cls(kwargs.get(cls.KEY, default))
|
||||||
|
|
||||||
|
|
||||||
|
class LTXFrequenciesPrecision(str, Enum):
|
||||||
|
FLOAT32 = "float32"
|
||||||
|
FLOAT64 = "float64"
|
||||||
|
|
||||||
|
KEY = "frequencies_precision"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, kwargs, default=None):
|
||||||
|
if default is None:
|
||||||
|
default = cls.FLOAT32
|
||||||
|
return cls(kwargs.get(cls.KEY, default))
|
||||||
|
|
||||||
|
|
||||||
def get_timestep_embedding(
|
def get_timestep_embedding(
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
@ -39,9 +73,7 @@ def get_timestep_embedding(
|
|||||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||||
|
|
||||||
half_dim = embedding_dim // 2
|
half_dim = embedding_dim // 2
|
||||||
exponent = -math.log(max_period) * torch.arange(
|
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
||||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
|
||||||
)
|
|
||||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||||
|
|
||||||
emb = torch.exp(exponent)
|
emb = torch.exp(exponent)
|
||||||
@ -73,7 +105,9 @@ class TimestepEmbedding(nn.Module):
|
|||||||
post_act_fn: Optional[str] = None,
|
post_act_fn: Optional[str] = None,
|
||||||
cond_proj_dim=None,
|
cond_proj_dim=None,
|
||||||
sample_proj_bias=True,
|
sample_proj_bias=True,
|
||||||
dtype=None, device=None, operations=None,
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -90,7 +124,9 @@ class TimestepEmbedding(nn.Module):
|
|||||||
time_embed_dim_out = out_dim
|
time_embed_dim_out = out_dim
|
||||||
else:
|
else:
|
||||||
time_embed_dim_out = time_embed_dim
|
time_embed_dim_out = time_embed_dim
|
||||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
|
self.linear_2 = operations.Linear(
|
||||||
|
time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
if post_act_fn is None:
|
if post_act_fn is None:
|
||||||
self.post_act = None
|
self.post_act = None
|
||||||
@ -139,12 +175,22 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
|||||||
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim,
|
||||||
|
size_emb_dim,
|
||||||
|
use_additional_conditions: bool = False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.outdim = size_emb_dim
|
self.outdim = size_emb_dim
|
||||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
|
self.timestep_embedder = TimestepEmbedding(
|
||||||
|
in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
||||||
timesteps_proj = self.time_proj(timestep)
|
timesteps_proj = self.time_proj(timestep)
|
||||||
@ -163,15 +209,22 @@ class AdaLayerNormSingle(nn.Module):
|
|||||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
def __init__(
|
||||||
|
self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||||
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
|
embedding_dim,
|
||||||
|
size_emb_dim=embedding_dim // 3,
|
||||||
|
use_additional_conditions=use_additional_conditions,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.silu = nn.SiLU()
|
self.silu = nn.SiLU()
|
||||||
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
|
self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -185,6 +238,7 @@ class AdaLayerNormSingle(nn.Module):
|
|||||||
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
||||||
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
||||||
|
|
||||||
|
|
||||||
class PixArtAlphaTextProjection(nn.Module):
|
class PixArtAlphaTextProjection(nn.Module):
|
||||||
"""
|
"""
|
||||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||||
@ -192,18 +246,24 @@ class PixArtAlphaTextProjection(nn.Module):
|
|||||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
|
def __init__(
|
||||||
|
self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if out_features is None:
|
if out_features is None:
|
||||||
out_features = hidden_size
|
out_features = hidden_size
|
||||||
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
|
self.linear_1 = operations.Linear(
|
||||||
|
in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
if act_fn == "gelu_tanh":
|
if act_fn == "gelu_tanh":
|
||||||
self.act_1 = nn.GELU(approximate="tanh")
|
self.act_1 = nn.GELU(approximate="tanh")
|
||||||
elif act_fn == "silu":
|
elif act_fn == "silu":
|
||||||
self.act_1 = nn.SiLU()
|
self.act_1 = nn.SiLU()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||||
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
|
self.linear_2 = operations.Linear(
|
||||||
|
in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, caption):
|
def forward(self, caption):
|
||||||
hidden_states = self.linear_1(caption)
|
hidden_states = self.linear_1(caption)
|
||||||
@ -222,23 +282,68 @@ class GELU_approx(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
|
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = int(dim * mult)
|
||||||
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
project_in,
|
project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||||
nn.Dropout(dropout),
|
|
||||||
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
def apply_rotary_emb(input_tensor, freqs_cis):
|
||||||
|
cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1]
|
||||||
|
split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False
|
||||||
|
return (
|
||||||
|
apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
||||||
|
if split_pe else
|
||||||
|
apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one
|
||||||
|
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||||
|
t1, t2 = t_dup.unbind(dim=-1)
|
||||||
|
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||||
|
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||||
|
|
||||||
|
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def apply_split_rotary_emb(input_tensor, cos, sin):
|
||||||
|
needs_reshape = False
|
||||||
|
if input_tensor.ndim != 4 and cos.ndim == 4:
|
||||||
|
B, H, T, _ = cos.shape
|
||||||
|
input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2)
|
||||||
|
needs_reshape = True
|
||||||
|
split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2)
|
||||||
|
first_half_input = split_input[..., :1, :]
|
||||||
|
second_half_input = split_input[..., 1:, :]
|
||||||
|
output = split_input * cos.unsqueeze(-2)
|
||||||
|
first_half_output = output[..., :1, :]
|
||||||
|
second_half_output = output[..., 1:, :]
|
||||||
|
first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input)
|
||||||
|
second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input)
|
||||||
|
output = rearrange(output, "... d r -> ... (d r)")
|
||||||
|
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim,
|
||||||
|
context_dim=None,
|
||||||
|
heads=8,
|
||||||
|
dim_head=64,
|
||||||
|
dropout=0.0,
|
||||||
|
attn_precision=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = query_dim if context_dim is None else context_dim
|
context_dim = query_dim if context_dim is None else context_dim
|
||||||
@ -254,9 +359,11 @@ class CrossAttention(nn.Module):
|
|||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(
|
||||||
|
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
|
def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}):
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = x if context is None else context
|
context = x if context is None else context
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
@ -266,8 +373,8 @@ class CrossAttention(nn.Module):
|
|||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
|
|
||||||
if pe is not None:
|
if pe is not None:
|
||||||
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
|
q = apply_rotary_emb(q, pe)
|
||||||
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
|
k = apply_rotary_emb(k, pe if k_pe is None else k_pe)
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
@ -277,14 +384,34 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
|
def __init__(
|
||||||
|
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attn_precision = attn_precision
|
self.attn_precision = attn_precision
|
||||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
self.attn1 = CrossAttention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
context_dim=None,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
self.attn2 = CrossAttention(
|
||||||
|
query_dim=dim,
|
||||||
|
context_dim=context_dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
@ -306,116 +433,446 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def get_fractional_positions(indices_grid, max_pos):
|
def get_fractional_positions(indices_grid, max_pos):
|
||||||
|
n_pos_dims = indices_grid.shape[1]
|
||||||
|
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
||||||
fractional_positions = torch.stack(
|
fractional_positions = torch.stack(
|
||||||
[
|
[indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],
|
||||||
indices_grid[:, i] / max_pos[i]
|
axis=-1,
|
||||||
for i in range(3)
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
)
|
||||||
return fractional_positions
|
return fractional_positions
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
@functools.lru_cache(maxsize=5)
|
||||||
dtype = torch.float32
|
def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None):
|
||||||
device = indices_grid.device
|
theta = positional_embedding_theta
|
||||||
|
start = 1
|
||||||
|
end = theta
|
||||||
|
|
||||||
|
n_elem = 2 * positional_embedding_max_pos_count
|
||||||
|
pow_indices = np.power(
|
||||||
|
theta,
|
||||||
|
np.linspace(
|
||||||
|
_log_base(start, theta),
|
||||||
|
_log_base(end, theta),
|
||||||
|
inner_dim // n_elem,
|
||||||
|
dtype=np.float64,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)
|
||||||
|
|
||||||
|
def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device):
|
||||||
|
theta = positional_embedding_theta
|
||||||
|
start = 1
|
||||||
|
end = theta
|
||||||
|
n_elem = 2 * positional_embedding_max_pos_count
|
||||||
|
|
||||||
|
indices = theta ** (
|
||||||
|
torch.linspace(
|
||||||
|
math.log(start, theta),
|
||||||
|
math.log(end, theta),
|
||||||
|
inner_dim // n_elem,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
indices = indices.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
indices = indices * math.pi / 2
|
||||||
|
|
||||||
|
return indices
|
||||||
|
|
||||||
|
def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid):
|
||||||
|
if use_middle_indices_grid:
|
||||||
|
assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2)
|
||||||
|
indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]
|
||||||
|
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
|
||||||
|
elif len(indices_grid.shape) == 4:
|
||||||
|
indices_grid = indices_grid[..., 0]
|
||||||
|
|
||||||
# Get fractional positions and compute frequency indices
|
# Get fractional positions and compute frequency indices
|
||||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||||
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
|
indices = indices.to(device=fractional_positions.device)
|
||||||
|
|
||||||
# Compute frequencies and apply cos/sin
|
freqs = (
|
||||||
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
|
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||||
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
|
.transpose(-1, -2)
|
||||||
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
|
.flatten(2)
|
||||||
|
)
|
||||||
|
return freqs
|
||||||
|
|
||||||
# Pad if dim is not divisible by 6
|
def interleaved_freqs_cis(freqs, pad_size):
|
||||||
if dim % 6 != 0:
|
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||||
padding_size = dim % 6
|
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||||
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
|
if pad_size != 0:
|
||||||
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
|
cos_padding = torch.ones_like(cos_freq[:, :, : pad_size])
|
||||||
|
sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size])
|
||||||
|
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||||
|
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||||
|
return cos_freq, sin_freq
|
||||||
|
|
||||||
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
|
def split_freqs_cis(freqs, pad_size, num_attention_heads):
|
||||||
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
cos_freq = freqs.cos()
|
||||||
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
sin_freq = freqs.sin()
|
||||||
|
|
||||||
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
|
if pad_size != 0:
|
||||||
freqs_cis = torch.stack([
|
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
|
||||||
torch.stack([cos_vals, -sin_vals], dim=-1),
|
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
|
||||||
torch.stack([sin_vals, cos_vals], dim=-1)
|
|
||||||
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
|
|
||||||
|
|
||||||
return freqs_cis
|
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
|
||||||
|
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
|
||||||
|
|
||||||
|
# Reshape freqs to be compatible with multi-head attention
|
||||||
|
B , T, half_HD = cos_freq.shape
|
||||||
|
|
||||||
class LTXVModel(torch.nn.Module):
|
cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
||||||
def __init__(self,
|
sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
||||||
in_channels=128,
|
|
||||||
cross_attention_dim=2048,
|
|
||||||
attention_head_dim=64,
|
|
||||||
num_attention_heads=32,
|
|
||||||
|
|
||||||
caption_channels=4096,
|
cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
|
||||||
num_layers=28,
|
sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
|
||||||
|
return cos_freq, sin_freq
|
||||||
|
|
||||||
|
class LTXBaseModel(torch.nn.Module, ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for LTX models (Lightricks Transformer models).
|
||||||
|
|
||||||
positional_embedding_theta=10000.0,
|
This class defines the common interface and shared functionality for all LTX models,
|
||||||
positional_embedding_max_pos=[20, 2048, 2048],
|
including LTXV (video) and LTXAV (audio-video) variants.
|
||||||
causal_temporal_positioning=False,
|
"""
|
||||||
vae_scale_factors=(8, 32, 32),
|
|
||||||
dtype=None, device=None, operations=None, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
cross_attention_dim: int,
|
||||||
|
attention_head_dim: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
caption_channels: int,
|
||||||
|
num_layers: int,
|
||||||
|
positional_embedding_theta: float = 10000.0,
|
||||||
|
positional_embedding_max_pos: list = [20, 2048, 2048],
|
||||||
|
causal_temporal_positioning: bool = False,
|
||||||
|
vae_scale_factors: tuple = (8, 32, 32),
|
||||||
|
use_middle_indices_grid=False,
|
||||||
|
timestep_scale_multiplier = 1000.0,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.generator = None
|
self.generator = None
|
||||||
self.vae_scale_factors = vae_scale_factors
|
self.vae_scale_factors = vae_scale_factors
|
||||||
|
self.use_middle_indices_grid = use_middle_indices_grid
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.out_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
self.attention_head_dim = attention_head_dim
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.caption_channels = caption_channels
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.positional_embedding_theta = positional_embedding_theta
|
||||||
|
self.positional_embedding_max_pos = positional_embedding_max_pos
|
||||||
|
self.split_positional_embedding = LTXRopeType.from_dict(kwargs)
|
||||||
|
self.freq_grid_generator = (
|
||||||
|
generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64
|
||||||
|
else generate_freq_grid_pytorch
|
||||||
|
)
|
||||||
self.causal_temporal_positioning = causal_temporal_positioning
|
self.causal_temporal_positioning = causal_temporal_positioning
|
||||||
|
self.operations = operations
|
||||||
|
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||||
|
|
||||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
# Common dimensions
|
||||||
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.out_channels = in_channels
|
||||||
|
|
||||||
|
# Initialize common components
|
||||||
|
self._init_common_components(device, dtype)
|
||||||
|
|
||||||
|
# Initialize model-specific components
|
||||||
|
self._init_model_components(device, dtype, **kwargs)
|
||||||
|
|
||||||
|
# Initialize transformer blocks
|
||||||
|
self._init_transformer_blocks(device, dtype, **kwargs)
|
||||||
|
|
||||||
|
# Initialize output components
|
||||||
|
self._init_output_components(device, dtype)
|
||||||
|
|
||||||
|
def _init_common_components(self, device, dtype):
|
||||||
|
"""Initialize components common to all LTX models
|
||||||
|
- patchify_proj: Linear projection for patchifying input
|
||||||
|
- adaln_single: AdaLN layer for timestep embedding
|
||||||
|
- caption_projection: Linear projection for caption embedding
|
||||||
|
"""
|
||||||
|
self.patchify_proj = self.operations.Linear(
|
||||||
|
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
self.adaln_single = AdaLayerNormSingle(
|
self.adaln_single = AdaLayerNormSingle(
|
||||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
|
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||||
)
|
)
|
||||||
|
|
||||||
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.caption_projection = PixArtAlphaTextProjection(
|
self.caption_projection = PixArtAlphaTextProjection(
|
||||||
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
|
in_features=self.caption_channels,
|
||||||
|
hidden_size=self.inner_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _init_model_components(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize model-specific components. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize transformer blocks. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _init_output_components(self, device, dtype):
|
||||||
|
"""Initialize output components. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||||
|
"""Process input data. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
|
||||||
|
"""Process transformer blocks. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||||
|
"""Process output data. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||||
|
"""Prepare timestep embeddings."""
|
||||||
|
grid_mask = kwargs.get("grid_mask", None)
|
||||||
|
if grid_mask is not None:
|
||||||
|
timestep = timestep[:, grid_mask]
|
||||||
|
|
||||||
|
timestep = timestep * self.timestep_scale_multiplier
|
||||||
|
timestep, embedded_timestep = self.adaln_single(
|
||||||
|
timestep.flatten(),
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||||
|
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||||
|
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||||
|
|
||||||
|
return timestep, embedded_timestep
|
||||||
|
|
||||||
|
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||||
|
"""Prepare context for transformer blocks."""
|
||||||
|
if self.caption_projection is not None:
|
||||||
|
context = self.caption_projection(context)
|
||||||
|
context = context.view(batch_size, -1, x.shape[-1])
|
||||||
|
|
||||||
|
return context, attention_mask
|
||||||
|
|
||||||
|
def _precompute_freqs_cis(
|
||||||
|
self,
|
||||||
|
indices_grid,
|
||||||
|
dim,
|
||||||
|
out_dtype,
|
||||||
|
theta=10000.0,
|
||||||
|
max_pos=[20, 2048, 2048],
|
||||||
|
use_middle_indices_grid=False,
|
||||||
|
num_attention_heads=32,
|
||||||
|
):
|
||||||
|
split_mode = self.split_positional_embedding == LTXRopeType.SPLIT
|
||||||
|
indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device)
|
||||||
|
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
|
||||||
|
|
||||||
|
if split_mode:
|
||||||
|
expected_freqs = dim // 2
|
||||||
|
current_freqs = freqs.shape[-1]
|
||||||
|
pad_size = expected_freqs - current_freqs
|
||||||
|
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
|
||||||
|
else:
|
||||||
|
# 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only
|
||||||
|
n_elem = 2 * indices_grid.shape[1]
|
||||||
|
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||||
|
return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode
|
||||||
|
|
||||||
|
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
||||||
|
"""Prepare positional embeddings."""
|
||||||
|
fractional_coords = pixel_coords.to(torch.float32)
|
||||||
|
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
||||||
|
pe = self._precompute_freqs_cis(
|
||||||
|
fractional_coords,
|
||||||
|
dim=self.inner_dim,
|
||||||
|
out_dtype=x_dtype,
|
||||||
|
max_pos=self.positional_embedding_max_pos,
|
||||||
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
)
|
||||||
|
return pe
|
||||||
|
|
||||||
|
def _prepare_attention_mask(self, attention_mask, x_dtype):
|
||||||
|
"""Prepare attention mask."""
|
||||||
|
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||||
|
attention_mask = (attention_mask - 1).to(x_dtype).reshape(
|
||||||
|
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||||
|
) * torch.finfo(x_dtype).max
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass for LTX models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
timestep: Timestep tensor
|
||||||
|
context: Context tensor (e.g., text embeddings)
|
||||||
|
attention_mask: Attention mask tensor
|
||||||
|
frame_rate: Frame rate for temporal processing
|
||||||
|
transformer_options: Additional options for transformer blocks
|
||||||
|
keyframe_idxs: Keyframe indices for temporal processing
|
||||||
|
**kwargs: Additional keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed output tensor
|
||||||
|
"""
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(
|
||||||
|
comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options
|
||||||
|
),
|
||||||
|
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Internal forward pass for LTX models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
timestep: Timestep tensor
|
||||||
|
context: Context tensor (e.g., text embeddings)
|
||||||
|
attention_mask: Attention mask tensor
|
||||||
|
frame_rate: Frame rate for temporal processing
|
||||||
|
transformer_options: Additional options for transformer blocks
|
||||||
|
keyframe_idxs: Keyframe indices for temporal processing
|
||||||
|
**kwargs: Additional keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed output tensor
|
||||||
|
"""
|
||||||
|
if isinstance(x, list):
|
||||||
|
input_dtype = x[0].dtype
|
||||||
|
batch_size = x[0].shape[0]
|
||||||
|
else:
|
||||||
|
input_dtype = x.dtype
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
# Process input
|
||||||
|
merged_args = {**transformer_options, **kwargs}
|
||||||
|
x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args)
|
||||||
|
merged_args.update(additional_args)
|
||||||
|
|
||||||
|
# Prepare timestep and context
|
||||||
|
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||||
|
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
||||||
|
|
||||||
|
# Prepare attention mask and positional embeddings
|
||||||
|
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
||||||
|
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
||||||
|
|
||||||
|
# Process transformer blocks
|
||||||
|
x = self._process_transformer_blocks(
|
||||||
|
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process output
|
||||||
|
x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVModel(LTXBaseModel):
|
||||||
|
"""LTXV model for video generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=128,
|
||||||
|
cross_attention_dim=2048,
|
||||||
|
attention_head_dim=64,
|
||||||
|
num_attention_heads=32,
|
||||||
|
caption_channels=4096,
|
||||||
|
num_layers=28,
|
||||||
|
positional_embedding_theta=10000.0,
|
||||||
|
positional_embedding_max_pos=[20, 2048, 2048],
|
||||||
|
causal_temporal_positioning=False,
|
||||||
|
vae_scale_factors=(8, 32, 32),
|
||||||
|
use_middle_indices_grid=False,
|
||||||
|
timestep_scale_multiplier = 1000.0,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
in_channels=in_channels,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
caption_channels=caption_channels,
|
||||||
|
num_layers=num_layers,
|
||||||
|
positional_embedding_theta=positional_embedding_theta,
|
||||||
|
positional_embedding_max_pos=positional_embedding_max_pos,
|
||||||
|
causal_temporal_positioning=causal_temporal_positioning,
|
||||||
|
vae_scale_factors=vae_scale_factors,
|
||||||
|
use_middle_indices_grid=use_middle_indices_grid,
|
||||||
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_model_components(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize LTXV-specific components."""
|
||||||
|
# No additional components needed for LTXV beyond base class
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize transformer blocks for LTXV."""
|
||||||
self.transformer_blocks = nn.ModuleList(
|
self.transformer_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
BasicTransformerBlock(
|
BasicTransformerBlock(
|
||||||
self.inner_dim,
|
self.inner_dim,
|
||||||
num_attention_heads,
|
self.num_attention_heads,
|
||||||
attention_head_dim,
|
self.attention_head_dim,
|
||||||
context_dim=cross_attention_dim,
|
context_dim=self.cross_attention_dim,
|
||||||
# attn_precision=attn_precision,
|
dtype=dtype,
|
||||||
dtype=dtype, device=device, operations=operations
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
)
|
)
|
||||||
for d in range(num_layers)
|
for _ in range(self.num_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _init_output_components(self, device, dtype):
|
||||||
|
"""Initialize output components for LTXV."""
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
||||||
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.norm_out = self.operations.LayerNorm(
|
||||||
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||||
|
)
|
||||||
self.patchifier = SymmetricPatchifier(1)
|
self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||||
|
self.patchifier = SymmetricPatchifier(1, start_end=True)
|
||||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
|
||||||
self._forward,
|
|
||||||
self,
|
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
|
||||||
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
|
|
||||||
|
|
||||||
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
|
||||||
|
|
||||||
orig_shape = list(x.shape)
|
|
||||||
|
|
||||||
|
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||||
|
"""Process input for LTXV."""
|
||||||
|
additional_args = {"orig_shape": list(x.shape)}
|
||||||
x, latent_coords = self.patchifier.patchify(x)
|
x, latent_coords = self.patchifier.patchify(x)
|
||||||
pixel_coords = latent_to_pixel_coords(
|
pixel_coords = latent_to_pixel_coords(
|
||||||
latent_coords=latent_coords,
|
latent_coords=latent_coords,
|
||||||
@ -423,44 +880,30 @@ class LTXVModel(torch.nn.Module):
|
|||||||
causal_fix=self.causal_temporal_positioning,
|
causal_fix=self.causal_temporal_positioning,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
grid_mask = None
|
||||||
if keyframe_idxs is not None:
|
if keyframe_idxs is not None:
|
||||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
|
additional_args.update({ "orig_patchified_shape": list(x.shape)})
|
||||||
|
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
|
||||||
|
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
|
||||||
|
additional_args.update({"grid_mask": grid_mask})
|
||||||
|
x = x[:, grid_mask, :]
|
||||||
|
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
||||||
|
|
||||||
fractional_coords = pixel_coords.to(torch.float32)
|
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
||||||
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||||
|
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||||
|
|
||||||
x = self.patchify_proj(x)
|
x = self.patchify_proj(x)
|
||||||
timestep = timestep * 1000.0
|
return x, pixel_coords, additional_args
|
||||||
|
|
||||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
|
||||||
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
|
||||||
|
|
||||||
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
|
|
||||||
|
|
||||||
batch_size = x.shape[0]
|
|
||||||
timestep, embedded_timestep = self.adaln_single(
|
|
||||||
timestep.flatten(),
|
|
||||||
{"resolution": None, "aspect_ratio": None},
|
|
||||||
batch_size=batch_size,
|
|
||||||
hidden_dtype=x.dtype,
|
|
||||||
)
|
|
||||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
|
||||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
|
||||||
embedded_timestep = embedded_timestep.view(
|
|
||||||
batch_size, -1, embedded_timestep.shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Blocks
|
|
||||||
if self.caption_projection is not None:
|
|
||||||
batch_size = x.shape[0]
|
|
||||||
context = self.caption_projection(context)
|
|
||||||
context = context.view(
|
|
||||||
batch_size, -1, x.shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
|
||||||
|
"""Process transformer blocks for LTXV."""
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
|
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||||
@ -478,16 +921,28 @@ class LTXVModel(torch.nn.Module):
|
|||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Output
|
return x
|
||||||
|
|
||||||
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||||
|
"""Process output for LTXV."""
|
||||||
|
# Apply scale-shift modulation
|
||||||
scale_shift_values = (
|
scale_shift_values = (
|
||||||
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||||
)
|
)
|
||||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||||
|
|
||||||
x = self.norm_out(x)
|
x = self.norm_out(x)
|
||||||
# Modulation
|
x = x * (1 + scale) + shift
|
||||||
x = torch.addcmul(x, x, scale).add_(shift)
|
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
|
|
||||||
|
if keyframe_idxs is not None:
|
||||||
|
grid_mask = kwargs["grid_mask"]
|
||||||
|
orig_patchified_shape = kwargs["orig_patchified_shape"]
|
||||||
|
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
|
||||||
|
full_x[:, grid_mask, :] = x
|
||||||
|
x = full_x
|
||||||
|
# Unpatchify to restore original dimensions
|
||||||
|
orig_shape = kwargs["orig_shape"]
|
||||||
x = self.patchifier.unpatchify(
|
x = self.patchifier.unpatchify(
|
||||||
latents=x,
|
latents=x,
|
||||||
output_height=orig_shape[3],
|
output_height=orig_shape[3],
|
||||||
|
|||||||
@ -21,20 +21,23 @@ def latent_to_pixel_coords(
|
|||||||
Returns:
|
Returns:
|
||||||
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
||||||
"""
|
"""
|
||||||
|
shape = [1] * latent_coords.ndim
|
||||||
|
shape[1] = -1
|
||||||
pixel_coords = (
|
pixel_coords = (
|
||||||
latent_coords
|
latent_coords
|
||||||
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
|
* torch.tensor(scale_factors, device=latent_coords.device).view(*shape)
|
||||||
)
|
)
|
||||||
if causal_fix:
|
if causal_fix:
|
||||||
# Fix temporal scale for first frame to 1 due to causality
|
# Fix temporal scale for first frame to 1 due to causality
|
||||||
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
|
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
||||||
return pixel_coords
|
return pixel_coords
|
||||||
|
|
||||||
|
|
||||||
class Patchifier(ABC):
|
class Patchifier(ABC):
|
||||||
def __init__(self, patch_size: int):
|
def __init__(self, patch_size: int, start_end: bool=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._patch_size = (1, patch_size, patch_size)
|
self._patch_size = (1, patch_size, patch_size)
|
||||||
|
self.start_end = start_end
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def patchify(
|
def patchify(
|
||||||
@ -71,11 +74,23 @@ class Patchifier(ABC):
|
|||||||
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
||||||
indexing="ij",
|
indexing="ij",
|
||||||
)
|
)
|
||||||
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
|
latent_sample_coords_start = torch.stack(latent_sample_coords, dim=0)
|
||||||
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
delta = torch.tensor(self._patch_size, device=latent_sample_coords_start.device, dtype=latent_sample_coords_start.dtype)[:, None, None, None]
|
||||||
latent_coords = rearrange(
|
latent_sample_coords_end = latent_sample_coords_start + delta
|
||||||
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
|
|
||||||
|
latent_sample_coords_start = latent_sample_coords_start.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||||
|
latent_sample_coords_start = rearrange(
|
||||||
|
latent_sample_coords_start, "b c f h w -> b c (f h w)", b=batch_size
|
||||||
)
|
)
|
||||||
|
if self.start_end:
|
||||||
|
latent_sample_coords_end = latent_sample_coords_end.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||||
|
latent_sample_coords_end = rearrange(
|
||||||
|
latent_sample_coords_end, "b c f h w -> b c (f h w)", b=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_coords = torch.stack((latent_sample_coords_start, latent_sample_coords_end), dim=-1)
|
||||||
|
else:
|
||||||
|
latent_coords = latent_sample_coords_start
|
||||||
return latent_coords
|
return latent_coords
|
||||||
|
|
||||||
|
|
||||||
@ -115,3 +130,61 @@ class SymmetricPatchifier(Patchifier):
|
|||||||
q=self._patch_size[2],
|
q=self._patch_size[2],
|
||||||
)
|
)
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
class AudioPatchifier(Patchifier):
|
||||||
|
def __init__(self, patch_size: int,
|
||||||
|
sample_rate=16000,
|
||||||
|
hop_length=160,
|
||||||
|
audio_latent_downsample_factor=4,
|
||||||
|
is_causal=True,
|
||||||
|
start_end=False,
|
||||||
|
shift = 0
|
||||||
|
):
|
||||||
|
super().__init__(patch_size, start_end=start_end)
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
||||||
|
self.is_causal = is_causal
|
||||||
|
self.shift = shift
|
||||||
|
|
||||||
|
def copy_with_shift(self, shift):
|
||||||
|
return AudioPatchifier(
|
||||||
|
self.patch_size, self.sample_rate, self.hop_length, self.audio_latent_downsample_factor,
|
||||||
|
self.is_causal, self.start_end, shift
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_audio_latent_time_in_sec(self, start_latent, end_latent: int, dtype: torch.dtype, device=torch.device):
|
||||||
|
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
|
||||||
|
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
|
||||||
|
if self.is_causal:
|
||||||
|
audio_mel_frame = (audio_mel_frame + 1 - self.audio_latent_downsample_factor).clip(min=0)
|
||||||
|
return audio_mel_frame * self.hop_length / self.sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def patchify(self, audio_latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# audio_latents: (batch, channels, time, freq)
|
||||||
|
b, _, t, _ = audio_latents.shape
|
||||||
|
audio_latents = rearrange(
|
||||||
|
audio_latents,
|
||||||
|
"b c t f -> b t (c f)",
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_latents_start_timings = self._get_audio_latent_time_in_sec(self.shift, t + self.shift, torch.float32, audio_latents.device)
|
||||||
|
audio_latents_start_timings = audio_latents_start_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
|
||||||
|
|
||||||
|
if self.start_end:
|
||||||
|
audio_latents_end_timings = self._get_audio_latent_time_in_sec(self.shift + 1, t + self.shift + 1, torch.float32, audio_latents.device)
|
||||||
|
audio_latents_end_timings = audio_latents_end_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
|
||||||
|
|
||||||
|
audio_latents_timings = torch.stack([audio_latents_start_timings, audio_latents_end_timings], dim=-1)
|
||||||
|
else:
|
||||||
|
audio_latents_timings = audio_latents_start_timings
|
||||||
|
return audio_latents, audio_latents_timings
|
||||||
|
|
||||||
|
def unpatchify(self, audio_latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor:
|
||||||
|
# audio_latents: (batch, time, freq * channels)
|
||||||
|
audio_latents = rearrange(
|
||||||
|
audio_latents, "b t (c f) -> b c t f", c=channels, f=freq
|
||||||
|
)
|
||||||
|
return audio_latents
|
||||||
|
|||||||
286
comfy/ldm/lightricks/vae/audio_vae.py
Normal file
286
comfy/ldm/lightricks/vae/audio_vae.py
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.model_patcher
|
||||||
|
import comfy.utils as utils
|
||||||
|
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
|
||||||
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
|
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
||||||
|
CausalityAxis,
|
||||||
|
CausalAudioAutoencoder,
|
||||||
|
)
|
||||||
|
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
|
||||||
|
|
||||||
|
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AudioVAEComponentConfig:
|
||||||
|
"""Container for model component configuration extracted from metadata."""
|
||||||
|
|
||||||
|
autoencoder: dict
|
||||||
|
vocoder: dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig":
|
||||||
|
assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE"
|
||||||
|
|
||||||
|
raw_config = metadata["config"]
|
||||||
|
if isinstance(raw_config, str):
|
||||||
|
parsed_config = json.loads(raw_config)
|
||||||
|
else:
|
||||||
|
parsed_config = raw_config
|
||||||
|
|
||||||
|
audio_config = parsed_config.get("audio_vae")
|
||||||
|
vocoder_config = parsed_config.get("vocoder")
|
||||||
|
|
||||||
|
assert audio_config is not None, "Audio VAE config is required for audio VAE"
|
||||||
|
assert vocoder_config is not None, "Vocoder config is required for audio VAE"
|
||||||
|
|
||||||
|
return cls(autoencoder=audio_config, vocoder=vocoder_config)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDeviceManager:
|
||||||
|
"""Manages device placement and GPU residency for the composed model."""
|
||||||
|
|
||||||
|
def __init__(self, module: torch.nn.Module):
|
||||||
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
offload_device = comfy.model_management.vae_offload_device()
|
||||||
|
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
|
||||||
|
|
||||||
|
def ensure_model_loaded(self) -> None:
|
||||||
|
comfy.model_management.free_memory(
|
||||||
|
self.patcher.model_size(),
|
||||||
|
self.patcher.load_device,
|
||||||
|
)
|
||||||
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
|
|
||||||
|
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
return tensor.to(self.patcher.load_device)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def load_device(self):
|
||||||
|
return self.patcher.load_device
|
||||||
|
|
||||||
|
|
||||||
|
class AudioLatentNormalizer:
|
||||||
|
"""Applies per-channel statistics in patch space and restores original layout."""
|
||||||
|
|
||||||
|
def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module):
|
||||||
|
self.patchifier = patchfier
|
||||||
|
self.statistics = statistics_processor
|
||||||
|
|
||||||
|
def normalize(self, latents: torch.Tensor) -> torch.Tensor:
|
||||||
|
channels = latents.shape[1]
|
||||||
|
freq = latents.shape[3]
|
||||||
|
patched, _ = self.patchifier.patchify(latents)
|
||||||
|
normalized = self.statistics.normalize(patched)
|
||||||
|
return self.patchifier.unpatchify(normalized, channels=channels, freq=freq)
|
||||||
|
|
||||||
|
def denormalize(self, latents: torch.Tensor) -> torch.Tensor:
|
||||||
|
channels = latents.shape[1]
|
||||||
|
freq = latents.shape[3]
|
||||||
|
patched, _ = self.patchifier.patchify(latents)
|
||||||
|
denormalized = self.statistics.un_normalize(patched)
|
||||||
|
return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioPreprocessor:
|
||||||
|
"""Prepares raw waveforms for the autoencoder by matching training conditions."""
|
||||||
|
|
||||||
|
def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int):
|
||||||
|
self.target_sample_rate = target_sample_rate
|
||||||
|
self.mel_bins = mel_bins
|
||||||
|
self.mel_hop_length = mel_hop_length
|
||||||
|
self.n_fft = n_fft
|
||||||
|
|
||||||
|
def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor:
|
||||||
|
if source_rate == self.target_sample_rate:
|
||||||
|
return waveform
|
||||||
|
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_amplitude(
|
||||||
|
waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5
|
||||||
|
) -> torch.Tensor:
|
||||||
|
waveform = waveform - waveform.mean(dim=2, keepdim=True)
|
||||||
|
peak = torch.max(torch.abs(waveform)) + eps
|
||||||
|
scale = peak.clamp(max=max_amplitude) / peak
|
||||||
|
return waveform * scale
|
||||||
|
|
||||||
|
def waveform_to_mel(
|
||||||
|
self, waveform: torch.Tensor, waveform_sample_rate: int, device
|
||||||
|
) -> torch.Tensor:
|
||||||
|
waveform = self.resample(waveform, waveform_sample_rate)
|
||||||
|
waveform = self.normalize_amplitude(waveform)
|
||||||
|
|
||||||
|
mel_transform = torchaudio.transforms.MelSpectrogram(
|
||||||
|
sample_rate=self.target_sample_rate,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
win_length=self.n_fft,
|
||||||
|
hop_length=self.mel_hop_length,
|
||||||
|
f_min=0.0,
|
||||||
|
f_max=self.target_sample_rate / 2.0,
|
||||||
|
n_mels=self.mel_bins,
|
||||||
|
window_fn=torch.hann_window,
|
||||||
|
center=True,
|
||||||
|
pad_mode="reflect",
|
||||||
|
power=1.0,
|
||||||
|
mel_scale="slaney",
|
||||||
|
norm="slaney",
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
mel = mel_transform(waveform)
|
||||||
|
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||||
|
return mel.permute(0, 1, 3, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class AudioVAE(torch.nn.Module):
|
||||||
|
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
|
||||||
|
|
||||||
|
def __init__(self, state_dict: dict, metadata: dict):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
component_config = AudioVAEComponentConfig.from_metadata(metadata)
|
||||||
|
|
||||||
|
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
|
||||||
|
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
||||||
|
|
||||||
|
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||||
|
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||||
|
|
||||||
|
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||||
|
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||||
|
|
||||||
|
autoencoder_config = self.autoencoder.get_config()
|
||||||
|
self.normalizer = AudioLatentNormalizer(
|
||||||
|
AudioPatchifier(
|
||||||
|
patch_size=1,
|
||||||
|
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||||
|
sample_rate=autoencoder_config["sampling_rate"],
|
||||||
|
hop_length=autoencoder_config["mel_hop_length"],
|
||||||
|
is_causal=autoencoder_config["is_causal"],
|
||||||
|
),
|
||||||
|
self.autoencoder.per_channel_statistics,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.preprocessor = AudioPreprocessor(
|
||||||
|
target_sample_rate=autoencoder_config["sampling_rate"],
|
||||||
|
mel_bins=autoencoder_config["mel_bins"],
|
||||||
|
mel_hop_length=autoencoder_config["mel_hop_length"],
|
||||||
|
n_fft=autoencoder_config["n_fft"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.device_manager = ModelDeviceManager(self)
|
||||||
|
|
||||||
|
def encode(self, audio: dict) -> torch.Tensor:
|
||||||
|
"""Encode a waveform dictionary into normalized latent tensors."""
|
||||||
|
|
||||||
|
waveform = audio["waveform"]
|
||||||
|
waveform_sample_rate = audio["sample_rate"]
|
||||||
|
input_device = waveform.device
|
||||||
|
# Ensure that Audio VAE is loaded on the correct device.
|
||||||
|
self.device_manager.ensure_model_loaded()
|
||||||
|
|
||||||
|
waveform = self.device_manager.move_to_load_device(waveform)
|
||||||
|
expected_channels = self.autoencoder.encoder.in_channels
|
||||||
|
if waveform.shape[1] != expected_channels:
|
||||||
|
raise ValueError(
|
||||||
|
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
mel_spec = self.preprocessor.waveform_to_mel(
|
||||||
|
waveform, waveform_sample_rate, device=self.device_manager.load_device
|
||||||
|
)
|
||||||
|
|
||||||
|
latents = self.autoencoder.encode(mel_spec)
|
||||||
|
posterior = DiagonalGaussianDistribution(latents)
|
||||||
|
latent_mode = posterior.mode()
|
||||||
|
|
||||||
|
normalized = self.normalizer.normalize(latent_mode)
|
||||||
|
return normalized.to(input_device)
|
||||||
|
|
||||||
|
def decode(self, latents: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Decode normalized latent tensors into an audio waveform."""
|
||||||
|
original_shape = latents.shape
|
||||||
|
|
||||||
|
# Ensure that Audio VAE is loaded on the correct device.
|
||||||
|
self.device_manager.ensure_model_loaded()
|
||||||
|
|
||||||
|
latents = self.device_manager.move_to_load_device(latents)
|
||||||
|
latents = self.normalizer.denormalize(latents)
|
||||||
|
|
||||||
|
target_shape = self.target_shape_from_latents(original_shape)
|
||||||
|
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
|
||||||
|
|
||||||
|
waveform = self.run_vocoder(mel_spec)
|
||||||
|
return self.device_manager.move_to_load_device(waveform)
|
||||||
|
|
||||||
|
def target_shape_from_latents(self, latents_shape):
|
||||||
|
batch, _, time, _ = latents_shape
|
||||||
|
target_length = time * LATENT_DOWNSAMPLE_FACTOR
|
||||||
|
if self.autoencoder.causality_axis != CausalityAxis.NONE:
|
||||||
|
target_length -= LATENT_DOWNSAMPLE_FACTOR - 1
|
||||||
|
return (
|
||||||
|
batch,
|
||||||
|
self.autoencoder.decoder.out_ch,
|
||||||
|
target_length,
|
||||||
|
self.autoencoder.mel_bins,
|
||||||
|
)
|
||||||
|
|
||||||
|
def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int:
|
||||||
|
return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second)
|
||||||
|
|
||||||
|
def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
audio_channels = self.autoencoder.decoder.out_ch
|
||||||
|
vocoder_input = mel_spec.transpose(2, 3)
|
||||||
|
|
||||||
|
if audio_channels == 1:
|
||||||
|
vocoder_input = vocoder_input.squeeze(1)
|
||||||
|
elif audio_channels != 2:
|
||||||
|
raise ValueError(f"Unsupported audio_channels: {audio_channels}")
|
||||||
|
|
||||||
|
return self.vocoder(vocoder_input)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_rate(self) -> int:
|
||||||
|
return int(self.autoencoder.sampling_rate)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mel_hop_length(self) -> int:
|
||||||
|
return int(self.autoencoder.mel_hop_length)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mel_bins(self) -> int:
|
||||||
|
return int(self.autoencoder.mel_bins)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latent_channels(self) -> int:
|
||||||
|
return int(self.autoencoder.decoder.z_channels)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latent_frequency_bins(self) -> int:
|
||||||
|
return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latents_per_second(self) -> float:
|
||||||
|
return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_sample_rate(self) -> int:
|
||||||
|
output_rate = getattr(self.vocoder, "output_sample_rate", None)
|
||||||
|
if output_rate is not None:
|
||||||
|
return int(output_rate)
|
||||||
|
upsample_factor = getattr(self.vocoder, "upsample_factor", None)
|
||||||
|
if upsample_factor is None:
|
||||||
|
raise AttributeError(
|
||||||
|
"Vocoder is missing upsample_factor; cannot infer output sample rate"
|
||||||
|
)
|
||||||
|
return int(self.sample_rate * upsample_factor / self.mel_hop_length)
|
||||||
|
|
||||||
|
def memory_required(self, input_shape):
|
||||||
|
return self.device_manager.patcher.model_size()
|
||||||
909
comfy/ldm/lightricks/vae/causal_audio_autoencoder.py
Normal file
909
comfy/ldm/lightricks/vae/causal_audio_autoencoder.py
Normal file
@ -0,0 +1,909 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from typing import Optional
|
||||||
|
from enum import Enum
|
||||||
|
from .pixel_norm import PixelNorm
|
||||||
|
import comfy.ops
|
||||||
|
import logging
|
||||||
|
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
|
||||||
|
class StringConvertibleEnum(Enum):
|
||||||
|
"""
|
||||||
|
Base enum class that provides string-to-enum conversion functionality.
|
||||||
|
|
||||||
|
This mixin adds a str_to_enum() class method that handles conversion from
|
||||||
|
strings, None, or existing enum instances with case-insensitive matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def str_to_enum(cls, value):
|
||||||
|
"""
|
||||||
|
Convert a string, enum instance, or None to the appropriate enum member.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Can be an enum instance of this class, a string, or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Enum member of this class
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the value cannot be converted to a valid enum member
|
||||||
|
"""
|
||||||
|
# Already an enum instance of this class
|
||||||
|
if isinstance(value, cls):
|
||||||
|
return value
|
||||||
|
|
||||||
|
# None maps to NONE member if it exists
|
||||||
|
if value is None:
|
||||||
|
if hasattr(cls, "NONE"):
|
||||||
|
return cls.NONE
|
||||||
|
raise ValueError(f"{cls.__name__} does not have a NONE member to map None to")
|
||||||
|
|
||||||
|
# String conversion (case-insensitive)
|
||||||
|
if isinstance(value, str):
|
||||||
|
value_lower = value.lower()
|
||||||
|
|
||||||
|
# Try to match against enum values
|
||||||
|
for member in cls:
|
||||||
|
# Handle members with None values
|
||||||
|
if member.value is None:
|
||||||
|
if value_lower == "none":
|
||||||
|
return member
|
||||||
|
# Handle members with string values
|
||||||
|
elif isinstance(member.value, str) and member.value.lower() == value_lower:
|
||||||
|
return member
|
||||||
|
|
||||||
|
# Build helpful error message with valid values
|
||||||
|
valid_values = []
|
||||||
|
for member in cls:
|
||||||
|
if member.value is None:
|
||||||
|
valid_values.append("none")
|
||||||
|
elif isinstance(member.value, str):
|
||||||
|
valid_values.append(member.value)
|
||||||
|
|
||||||
|
raise ValueError(f"Invalid {cls.__name__} string: '{value}'. " f"Valid values are: {valid_values}")
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot convert type {type(value).__name__} to {cls.__name__} enum. "
|
||||||
|
f"Expected string, None, or {cls.__name__} instance."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionType(StringConvertibleEnum):
|
||||||
|
"""Enum for specifying the attention mechanism type."""
|
||||||
|
|
||||||
|
VANILLA = "vanilla"
|
||||||
|
LINEAR = "linear"
|
||||||
|
NONE = "none"
|
||||||
|
|
||||||
|
|
||||||
|
class CausalityAxis(StringConvertibleEnum):
|
||||||
|
"""Enum for specifying the causality axis in causal convolutions."""
|
||||||
|
|
||||||
|
NONE = None
|
||||||
|
WIDTH = "width"
|
||||||
|
HEIGHT = "height"
|
||||||
|
WIDTH_COMPATIBILITY = "width-compatibility"
|
||||||
|
|
||||||
|
|
||||||
|
def Normalize(in_channels, *, num_groups=32, normtype="group"):
|
||||||
|
if normtype == "group":
|
||||||
|
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
elif normtype == "pixel":
|
||||||
|
return PixelNorm(dim=1, eps=1e-6)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv2d(nn.Module):
|
||||||
|
"""
|
||||||
|
A causal 2D convolution.
|
||||||
|
|
||||||
|
This layer ensures that the output at time `t` only depends on inputs
|
||||||
|
at time `t` and earlier. It achieves this by applying asymmetric padding
|
||||||
|
to the time dimension (width) before the convolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
bias=True,
|
||||||
|
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
|
||||||
|
# Ensure kernel_size and dilation are tuples
|
||||||
|
kernel_size = nn.modules.utils._pair(kernel_size)
|
||||||
|
dilation = nn.modules.utils._pair(dilation)
|
||||||
|
|
||||||
|
# Calculate padding dimensions
|
||||||
|
pad_h = (kernel_size[0] - 1) * dilation[0]
|
||||||
|
pad_w = (kernel_size[1] - 1) * dilation[1]
|
||||||
|
|
||||||
|
# The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
|
||||||
|
match self.causality_axis:
|
||||||
|
case CausalityAxis.NONE:
|
||||||
|
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
||||||
|
case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
|
||||||
|
self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
||||||
|
case CausalityAxis.HEIGHT:
|
||||||
|
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
||||||
|
|
||||||
|
# The internal convolution layer uses no padding, as we handle it manually
|
||||||
|
self.conv = ops.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=0,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Apply causal padding before convolution
|
||||||
|
x = F.pad(x, self.padding)
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
def make_conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=None,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
bias=True,
|
||||||
|
causality_axis: Optional[CausalityAxis] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a 2D convolution layer that can be either causal or non-causal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels: Number of input channels
|
||||||
|
out_channels: Number of output channels
|
||||||
|
kernel_size: Size of the convolution kernel
|
||||||
|
stride: Convolution stride
|
||||||
|
padding: Padding (if None, will be calculated based on causal flag)
|
||||||
|
dilation: Dilation rate
|
||||||
|
groups: Number of groups for grouped convolution
|
||||||
|
bias: Whether to use bias
|
||||||
|
causality_axis: Dimension along which to apply causality.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either a regular Conv2d or CausalConv2d layer
|
||||||
|
"""
|
||||||
|
if causality_axis is not None:
|
||||||
|
# For causal convolution, padding is handled internally by CausalConv2d
|
||||||
|
return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
|
||||||
|
else:
|
||||||
|
# For non-causal convolution, use symmetric padding if not specified
|
||||||
|
if padding is None:
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
padding = kernel_size // 2
|
||||||
|
else:
|
||||||
|
padding = tuple(k // 2 for k in kernel_size)
|
||||||
|
return ops.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
groups,
|
||||||
|
bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.HEIGHT):
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = with_conv
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
if self.with_conv:
|
||||||
|
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
|
if self.with_conv:
|
||||||
|
x = self.conv(x)
|
||||||
|
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
|
||||||
|
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
|
||||||
|
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
|
||||||
|
# So the output elements rely on the following windows:
|
||||||
|
# 0: [-,-,0]
|
||||||
|
# 1: [-,0,0]
|
||||||
|
# 2: [0,0,1]
|
||||||
|
# 3: [0,1,1]
|
||||||
|
# 4: [1,1,2]
|
||||||
|
# 5: [1,2,2]
|
||||||
|
# Notice that the first and second elements in the output rely only on the first element in the input,
|
||||||
|
# while all other elements rely on two elements in the input.
|
||||||
|
# So we can drop the first element to undo the padding (rather than the last element).
|
||||||
|
# This is a no-op for non-causal convolutions.
|
||||||
|
match self.causality_axis:
|
||||||
|
case CausalityAxis.NONE:
|
||||||
|
pass # x remains unchanged
|
||||||
|
case CausalityAxis.HEIGHT:
|
||||||
|
x = x[:, :, 1:, :]
|
||||||
|
case CausalityAxis.WIDTH:
|
||||||
|
x = x[:, :, :, 1:]
|
||||||
|
case CausalityAxis.WIDTH_COMPATIBILITY:
|
||||||
|
pass # x remains unchanged
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
"""
|
||||||
|
A downsampling layer that can use either a strided convolution
|
||||||
|
or average pooling. Supports standard and causal padding for the
|
||||||
|
convolutional mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.WIDTH):
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = with_conv
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
|
||||||
|
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
|
||||||
|
raise ValueError("causality is only supported when `with_conv=True`.")
|
||||||
|
|
||||||
|
if self.with_conv:
|
||||||
|
# Do time downsampling here
|
||||||
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
|
self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.with_conv:
|
||||||
|
# (pad_left, pad_right, pad_top, pad_bottom)
|
||||||
|
match self.causality_axis:
|
||||||
|
case CausalityAxis.NONE:
|
||||||
|
pad = (0, 1, 0, 1)
|
||||||
|
case CausalityAxis.WIDTH:
|
||||||
|
pad = (2, 0, 0, 1)
|
||||||
|
case CausalityAxis.HEIGHT:
|
||||||
|
pad = (0, 1, 2, 0)
|
||||||
|
case CausalityAxis.WIDTH_COMPATIBILITY:
|
||||||
|
pad = (1, 0, 0, 1)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||||
|
|
||||||
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||||
|
x = self.conv(x)
|
||||||
|
else:
|
||||||
|
# This branch is only taken if with_conv=False, which implies causality_axis is NONE.
|
||||||
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
in_channels,
|
||||||
|
out_channels=None,
|
||||||
|
conv_shortcut=False,
|
||||||
|
dropout,
|
||||||
|
temb_channels=512,
|
||||||
|
norm_type="group",
|
||||||
|
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
|
||||||
|
if self.causality_axis != CausalityAxis.NONE and norm_type == "group":
|
||||||
|
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
||||||
|
self.in_channels = in_channels
|
||||||
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
|
||||||
|
self.norm1 = Normalize(in_channels, normtype=norm_type)
|
||||||
|
self.non_linearity = nn.SiLU()
|
||||||
|
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||||
|
if temb_channels > 0:
|
||||||
|
self.temb_proj = ops.Linear(temb_channels, out_channels)
|
||||||
|
self.norm2 = Normalize(out_channels, normtype=norm_type)
|
||||||
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
|
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
self.conv_shortcut = make_conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.nin_shortcut = make_conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, temb):
|
||||||
|
h = x
|
||||||
|
h = self.norm1(h)
|
||||||
|
h = self.non_linearity(h)
|
||||||
|
h = self.conv1(h)
|
||||||
|
|
||||||
|
if temb is not None:
|
||||||
|
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
||||||
|
|
||||||
|
h = self.norm2(h)
|
||||||
|
h = self.non_linearity(h)
|
||||||
|
h = self.dropout(h)
|
||||||
|
h = self.conv2(h)
|
||||||
|
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
x = self.conv_shortcut(x)
|
||||||
|
else:
|
||||||
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
|
class AttnBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, norm_type="group"):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.norm = Normalize(in_channels, normtype=norm_type)
|
||||||
|
self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q = self.q(h_)
|
||||||
|
k = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
b, c, h, w = q.shape
|
||||||
|
q = q.reshape(b, c, h * w).contiguous()
|
||||||
|
q = q.permute(0, 2, 1).contiguous() # b,hw,c
|
||||||
|
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
|
||||||
|
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||||
|
w_ = w_ * (int(c) ** (-0.5))
|
||||||
|
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||||
|
|
||||||
|
# attend to values
|
||||||
|
v = v.reshape(b, c, h * w).contiguous()
|
||||||
|
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
|
||||||
|
h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||||
|
h_ = h_.reshape(b, c, h, w).contiguous()
|
||||||
|
|
||||||
|
h_ = self.proj_out(h_)
|
||||||
|
|
||||||
|
return x + h_
|
||||||
|
|
||||||
|
|
||||||
|
def make_attn(in_channels, attn_type="vanilla", norm_type="group"):
|
||||||
|
# Convert string to enum if needed
|
||||||
|
attn_type = AttentionType.str_to_enum(attn_type)
|
||||||
|
|
||||||
|
if attn_type != AttentionType.NONE:
|
||||||
|
logging.info(f"making attention of type '{attn_type.value}' with {in_channels} in_channels")
|
||||||
|
else:
|
||||||
|
logging.info(f"making identity attention with {in_channels} in_channels")
|
||||||
|
|
||||||
|
match attn_type:
|
||||||
|
case AttentionType.VANILLA:
|
||||||
|
return AttnBlock(in_channels, norm_type=norm_type)
|
||||||
|
case AttentionType.NONE:
|
||||||
|
return nn.Identity(in_channels)
|
||||||
|
case AttentionType.LINEAR:
|
||||||
|
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ch,
|
||||||
|
out_ch,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
num_res_blocks,
|
||||||
|
attn_resolutions,
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
in_channels,
|
||||||
|
resolution,
|
||||||
|
z_channels,
|
||||||
|
double_z=True,
|
||||||
|
attn_type="vanilla",
|
||||||
|
mid_block_add_attention=True,
|
||||||
|
norm_type="group",
|
||||||
|
causality_axis=CausalityAxis.WIDTH.value,
|
||||||
|
**ignore_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = ch
|
||||||
|
self.temb_ch = 0
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.z_channels = z_channels
|
||||||
|
self.double_z = double_z
|
||||||
|
self.norm_type = norm_type
|
||||||
|
# Convert string to enum if needed (for config loading)
|
||||||
|
causality_axis = CausalityAxis.str_to_enum(causality_axis)
|
||||||
|
self.attn_type = AttentionType.str_to_enum(attn_type)
|
||||||
|
|
||||||
|
# downsampling
|
||||||
|
self.conv_in = make_conv2d(
|
||||||
|
in_channels,
|
||||||
|
self.ch,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.non_linearity = nn.SiLU()
|
||||||
|
|
||||||
|
curr_res = resolution
|
||||||
|
in_ch_mult = (1,) + tuple(ch_mult)
|
||||||
|
self.in_ch_mult = in_ch_mult
|
||||||
|
self.down = nn.ModuleList()
|
||||||
|
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_in = ch * in_ch_mult[i_level]
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
|
||||||
|
for _ in range(self.num_res_blocks):
|
||||||
|
block.append(
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
block_in = block_out
|
||||||
|
if curr_res in attn_resolutions:
|
||||||
|
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
|
||||||
|
|
||||||
|
down = nn.Module()
|
||||||
|
down.block = block
|
||||||
|
down.attn = attn
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||||
|
curr_res = curr_res // 2
|
||||||
|
self.down.append(down)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
if mid_block_add_attention:
|
||||||
|
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
|
||||||
|
else:
|
||||||
|
self.mid.attn_1 = nn.Identity()
|
||||||
|
self.mid.block_2 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(block_in, normtype=self.norm_type)
|
||||||
|
self.conv_out = make_conv2d(
|
||||||
|
block_in,
|
||||||
|
2 * z_channels if double_z else z_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass through the encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape [batch, channels, time, n_mels]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Encoded latent representation
|
||||||
|
"""
|
||||||
|
feature_maps = [self.conv_in(x)]
|
||||||
|
|
||||||
|
# Process each resolution level (from high to low resolution)
|
||||||
|
for resolution_level in range(self.num_resolutions):
|
||||||
|
# Apply residual blocks at current resolution level
|
||||||
|
for block_idx in range(self.num_res_blocks):
|
||||||
|
# Apply ResNet block with optional timestep embedding
|
||||||
|
current_features = self.down[resolution_level].block[block_idx](feature_maps[-1], temb=None)
|
||||||
|
|
||||||
|
# Apply attention if configured for this resolution level
|
||||||
|
if len(self.down[resolution_level].attn) > 0:
|
||||||
|
current_features = self.down[resolution_level].attn[block_idx](current_features)
|
||||||
|
|
||||||
|
# Store processed features
|
||||||
|
feature_maps.append(current_features)
|
||||||
|
|
||||||
|
# Downsample spatial dimensions (except at the final resolution level)
|
||||||
|
if resolution_level != self.num_resolutions - 1:
|
||||||
|
downsampled_features = self.down[resolution_level].downsample(feature_maps[-1])
|
||||||
|
feature_maps.append(downsampled_features)
|
||||||
|
|
||||||
|
# === MIDDLE PROCESSING PHASE ===
|
||||||
|
# Take the lowest resolution features for middle processing
|
||||||
|
bottleneck_features = feature_maps[-1]
|
||||||
|
|
||||||
|
# Apply first middle ResNet block
|
||||||
|
bottleneck_features = self.mid.block_1(bottleneck_features, temb=None)
|
||||||
|
|
||||||
|
# Apply middle attention block
|
||||||
|
bottleneck_features = self.mid.attn_1(bottleneck_features)
|
||||||
|
|
||||||
|
# Apply second middle ResNet block
|
||||||
|
bottleneck_features = self.mid.block_2(bottleneck_features, temb=None)
|
||||||
|
|
||||||
|
# === OUTPUT PHASE ===
|
||||||
|
# Normalize the bottleneck features
|
||||||
|
output_features = self.norm_out(bottleneck_features)
|
||||||
|
|
||||||
|
# Apply non-linearity (SiLU activation)
|
||||||
|
output_features = self.non_linearity(output_features)
|
||||||
|
|
||||||
|
# Final convolution to produce latent representation
|
||||||
|
# [batch, channels, time, n_mels] -> [batch, 2 * z_channels if double_z else z_channels, time, n_mels]
|
||||||
|
return self.conv_out(output_features)
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ch,
|
||||||
|
out_ch,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
num_res_blocks,
|
||||||
|
attn_resolutions,
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
in_channels,
|
||||||
|
resolution,
|
||||||
|
z_channels,
|
||||||
|
give_pre_end=False,
|
||||||
|
tanh_out=False,
|
||||||
|
attn_type="vanilla",
|
||||||
|
mid_block_add_attention=True,
|
||||||
|
norm_type="group",
|
||||||
|
causality_axis=CausalityAxis.WIDTH.value,
|
||||||
|
**ignorekwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = ch
|
||||||
|
self.temb_ch = 0
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_ch = out_ch
|
||||||
|
self.give_pre_end = give_pre_end
|
||||||
|
self.tanh_out = tanh_out
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.z_channels = z_channels
|
||||||
|
# Convert string to enum if needed (for config loading)
|
||||||
|
causality_axis = CausalityAxis.str_to_enum(causality_axis)
|
||||||
|
self.attn_type = AttentionType.str_to_enum(attn_type)
|
||||||
|
|
||||||
|
# compute block_in and curr_res at lowest res
|
||||||
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
self.conv_in = make_conv2d(z_channels, block_in, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||||
|
|
||||||
|
self.non_linearity = nn.SiLU()
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
if mid_block_add_attention:
|
||||||
|
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
|
||||||
|
else:
|
||||||
|
self.mid.attn_1 = nn.Identity()
|
||||||
|
self.mid.block_2 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for _ in range(self.num_res_blocks + 1):
|
||||||
|
block.append(
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
block_in = block_out
|
||||||
|
if curr_res in attn_resolutions:
|
||||||
|
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
|
||||||
|
up = nn.Module()
|
||||||
|
up.block = block
|
||||||
|
up.attn = attn
|
||||||
|
if i_level != 0:
|
||||||
|
up.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||||
|
curr_res = curr_res * 2
|
||||||
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(block_in, normtype=self.norm_type)
|
||||||
|
self.conv_out = make_conv2d(block_in, out_ch, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||||
|
|
||||||
|
def _adjust_output_shape(self, decoded_output, target_shape):
|
||||||
|
"""
|
||||||
|
Adjust output shape to match target dimensions for variable-length audio.
|
||||||
|
|
||||||
|
This function handles the common case where decoded audio spectrograms need to be
|
||||||
|
resized to match a specific target shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoded_output: Tensor of shape (batch, channels, time, frequency)
|
||||||
|
target_shape: Target shape tuple (batch, channels, time, frequency)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor adjusted to match target_shape exactly
|
||||||
|
"""
|
||||||
|
# Current output shape: (batch, channels, time, frequency)
|
||||||
|
_, _, current_time, current_freq = decoded_output.shape
|
||||||
|
_, target_channels, target_time, target_freq = target_shape
|
||||||
|
|
||||||
|
# Step 1: Crop first to avoid exceeding target dimensions
|
||||||
|
decoded_output = decoded_output[
|
||||||
|
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Step 2: Calculate padding needed for time and frequency dimensions
|
||||||
|
time_padding_needed = target_time - decoded_output.shape[2]
|
||||||
|
freq_padding_needed = target_freq - decoded_output.shape[3]
|
||||||
|
|
||||||
|
# Step 3: Apply padding if needed
|
||||||
|
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||||
|
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
|
||||||
|
# For audio: pad_left/right = frequency, pad_top/bottom = time
|
||||||
|
padding = (
|
||||||
|
0,
|
||||||
|
max(freq_padding_needed, 0), # frequency padding (left, right)
|
||||||
|
0,
|
||||||
|
max(time_padding_needed, 0), # time padding (top, bottom)
|
||||||
|
)
|
||||||
|
decoded_output = F.pad(decoded_output, padding)
|
||||||
|
|
||||||
|
# Step 4: Final safety crop to ensure exact target shape
|
||||||
|
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
||||||
|
|
||||||
|
return decoded_output
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {
|
||||||
|
"ch": self.ch,
|
||||||
|
"out_ch": self.out_ch,
|
||||||
|
"ch_mult": self.ch_mult,
|
||||||
|
"num_res_blocks": self.num_res_blocks,
|
||||||
|
"in_channels": self.in_channels,
|
||||||
|
"resolution": self.resolution,
|
||||||
|
"z_channels": self.z_channels,
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, latent_features, target_shape=None):
|
||||||
|
"""
|
||||||
|
Decode latent features back to audio spectrograms.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latent_features: Encoded latent representation of shape (batch, channels, height, width)
|
||||||
|
target_shape: Optional target output shape (batch, channels, time, frequency)
|
||||||
|
If provided, output will be cropped/padded to match this shape
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
|
||||||
|
"""
|
||||||
|
assert target_shape is not None, "Target shape is required for CausalAudioAutoencoder Decoder"
|
||||||
|
|
||||||
|
# Transform latent features to decoder's internal feature dimension
|
||||||
|
hidden_features = self.conv_in(latent_features)
|
||||||
|
|
||||||
|
# Middle processing
|
||||||
|
hidden_features = self.mid.block_1(hidden_features, temb=None)
|
||||||
|
hidden_features = self.mid.attn_1(hidden_features)
|
||||||
|
hidden_features = self.mid.block_2(hidden_features, temb=None)
|
||||||
|
|
||||||
|
# Upsampling
|
||||||
|
# Progressively increase spatial resolution from lowest to highest
|
||||||
|
for resolution_level in reversed(range(self.num_resolutions)):
|
||||||
|
# Apply residual blocks at current resolution level
|
||||||
|
for block_index in range(self.num_res_blocks + 1):
|
||||||
|
hidden_features = self.up[resolution_level].block[block_index](hidden_features, temb=None)
|
||||||
|
|
||||||
|
if len(self.up[resolution_level].attn) > 0:
|
||||||
|
hidden_features = self.up[resolution_level].attn[block_index](hidden_features)
|
||||||
|
|
||||||
|
if resolution_level != 0:
|
||||||
|
hidden_features = self.up[resolution_level].upsample(hidden_features)
|
||||||
|
|
||||||
|
# Output
|
||||||
|
if self.give_pre_end:
|
||||||
|
# Return intermediate features before final processing (for debugging/analysis)
|
||||||
|
decoded_output = hidden_features
|
||||||
|
else:
|
||||||
|
# Standard output path: normalize, activate, and convert to output channels
|
||||||
|
# Final normalization layer
|
||||||
|
hidden_features = self.norm_out(hidden_features)
|
||||||
|
|
||||||
|
# Apply SiLU (Swish) activation function
|
||||||
|
hidden_features = self.non_linearity(hidden_features)
|
||||||
|
|
||||||
|
# Final convolution to map to output channels (typically 2 for stereo audio)
|
||||||
|
decoded_output = self.conv_out(hidden_features)
|
||||||
|
|
||||||
|
# Optional tanh activation to bound output values to [-1, 1] range
|
||||||
|
if self.tanh_out:
|
||||||
|
decoded_output = torch.tanh(decoded_output)
|
||||||
|
|
||||||
|
# Adjust shape for audio data
|
||||||
|
if target_shape is not None:
|
||||||
|
decoded_output = self._adjust_output_shape(decoded_output, target_shape)
|
||||||
|
|
||||||
|
return decoded_output
|
||||||
|
|
||||||
|
|
||||||
|
class processor(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("std-of-means", torch.empty(128))
|
||||||
|
self.register_buffer("mean-of-means", torch.empty(128))
|
||||||
|
|
||||||
|
def un_normalize(self, x):
|
||||||
|
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
|
||||||
|
|
||||||
|
def normalize(self, x):
|
||||||
|
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
|
||||||
|
|
||||||
|
|
||||||
|
class CausalAudioAutoencoder(nn.Module):
|
||||||
|
def __init__(self, config=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = self._guess_config()
|
||||||
|
|
||||||
|
# Extract encoder and decoder configs from the new format
|
||||||
|
model_config = config.get("model", {}).get("params", {})
|
||||||
|
variables_config = config.get("variables", {})
|
||||||
|
|
||||||
|
self.sampling_rate = variables_config.get(
|
||||||
|
"sampling_rate",
|
||||||
|
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
|
||||||
|
)
|
||||||
|
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
||||||
|
decoder_config = model_config.get("decoder", encoder_config)
|
||||||
|
|
||||||
|
# Load mel spectrogram parameters
|
||||||
|
self.mel_bins = encoder_config.get("mel_bins", 64)
|
||||||
|
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||||
|
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||||
|
|
||||||
|
# Store causality configuration at VAE level (not just in encoder internals)
|
||||||
|
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
|
||||||
|
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
||||||
|
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
||||||
|
|
||||||
|
self.encoder = Encoder(**encoder_config)
|
||||||
|
self.decoder = Decoder(**decoder_config)
|
||||||
|
|
||||||
|
self.per_channel_statistics = processor()
|
||||||
|
|
||||||
|
def _guess_config(self):
|
||||||
|
encoder_config = {
|
||||||
|
# Required parameters - based on ltx-video-av-1679000 model metadata
|
||||||
|
"ch": 128,
|
||||||
|
"out_ch": 8,
|
||||||
|
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
|
||||||
|
"num_res_blocks": 2,
|
||||||
|
"attn_resolutions": [], # Based on metadata: empty list, no attention
|
||||||
|
"dropout": 0.0,
|
||||||
|
"resamp_with_conv": True,
|
||||||
|
"in_channels": 2, # stereo
|
||||||
|
"resolution": 256,
|
||||||
|
"z_channels": 8,
|
||||||
|
"double_z": True,
|
||||||
|
"attn_type": "vanilla",
|
||||||
|
"mid_block_add_attention": False, # Based on metadata: false
|
||||||
|
"norm_type": "pixel",
|
||||||
|
"causality_axis": "height", # Based on metadata
|
||||||
|
"mel_bins": 64, # Based on metadata: mel_bins = 64
|
||||||
|
}
|
||||||
|
|
||||||
|
decoder_config = {
|
||||||
|
# Inherits encoder config, can override specific params
|
||||||
|
**encoder_config,
|
||||||
|
"out_ch": 2, # Stereo audio output (2 channels)
|
||||||
|
"give_pre_end": False,
|
||||||
|
"tanh_out": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"_class_name": "CausalAudioAutoencoder",
|
||||||
|
"sampling_rate": 16000,
|
||||||
|
"model": {
|
||||||
|
"params": {
|
||||||
|
"encoder": encoder_config,
|
||||||
|
"decoder": decoder_config,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {
|
||||||
|
"sampling_rate": self.sampling_rate,
|
||||||
|
"mel_bins": self.mel_bins,
|
||||||
|
"mel_hop_length": self.mel_hop_length,
|
||||||
|
"n_fft": self.n_fft,
|
||||||
|
"causality_axis": self.causality_axis.value,
|
||||||
|
"is_causal": self.is_causal,
|
||||||
|
}
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self.encoder(x)
|
||||||
|
|
||||||
|
def decode(self, x, target_shape=None):
|
||||||
|
return self.decoder(x, target_shape=target_shape)
|
||||||
213
comfy/ldm/lightricks/vocoders/vocoder.py
Normal file
213
comfy/ldm/lightricks/vocoders/vocoder.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn as nn
|
||||||
|
import comfy.ops
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
def get_padding(kernel_size, dilation=1):
|
||||||
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock1(torch.nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||||
|
super(ResBlock1, self).__init__()
|
||||||
|
self.convs1 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[2],
|
||||||
|
padding=get_padding(kernel_size, dilation[2]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.convs2 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c1, c2 in zip(self.convs1, self.convs2):
|
||||||
|
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||||
|
xt = c2(xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock2(torch.nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||||
|
super(ResBlock2, self).__init__()
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c in self.convs:
|
||||||
|
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
xt = c(xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Vocoder(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config=None):
|
||||||
|
super(Vocoder, self).__init__()
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = self.get_default_config()
|
||||||
|
|
||||||
|
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
||||||
|
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
|
||||||
|
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
|
||||||
|
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||||
|
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
||||||
|
stereo = config.get("stereo", True)
|
||||||
|
resblock = config.get("resblock", "1")
|
||||||
|
|
||||||
|
self.output_sample_rate = config.get("output_sample_rate")
|
||||||
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(upsample_rates)
|
||||||
|
in_channels = 128 if stereo else 64
|
||||||
|
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||||
|
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||||
|
|
||||||
|
self.ups = nn.ModuleList()
|
||||||
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
|
self.ups.append(
|
||||||
|
ops.ConvTranspose1d(
|
||||||
|
upsample_initial_channel // (2**i),
|
||||||
|
upsample_initial_channel // (2 ** (i + 1)),
|
||||||
|
k,
|
||||||
|
u,
|
||||||
|
padding=(k - u) // 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||||
|
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||||
|
self.resblocks.append(resblock_class(ch, k, d))
|
||||||
|
|
||||||
|
out_channels = 2 if stereo else 1
|
||||||
|
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
|
||||||
|
|
||||||
|
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
||||||
|
|
||||||
|
def get_default_config(self):
|
||||||
|
"""Generate default configuration for the vocoder."""
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"resblock_kernel_sizes": [3, 7, 11],
|
||||||
|
"upsample_rates": [6, 5, 2, 2, 2],
|
||||||
|
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||||
|
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
"upsample_initial_channel": 1024,
|
||||||
|
"stereo": True,
|
||||||
|
"resblock": "1",
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass of the vocoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input spectrogram tensor. Can be:
|
||||||
|
- 3D: (batch_size, channels, time_steps) for mono
|
||||||
|
- 4D: (batch_size, 2, channels, time_steps) for stereo
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Audio tensor of shape (batch_size, out_channels, audio_length)
|
||||||
|
"""
|
||||||
|
if x.dim() == 4: # stereo
|
||||||
|
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
||||||
|
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
||||||
|
x = self.conv_pre(x)
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
x = self.ups[i](x)
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
x = F.leaky_relu(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
x = torch.tanh(x)
|
||||||
|
|
||||||
|
return x
|
||||||
@ -491,7 +491,8 @@ class NextDiT(nn.Module):
|
|||||||
for layer_id in range(n_layers)
|
for layer_id in range(n_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
# This norm final is in the lumina 2.0 code but isn't actually used for anything.
|
||||||
|
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
||||||
|
|
||||||
if self.pad_tokens_multiple is not None:
|
if self.pad_tokens_multiple is not None:
|
||||||
|
|||||||
@ -30,6 +30,13 @@ except ImportError as e:
|
|||||||
raise e
|
raise e
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
|
SAGE_ATTENTION3_IS_AVAILABLE = False
|
||||||
|
try:
|
||||||
|
from sageattn3 import sageattn3_blackwell
|
||||||
|
SAGE_ATTENTION3_IS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_func
|
from flash_attn import flash_attn_func
|
||||||
@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
out = out.reshape(b, -1, heads * dim_head)
|
out = out.reshape(b, -1, heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@wrap_attn
|
||||||
|
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
exception_fallback = False
|
||||||
|
if (q.device.type != "cuda" or
|
||||||
|
q.dtype not in (torch.float16, torch.bfloat16) or
|
||||||
|
mask is not None):
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if skip_reshape:
|
||||||
|
B, H, L, D = q.shape
|
||||||
|
if H != heads:
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
q_s, k_s, v_s = q, k, v
|
||||||
|
N = q.shape[2]
|
||||||
|
dim_head = D
|
||||||
|
else:
|
||||||
|
B, N, inner_dim = q.shape
|
||||||
|
if inner_dim % heads != 0:
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=False,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
dim_head = inner_dim // heads
|
||||||
|
|
||||||
|
if dim_head >= 256 or N <= 1024:
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not skip_reshape:
|
||||||
|
q_s, k_s, v_s = map(
|
||||||
|
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
B, H, L, D = q_s.shape
|
||||||
|
|
||||||
|
try:
|
||||||
|
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
|
||||||
|
except Exception as e:
|
||||||
|
exception_fallback = True
|
||||||
|
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
|
||||||
|
|
||||||
|
if exception_fallback:
|
||||||
|
if not skip_reshape:
|
||||||
|
del q_s, k_s, v_s
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=False,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if skip_reshape:
|
||||||
|
if not skip_output_reshape:
|
||||||
|
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||||
|
else:
|
||||||
|
if skip_output_reshape:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||||
@ -650,6 +744,8 @@ optimized_attention_masked = optimized_attention
|
|||||||
# register core-supported attention functions
|
# register core-supported attention functions
|
||||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||||
register_attention_function("sage", attention_sage)
|
register_attention_function("sage", attention_sage)
|
||||||
|
if SAGE_ATTENTION3_IS_AVAILABLE:
|
||||||
|
register_attention_function("sage3", attention3_sage)
|
||||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||||
register_attention_function("flash", attention_flash)
|
register_attention_function("flash", attention_flash)
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
|
|||||||
@ -394,7 +394,8 @@ class Model(nn.Module):
|
|||||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_linear_attn: attn_type = "linear"
|
if use_linear_attn:
|
||||||
|
attn_type = "linear"
|
||||||
self.ch = ch
|
self.ch = ch
|
||||||
self.temb_ch = self.ch*4
|
self.temb_ch = self.ch*4
|
||||||
self.num_resolutions = len(ch_mult)
|
self.num_resolutions = len(ch_mult)
|
||||||
@ -548,7 +549,8 @@ class Encoder(nn.Module):
|
|||||||
conv3d=False, time_compress=None,
|
conv3d=False, time_compress=None,
|
||||||
**ignore_kwargs):
|
**ignore_kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_linear_attn: attn_type = "linear"
|
if use_linear_attn:
|
||||||
|
attn_type = "linear"
|
||||||
self.ch = ch
|
self.ch = ch
|
||||||
self.temb_ch = 0
|
self.temb_ch = 0
|
||||||
self.num_resolutions = len(ch_mult)
|
self.num_resolutions = len(ch_mult)
|
||||||
|
|||||||
@ -45,7 +45,7 @@ class LitEma(nn.Module):
|
|||||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||||
else:
|
else:
|
||||||
assert not key in self.m_name2s_name
|
assert key not in self.m_name2s_name
|
||||||
|
|
||||||
def copy_to(self, model):
|
def copy_to(self, model):
|
||||||
m_param = dict(model.named_parameters())
|
m_param = dict(model.named_parameters())
|
||||||
@ -54,7 +54,7 @@ class LitEma(nn.Module):
|
|||||||
if m_param[key].requires_grad:
|
if m_param[key].requires_grad:
|
||||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||||
else:
|
else:
|
||||||
assert not key in self.m_name2s_name
|
assert key not in self.m_name2s_name
|
||||||
|
|
||||||
def store(self, parameters):
|
def store(self, parameters):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -71,7 +71,7 @@ def count_params(model, verbose=False):
|
|||||||
|
|
||||||
|
|
||||||
def instantiate_from_config(config):
|
def instantiate_from_config(config):
|
||||||
if not "target" in config:
|
if "target" not in config:
|
||||||
if config == '__is_first_stage__':
|
if config == '__is_first_stage__':
|
||||||
return None
|
return None
|
||||||
elif config == "__is_unconditional__":
|
elif config == "__is_unconditional__":
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1
|
|||||||
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
|
import comfy.ldm.lightricks.av_model
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
from comfy.ldm.cascade.stage_c import StageC
|
from comfy.ldm.cascade.stage_c import StageC
|
||||||
from comfy.ldm.cascade.stage_b import StageB
|
from comfy.ldm.cascade.stage_b import StageB
|
||||||
@ -946,7 +947,7 @@ class GenmoMochi(BaseModel):
|
|||||||
|
|
||||||
class LTXV(BaseModel):
|
class LTXV(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel)
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
@ -977,6 +978,60 @@ class LTXV(BaseModel):
|
|||||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
return latent_image
|
return latent_image
|
||||||
|
|
||||||
|
class LTXAV(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||||
|
|
||||||
|
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
|
|
||||||
|
audio_denoise_mask = None
|
||||||
|
if denoise_mask is not None and "latent_shapes" in kwargs:
|
||||||
|
denoise_mask = utils.unpack_latents(denoise_mask, kwargs["latent_shapes"])
|
||||||
|
if len(denoise_mask) > 1:
|
||||||
|
audio_denoise_mask = denoise_mask[1]
|
||||||
|
denoise_mask = denoise_mask[0]
|
||||||
|
|
||||||
|
if denoise_mask is not None:
|
||||||
|
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||||
|
|
||||||
|
if audio_denoise_mask is not None:
|
||||||
|
out["audio_denoise_mask"] = comfy.conds.CONDRegular(audio_denoise_mask)
|
||||||
|
|
||||||
|
keyframe_idxs = kwargs.get("keyframe_idxs", None)
|
||||||
|
if keyframe_idxs is not None:
|
||||||
|
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||||
|
|
||||||
|
latent_shapes = kwargs.get("latent_shapes", None)
|
||||||
|
if latent_shapes is not None:
|
||||||
|
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||||
|
v_timestep = timestep
|
||||||
|
a_timestep = timestep
|
||||||
|
|
||||||
|
if denoise_mask is not None:
|
||||||
|
v_timestep = self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
|
||||||
|
if audio_denoise_mask is not None:
|
||||||
|
a_timestep = self.diffusion_model.a_patchifier.patchify(((audio_denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (audio_denoise_mask.ndim - 1)))[:, :1, :, :1])[0]
|
||||||
|
|
||||||
|
return v_timestep, a_timestep
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
|
return latent_image
|
||||||
|
|
||||||
class HunyuanVideo(BaseModel):
|
class HunyuanVideo(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||||
|
|||||||
@ -305,7 +305,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "ltxv"
|
dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv"
|
||||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
|
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
|
||||||
dit_config["attention_head_dim"] = shape[0] // 32
|
dit_config["attention_head_dim"] = shape[0] // 32
|
||||||
|
|||||||
@ -123,6 +123,7 @@ except:
|
|||||||
try:
|
try:
|
||||||
import torch_npu # noqa: F401
|
import torch_npu # noqa: F401
|
||||||
_ = torch.npu.device_count()
|
_ = torch.npu.device_count()
|
||||||
|
torch_npu.npu.set_compile_mode(jit_compile = False)
|
||||||
npu_available = torch.npu.is_available()
|
npu_available = torch.npu.is_available()
|
||||||
except:
|
except:
|
||||||
npu_available = False
|
npu_available = False
|
||||||
@ -456,7 +457,7 @@ def module_size(module):
|
|||||||
sd = module.state_dict()
|
sd = module.state_dict()
|
||||||
for k in sd:
|
for k in sd:
|
||||||
t = sd[k]
|
t = sd[k]
|
||||||
module_mem += t.nelement() * t.element_size()
|
module_mem += t.nbytes
|
||||||
return module_mem
|
return module_mem
|
||||||
|
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
@ -1126,6 +1127,16 @@ if not args.disable_pinned_memory:
|
|||||||
|
|
||||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||||
|
|
||||||
|
def discard_cuda_async_error():
|
||||||
|
try:
|
||||||
|
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
|
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
|
_ = a + b
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
except torch.AcceleratorError:
|
||||||
|
#Dump it! We already know about it from the synchronous return
|
||||||
|
pass
|
||||||
|
|
||||||
def pin_memory(tensor):
|
def pin_memory(tensor):
|
||||||
global TOTAL_PINNED_MEMORY
|
global TOTAL_PINNED_MEMORY
|
||||||
if MAX_PINNED_MEMORY <= 0:
|
if MAX_PINNED_MEMORY <= 0:
|
||||||
@ -1146,7 +1157,7 @@ def pin_memory(tensor):
|
|||||||
if not tensor.is_contiguous():
|
if not tensor.is_contiguous():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
size = tensor.numel() * tensor.element_size()
|
size = tensor.nbytes
|
||||||
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -1158,6 +1169,9 @@ def pin_memory(tensor):
|
|||||||
PINNED_MEMORY[ptr] = size
|
PINNED_MEMORY[ptr] = size
|
||||||
TOTAL_PINNED_MEMORY += size
|
TOTAL_PINNED_MEMORY += size
|
||||||
return True
|
return True
|
||||||
|
else:
|
||||||
|
logging.warning("Pin error.")
|
||||||
|
discard_cuda_async_error()
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -1170,7 +1184,7 @@ def unpin_memory(tensor):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
ptr = tensor.data_ptr()
|
ptr = tensor.data_ptr()
|
||||||
size = tensor.numel() * tensor.element_size()
|
size = tensor.nbytes
|
||||||
|
|
||||||
size_stored = PINNED_MEMORY.get(ptr, None)
|
size_stored = PINNED_MEMORY.get(ptr, None)
|
||||||
if size_stored is None:
|
if size_stored is None:
|
||||||
@ -1186,6 +1200,9 @@ def unpin_memory(tensor):
|
|||||||
if len(PINNED_MEMORY) == 0:
|
if len(PINNED_MEMORY) == 0:
|
||||||
TOTAL_PINNED_MEMORY = 0
|
TOTAL_PINNED_MEMORY = 0
|
||||||
return True
|
return True
|
||||||
|
else:
|
||||||
|
logging.warning("Unpin error.")
|
||||||
|
discard_cuda_async_error()
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -1488,6 +1505,16 @@ def supports_fp8_compute(device=None):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def supports_nvfp4_compute(device=None):
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
if props.major < 10:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def extended_fp16_support():
|
def extended_fp16_support():
|
||||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||||
if torch_version_numeric < (2, 7):
|
if torch_version_numeric < (2, 7):
|
||||||
@ -1526,6 +1553,10 @@ def soft_empty_cache(force=False):
|
|||||||
def unload_all_models():
|
def unload_all_models():
|
||||||
free_memory(1e30, get_torch_device())
|
free_memory(1e30, get_torch_device())
|
||||||
|
|
||||||
|
def debug_memory_summary():
|
||||||
|
if is_amd() or is_nvidia():
|
||||||
|
return torch.cuda.memory.memory_summary()
|
||||||
|
return ""
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
#TODO: might be cleaner to put this somewhere else
|
||||||
import threading
|
import threading
|
||||||
|
|||||||
185
comfy/ops.py
185
comfy/ops.py
@ -79,7 +79,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
if input is not None:
|
if input is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
if isinstance(input, QuantizedTensor):
|
if isinstance(input, QuantizedTensor):
|
||||||
dtype = input._layout_params["orig_dtype"]
|
dtype = input.params.orig_dtype
|
||||||
else:
|
else:
|
||||||
dtype = input.dtype
|
dtype = input.dtype
|
||||||
if bias_dtype is None:
|
if bias_dtype is None:
|
||||||
@ -412,26 +412,34 @@ def fp8_linear(self, input):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
input_dtype = input.dtype
|
input_dtype = input.dtype
|
||||||
|
input_shape = input.shape
|
||||||
|
tensor_3d = input.ndim == 3
|
||||||
|
|
||||||
if input.ndim == 3 or input.ndim == 2:
|
if tensor_3d:
|
||||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
input = input.reshape(-1, input_shape[2])
|
||||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
|
||||||
|
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
if input.ndim != 2:
|
||||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
return None
|
||||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
|
|
||||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
input_fp8 = input.to(dtype).contiguous()
|
||||||
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
|
||||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input)
|
||||||
|
|
||||||
uncast_bias_weight(self, w, bias, offload_stream)
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
return o
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
|
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
|
||||||
|
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||||
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
|
|
||||||
return None
|
uncast_bias_weight(self, w, bias, offload_stream)
|
||||||
|
if tensor_3d:
|
||||||
|
o = o.reshape((input_shape[0], input_shape[1], w.shape[0]))
|
||||||
|
|
||||||
|
return o
|
||||||
|
|
||||||
class fp8_ops(manual_cast):
|
class fp8_ops(manual_cast):
|
||||||
class Linear(manual_cast.Linear):
|
class Linear(manual_cast.Linear):
|
||||||
@ -477,14 +485,20 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Mixed Precision Operations
|
# Mixed Precision Operations
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
from .quant_ops import (
|
||||||
|
QuantizedTensor,
|
||||||
|
QUANT_ALGOS,
|
||||||
|
TensorCoreFP8Layout,
|
||||||
|
get_layout_class,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||||
class MixedPrecisionOps(manual_cast):
|
class MixedPrecisionOps(manual_cast):
|
||||||
_quant_config = quant_config
|
_quant_config = quant_config
|
||||||
_compute_dtype = compute_dtype
|
_compute_dtype = compute_dtype
|
||||||
_full_precision_mm = full_precision_mm
|
_full_precision_mm = full_precision_mm
|
||||||
|
_disabled = disabled
|
||||||
|
|
||||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -497,21 +511,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if dtype is None:
|
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||||
dtype = MixedPrecisionOps._compute_dtype
|
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
self.factory_kwargs = {"device": device, "dtype": dtype}
|
|
||||||
|
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self._has_bias = bias
|
if bias:
|
||||||
|
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
self.tensor_class = None
|
self.tensor_class = None
|
||||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||||
|
self._full_precision_mm_config = False
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
|
||||||
|
key = f"{prefix}{param_name}"
|
||||||
|
value = state_dict.pop(key, None)
|
||||||
|
if value is not None:
|
||||||
|
value = value.to(device=device)
|
||||||
|
if dtype is not None:
|
||||||
|
value = value.view(dtype=dtype)
|
||||||
|
manually_loaded_keys.append(key)
|
||||||
|
return value
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
strict, missing_keys, unexpected_keys, error_msgs):
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
||||||
@ -529,49 +555,61 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||||
|
|
||||||
if layer_conf is None:
|
if layer_conf is None:
|
||||||
dtype = self.factory_kwargs["dtype"]
|
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
|
|
||||||
if dtype != MixedPrecisionOps._compute_dtype:
|
|
||||||
self.comfy_cast_weights = True
|
|
||||||
if self._has_bias:
|
|
||||||
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
else:
|
else:
|
||||||
self.quant_format = layer_conf.get("format", None)
|
self.quant_format = layer_conf.get("format", None)
|
||||||
|
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||||
if not self._full_precision_mm:
|
if not self._full_precision_mm:
|
||||||
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
|
self._full_precision_mm = self._full_precision_mm_config
|
||||||
|
|
||||||
|
if self.quant_format in MixedPrecisionOps._disabled:
|
||||||
|
self._full_precision_mm = True
|
||||||
|
|
||||||
if self.quant_format is None:
|
if self.quant_format is None:
|
||||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||||
|
|
||||||
qconfig = QUANT_ALGOS[self.quant_format]
|
qconfig = QUANT_ALGOS[self.quant_format]
|
||||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||||
|
layout_cls = get_layout_class(self.layout_type)
|
||||||
|
|
||||||
weight_scale_key = f"{prefix}weight_scale"
|
# Load format-specific parameters
|
||||||
scale = state_dict.pop(weight_scale_key, None)
|
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
|
||||||
if scale is not None:
|
# FP8: single tensor scale
|
||||||
scale = scale.to(device)
|
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||||
layout_params = {
|
|
||||||
'scale': scale,
|
|
||||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
|
||||||
'block_size': qconfig.get("group_size", None),
|
|
||||||
}
|
|
||||||
|
|
||||||
if scale is not None:
|
params = layout_cls.Params(
|
||||||
manually_loaded_keys.append(weight_scale_key)
|
scale=scale,
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=(self.out_features, self.in_features),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.quant_format == "nvfp4":
|
||||||
|
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||||
|
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||||
|
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||||
|
dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
if tensor_scale is None or block_scale is None:
|
||||||
|
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||||
|
|
||||||
|
params = layout_cls.Params(
|
||||||
|
scale=tensor_scale,
|
||||||
|
block_scale=block_scale,
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=(self.out_features, self.in_features),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||||
|
|
||||||
self.weight = torch.nn.Parameter(
|
self.weight = torch.nn.Parameter(
|
||||||
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
|
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||||
requires_grad=False
|
requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._has_bias:
|
|
||||||
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
for param_name in qconfig["parameters"]:
|
for param_name in qconfig["parameters"]:
|
||||||
|
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||||
|
continue # Already handled above
|
||||||
|
|
||||||
param_key = f"{prefix}{param_name}"
|
param_key = f"{prefix}{param_name}"
|
||||||
_v = state_dict.pop(param_key, None)
|
_v = state_dict.pop(param_key, None)
|
||||||
if _v is None:
|
if _v is None:
|
||||||
@ -588,9 +626,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||||
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
|
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
|
||||||
if isinstance(self.weight, QuantizedTensor):
|
if isinstance(self.weight, QuantizedTensor):
|
||||||
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
|
layout_cls = self.weight._layout_cls
|
||||||
|
|
||||||
|
# Check if it's any FP8 variant (E4M3 or E5M2)
|
||||||
|
if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"):
|
||||||
|
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
|
||||||
|
elif layout_cls == "TensorCoreNVFP4Layout":
|
||||||
|
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale
|
||||||
|
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
|
||||||
|
|
||||||
quant_conf = {"format": self.quant_format}
|
quant_conf = {"format": self.quant_format}
|
||||||
if self._full_precision_mm:
|
if self._full_precision_mm_config:
|
||||||
quant_conf["full_precision_matrix_mult"] = True
|
quant_conf["full_precision_matrix_mult"] = True
|
||||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||||
return sd
|
return sd
|
||||||
@ -607,12 +653,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
def forward(self, input, *args, **kwargs):
|
def forward(self, input, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
|
|
||||||
|
input_shape = input.shape
|
||||||
|
tensor_3d = input.ndim == 3
|
||||||
|
|
||||||
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||||
|
|
||||||
if (getattr(self, 'layout_type', None) is not None and
|
if (getattr(self, 'layout_type', None) is not None and
|
||||||
not isinstance(input, QuantizedTensor)):
|
not isinstance(input, QuantizedTensor)):
|
||||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
|
|
||||||
return self._forward(input, self.weight, self.bias)
|
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||||
|
if tensor_3d:
|
||||||
|
input = input.reshape(-1, input_shape[2])
|
||||||
|
|
||||||
|
if input.ndim != 2:
|
||||||
|
# Fall back to comfy_cast_weights for non-2D tensors
|
||||||
|
return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs)
|
||||||
|
|
||||||
|
# dtype is now implicit in the layout class
|
||||||
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None))
|
||||||
|
|
||||||
|
output = self._forward(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
# Reshape output back to 3D if input was 3D
|
||||||
|
if tensor_3d:
|
||||||
|
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||||
if isinstance(weight, QuantizedTensor):
|
if isinstance(weight, QuantizedTensor):
|
||||||
@ -622,7 +689,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||||
if getattr(self, 'layout_type', None) is not None:
|
if getattr(self, 'layout_type', None) is not None:
|
||||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
# dtype is now implicit in the layout class
|
||||||
|
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
|
||||||
else:
|
else:
|
||||||
weight = weight.to(self.weight.dtype)
|
weight = weight.to(self.weight.dtype)
|
||||||
if return_weight:
|
if return_weight:
|
||||||
@ -649,10 +717,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||||
|
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
||||||
|
|
||||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||||
logging.info("Using mixed precision operations")
|
logging.info("Using mixed precision operations")
|
||||||
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
disabled = set()
|
||||||
|
if not nvfp4_compute:
|
||||||
|
disabled.add("nvfp4")
|
||||||
|
if not fp8_compute:
|
||||||
|
disabled.add("float8_e4m3fn")
|
||||||
|
disabled.add("float8_e5m2")
|
||||||
|
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
fp8_compute and
|
fp8_compute and
|
||||||
|
|||||||
@ -1,580 +1,140 @@
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
from typing import Tuple, Dict
|
|
||||||
|
try:
|
||||||
|
import comfy_kitchen as ck
|
||||||
|
from comfy_kitchen.tensor import (
|
||||||
|
QuantizedTensor,
|
||||||
|
QuantizedLayout,
|
||||||
|
TensorCoreFP8Layout as _CKFp8Layout,
|
||||||
|
TensorCoreNVFP4Layout, # Direct import, no wrapper needed
|
||||||
|
register_layout_op,
|
||||||
|
register_layout_class,
|
||||||
|
get_layout_class,
|
||||||
|
)
|
||||||
|
_CK_AVAILABLE = True
|
||||||
|
if torch.version.cuda is None:
|
||||||
|
ck.registry.disable("cuda")
|
||||||
|
else:
|
||||||
|
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
|
||||||
|
if cuda_version < (13,):
|
||||||
|
ck.registry.disable("cuda")
|
||||||
|
|
||||||
|
ck.registry.disable("triton")
|
||||||
|
for k, v in ck.list_backends().items():
|
||||||
|
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
||||||
|
except ImportError as e:
|
||||||
|
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
|
||||||
|
_CK_AVAILABLE = False
|
||||||
|
|
||||||
|
class QuantizedTensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class _CKFp8Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TensorCoreNVFP4Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def register_layout_class(name, cls):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_layout_class(name):
|
||||||
|
return None
|
||||||
|
|
||||||
import comfy.float
|
import comfy.float
|
||||||
|
|
||||||
_LAYOUT_REGISTRY = {}
|
|
||||||
_GENERIC_UTILS = {}
|
|
||||||
|
|
||||||
|
|
||||||
def register_layout_op(torch_op, layout_type):
|
|
||||||
"""
|
|
||||||
Decorator to register a layout-specific operation handler.
|
|
||||||
Args:
|
|
||||||
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
|
|
||||||
layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
|
||||||
Example:
|
|
||||||
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
|
||||||
def fp8_linear(func, args, kwargs):
|
|
||||||
# FP8-specific linear implementation
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
def decorator(handler_func):
|
|
||||||
if torch_op not in _LAYOUT_REGISTRY:
|
|
||||||
_LAYOUT_REGISTRY[torch_op] = {}
|
|
||||||
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
|
|
||||||
return handler_func
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def register_generic_util(torch_op):
|
|
||||||
"""
|
|
||||||
Decorator to register a generic utility that works for all layouts.
|
|
||||||
Args:
|
|
||||||
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
|
||||||
|
|
||||||
Example:
|
|
||||||
@register_generic_util(torch.ops.aten.detach.default)
|
|
||||||
def generic_detach(func, args, kwargs):
|
|
||||||
# Works for any layout
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
def decorator(handler_func):
|
|
||||||
_GENERIC_UTILS[torch_op] = handler_func
|
|
||||||
return handler_func
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def _get_layout_from_args(args):
|
|
||||||
for arg in args:
|
|
||||||
if isinstance(arg, QuantizedTensor):
|
|
||||||
return arg._layout_type
|
|
||||||
elif isinstance(arg, (list, tuple)):
|
|
||||||
for item in arg:
|
|
||||||
if isinstance(item, QuantizedTensor):
|
|
||||||
return item._layout_type
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _move_layout_params_to_device(params, device):
|
|
||||||
new_params = {}
|
|
||||||
for k, v in params.items():
|
|
||||||
if isinstance(v, torch.Tensor):
|
|
||||||
new_params[k] = v.to(device=device)
|
|
||||||
else:
|
|
||||||
new_params[k] = v
|
|
||||||
return new_params
|
|
||||||
|
|
||||||
|
|
||||||
def _copy_layout_params(params):
|
|
||||||
new_params = {}
|
|
||||||
for k, v in params.items():
|
|
||||||
if isinstance(v, torch.Tensor):
|
|
||||||
new_params[k] = v.clone()
|
|
||||||
else:
|
|
||||||
new_params[k] = v
|
|
||||||
return new_params
|
|
||||||
|
|
||||||
def _copy_layout_params_inplace(src, dst, non_blocking=False):
|
|
||||||
for k, v in src.items():
|
|
||||||
if isinstance(v, torch.Tensor):
|
|
||||||
dst[k].copy_(v, non_blocking=non_blocking)
|
|
||||||
else:
|
|
||||||
dst[k] = v
|
|
||||||
|
|
||||||
class QuantizedLayout:
|
|
||||||
"""
|
|
||||||
Base class for quantization layouts.
|
|
||||||
|
|
||||||
A layout encapsulates the format-specific logic for quantization/dequantization
|
|
||||||
and provides a uniform interface for extracting raw tensors needed for computation.
|
|
||||||
|
|
||||||
New quantization formats should subclass this and implement the required methods.
|
|
||||||
"""
|
|
||||||
@classmethod
|
|
||||||
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
|
|
||||||
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
|
||||||
raise NotImplementedError("TensorLayout must implement dequantize()")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
|
||||||
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
|
||||||
|
|
||||||
|
|
||||||
class QuantizedTensor(torch.Tensor):
|
|
||||||
"""
|
|
||||||
Universal quantized tensor that works with any layout.
|
|
||||||
|
|
||||||
This tensor subclass uses a pluggable layout system to support multiple
|
|
||||||
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
|
||||||
|
|
||||||
The layout_type determines format-specific behavior, while common operations
|
|
||||||
(detach, clone, to) are handled generically.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
_qdata: The quantized tensor data
|
|
||||||
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
|
||||||
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __new__(cls, qdata, layout_type, layout_params):
|
|
||||||
"""
|
|
||||||
Create a quantized tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
qdata: The quantized data tensor
|
|
||||||
layout_type: Layout class (subclass of QuantizedLayout)
|
|
||||||
layout_params: Dict with layout-specific parameters
|
|
||||||
"""
|
|
||||||
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
|
||||||
|
|
||||||
def __init__(self, qdata, layout_type, layout_params):
|
|
||||||
self._qdata = qdata
|
|
||||||
self._layout_type = layout_type
|
|
||||||
self._layout_params = layout_params
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
layout_name = self._layout_type
|
|
||||||
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
|
||||||
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def layout_type(self):
|
|
||||||
return self._layout_type
|
|
||||||
|
|
||||||
def __tensor_flatten__(self):
|
|
||||||
"""
|
|
||||||
Tensor flattening protocol for proper device movement.
|
|
||||||
"""
|
|
||||||
inner_tensors = ["_qdata"]
|
|
||||||
ctx = {
|
|
||||||
"layout_type": self._layout_type,
|
|
||||||
}
|
|
||||||
|
|
||||||
tensor_params = {}
|
|
||||||
non_tensor_params = {}
|
|
||||||
for k, v in self._layout_params.items():
|
|
||||||
if isinstance(v, torch.Tensor):
|
|
||||||
tensor_params[k] = v
|
|
||||||
else:
|
|
||||||
non_tensor_params[k] = v
|
|
||||||
|
|
||||||
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
|
||||||
ctx["non_tensor_params"] = non_tensor_params
|
|
||||||
|
|
||||||
for k, v in tensor_params.items():
|
|
||||||
attr_name = f"_layout_param_{k}"
|
|
||||||
object.__setattr__(self, attr_name, v)
|
|
||||||
inner_tensors.append(attr_name)
|
|
||||||
|
|
||||||
return inner_tensors, ctx
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
|
||||||
"""
|
|
||||||
Tensor unflattening protocol for proper device movement.
|
|
||||||
Reconstructs the QuantizedTensor after device movement.
|
|
||||||
"""
|
|
||||||
layout_type = ctx["layout_type"]
|
|
||||||
layout_params = dict(ctx["non_tensor_params"])
|
|
||||||
|
|
||||||
for key in ctx["tensor_param_keys"]:
|
|
||||||
attr_name = f"_layout_param_{key}"
|
|
||||||
layout_params[key] = inner_tensors[attr_name]
|
|
||||||
|
|
||||||
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
|
||||||
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
|
|
||||||
return cls(qdata, layout_type, layout_params)
|
|
||||||
|
|
||||||
def dequantize(self) -> torch.Tensor:
|
|
||||||
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
||||||
kwargs = kwargs or {}
|
|
||||||
|
|
||||||
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
|
||||||
if func in _GENERIC_UTILS:
|
|
||||||
return _GENERIC_UTILS[func](func, args, kwargs)
|
|
||||||
|
|
||||||
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
|
||||||
layout_type = _get_layout_from_args(args)
|
|
||||||
if layout_type and func in _LAYOUT_REGISTRY:
|
|
||||||
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
|
||||||
if handler:
|
|
||||||
return handler(func, args, kwargs)
|
|
||||||
|
|
||||||
# Step 3: Fallback to dequantization
|
|
||||||
if isinstance(args[0] if args else None, QuantizedTensor):
|
|
||||||
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
|
||||||
return cls._dequant_and_fallback(func, args, kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _dequant_and_fallback(cls, func, args, kwargs):
|
|
||||||
def dequant_arg(arg):
|
|
||||||
if isinstance(arg, QuantizedTensor):
|
|
||||||
return arg.dequantize()
|
|
||||||
elif isinstance(arg, (list, tuple)):
|
|
||||||
return type(arg)(dequant_arg(a) for a in arg)
|
|
||||||
return arg
|
|
||||||
|
|
||||||
new_args = dequant_arg(args)
|
|
||||||
new_kwargs = dequant_arg(kwargs)
|
|
||||||
return func(*new_args, **new_kwargs)
|
|
||||||
|
|
||||||
def data_ptr(self):
|
|
||||||
return self._qdata.data_ptr()
|
|
||||||
|
|
||||||
def is_pinned(self):
|
|
||||||
return self._qdata.is_pinned()
|
|
||||||
|
|
||||||
def is_contiguous(self, *arg, **kwargs):
|
|
||||||
return self._qdata.is_contiguous(*arg, **kwargs)
|
|
||||||
|
|
||||||
def storage(self):
|
|
||||||
return self._qdata.storage()
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Generic Utilities (Layout-Agnostic Operations)
|
# FP8 Layouts with Comfy-Specific Extensions
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
def _create_transformed_qtensor(qt, transform_fn):
|
class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||||
new_data = transform_fn(qt._qdata)
|
FP8_DTYPE = None # Must be overridden in subclass
|
||||||
new_params = _copy_layout_params(qt._layout_params)
|
|
||||||
return QuantizedTensor(new_data, qt._layout_type, new_params)
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
|
||||||
if target_layout is not None and target_layout != torch.strided:
|
|
||||||
logging.warning(
|
|
||||||
f"QuantizedTensor: layout change requested to {target_layout}, "
|
|
||||||
f"but not supported. Ignoring layout."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle device transfer
|
|
||||||
current_device = qt._qdata.device
|
|
||||||
if target_device is not None:
|
|
||||||
# Normalize device for comparison
|
|
||||||
if isinstance(target_device, str):
|
|
||||||
target_device = torch.device(target_device)
|
|
||||||
if isinstance(current_device, str):
|
|
||||||
current_device = torch.device(current_device)
|
|
||||||
|
|
||||||
if target_device != current_device:
|
|
||||||
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
|
||||||
new_q_data = qt._qdata.to(device=target_device)
|
|
||||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
|
||||||
if target_dtype is not None:
|
|
||||||
new_params["orig_dtype"] = target_dtype
|
|
||||||
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
|
||||||
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
|
||||||
return new_qt
|
|
||||||
|
|
||||||
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
|
||||||
return qt
|
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten.detach.default)
|
|
||||||
def generic_detach(func, args, kwargs):
|
|
||||||
"""Detach operation - creates a detached copy of the quantized tensor."""
|
|
||||||
qt = args[0]
|
|
||||||
if isinstance(qt, QuantizedTensor):
|
|
||||||
return _create_transformed_qtensor(qt, lambda x: x.detach())
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten.clone.default)
|
|
||||||
def generic_clone(func, args, kwargs):
|
|
||||||
"""Clone operation - creates a deep copy of the quantized tensor."""
|
|
||||||
qt = args[0]
|
|
||||||
if isinstance(qt, QuantizedTensor):
|
|
||||||
return _create_transformed_qtensor(qt, lambda x: x.clone())
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten._to_copy.default)
|
|
||||||
def generic_to_copy(func, args, kwargs):
|
|
||||||
"""Device/dtype transfer operation - handles .to(device) calls."""
|
|
||||||
qt = args[0]
|
|
||||||
if isinstance(qt, QuantizedTensor):
|
|
||||||
return _handle_device_transfer(
|
|
||||||
qt,
|
|
||||||
target_device=kwargs.get('device', None),
|
|
||||||
target_dtype=kwargs.get('dtype', None),
|
|
||||||
op_name="_to_copy"
|
|
||||||
)
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten.to.dtype_layout)
|
|
||||||
def generic_to_dtype_layout(func, args, kwargs):
|
|
||||||
"""Handle .to(device) calls using the dtype_layout variant."""
|
|
||||||
qt = args[0]
|
|
||||||
if isinstance(qt, QuantizedTensor):
|
|
||||||
return _handle_device_transfer(
|
|
||||||
qt,
|
|
||||||
target_device=kwargs.get('device', None),
|
|
||||||
target_dtype=kwargs.get('dtype', None),
|
|
||||||
target_layout=kwargs.get('layout', None),
|
|
||||||
op_name="to"
|
|
||||||
)
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten.copy_.default)
|
|
||||||
def generic_copy_(func, args, kwargs):
|
|
||||||
qt_dest = args[0]
|
|
||||||
src = args[1]
|
|
||||||
non_blocking = args[2] if len(args) > 2 else False
|
|
||||||
if isinstance(qt_dest, QuantizedTensor):
|
|
||||||
if isinstance(src, QuantizedTensor):
|
|
||||||
# Copy from another quantized tensor
|
|
||||||
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
|
||||||
qt_dest._layout_type = src._layout_type
|
|
||||||
orig_dtype = qt_dest._layout_params["orig_dtype"]
|
|
||||||
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
|
||||||
qt_dest._layout_params["orig_dtype"] = orig_dtype
|
|
||||||
else:
|
|
||||||
# Copy from regular tensor - just copy raw data
|
|
||||||
qt_dest._qdata.copy_(src)
|
|
||||||
return qt_dest
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten.to.dtype)
|
|
||||||
def generic_to_dtype(func, args, kwargs):
|
|
||||||
"""Handle .to(dtype) calls - dtype conversion only."""
|
|
||||||
src = args[0]
|
|
||||||
if isinstance(src, QuantizedTensor):
|
|
||||||
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
|
|
||||||
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
|
|
||||||
src._layout_params["orig_dtype"] = target_dtype
|
|
||||||
return src
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
|
||||||
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten.empty_like.default)
|
|
||||||
def generic_empty_like(func, args, kwargs):
|
|
||||||
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
|
|
||||||
qt = args[0]
|
|
||||||
if isinstance(qt, QuantizedTensor):
|
|
||||||
# Create empty tensor with same shape and dtype as the quantized data
|
|
||||||
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
|
|
||||||
new_qdata = torch.empty_like(qt._qdata, **kwargs)
|
|
||||||
|
|
||||||
# Handle device transfer for layout params
|
|
||||||
target_device = kwargs.get('device', new_qdata.device)
|
|
||||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
|
||||||
|
|
||||||
# Update orig_dtype if dtype is specified
|
|
||||||
new_params['orig_dtype'] = hp_dtype
|
|
||||||
|
|
||||||
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# FP8 Layout + Operation Handlers
|
|
||||||
# ==============================================================================
|
|
||||||
class TensorCoreFP8Layout(QuantizedLayout):
|
|
||||||
"""
|
|
||||||
Storage format:
|
|
||||||
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
|
||||||
- scale: Scalar tensor (float32) for dequantization
|
|
||||||
- orig_dtype: Original dtype before quantization (for casting back)
|
|
||||||
"""
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||||
|
if cls.FP8_DTYPE is None:
|
||||||
|
raise NotImplementedError(f"{cls.__name__} must define FP8_DTYPE")
|
||||||
|
|
||||||
orig_dtype = tensor.dtype
|
orig_dtype = tensor.dtype
|
||||||
|
orig_shape = tuple(tensor.shape)
|
||||||
|
|
||||||
if isinstance(scale, str) and scale == "recalculate":
|
if isinstance(scale, str) and scale == "recalculate":
|
||||||
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max
|
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max
|
||||||
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
|
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
|
||||||
tensor_info = torch.finfo(tensor.dtype)
|
tensor_info = torch.finfo(tensor.dtype)
|
||||||
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
|
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
|
||||||
|
|
||||||
if scale is not None:
|
if scale is None:
|
||||||
if not isinstance(scale, torch.Tensor):
|
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
||||||
scale = torch.tensor(scale)
|
if not isinstance(scale, torch.Tensor):
|
||||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
if stochastic_rounding > 0:
|
||||||
if inplace_ops:
|
if inplace_ops:
|
||||||
tensor *= (1.0 / scale).to(tensor.dtype)
|
tensor *= (1.0 / scale).to(tensor.dtype)
|
||||||
else:
|
else:
|
||||||
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
||||||
|
qdata = comfy.float.stochastic_rounding(tensor, dtype=cls.FP8_DTYPE, seed=stochastic_rounding)
|
||||||
else:
|
else:
|
||||||
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE)
|
||||||
|
|
||||||
if stochastic_rounding > 0:
|
params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape)
|
||||||
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
return qdata, params
|
||||||
else:
|
|
||||||
lp_amax = torch.finfo(dtype).max
|
|
||||||
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
|
|
||||||
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
|
|
||||||
|
|
||||||
layout_params = {
|
|
||||||
'scale': scale,
|
|
||||||
'orig_dtype': orig_dtype
|
|
||||||
}
|
|
||||||
return tensor, layout_params
|
|
||||||
|
|
||||||
@staticmethod
|
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
|
||||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
FP8_DTYPE = torch.float8_e4m3fn
|
||||||
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
|
||||||
plain_tensor.mul_(scale)
|
|
||||||
return plain_tensor
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_plain_tensors(cls, qtensor):
|
class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
|
||||||
return qtensor._qdata, qtensor._layout_params['scale']
|
FP8_DTYPE = torch.float8_e5m2
|
||||||
|
|
||||||
|
|
||||||
|
# Backward compatibility alias - default to E4M3
|
||||||
|
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Registry
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
||||||
|
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||||
|
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||||
|
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||||
|
|
||||||
QUANT_ALGOS = {
|
QUANT_ALGOS = {
|
||||||
"float8_e4m3fn": {
|
"float8_e4m3fn": {
|
||||||
"storage_t": torch.float8_e4m3fn,
|
"storage_t": torch.float8_e4m3fn,
|
||||||
"parameters": {"weight_scale", "input_scale"},
|
"parameters": {"weight_scale", "input_scale"},
|
||||||
"comfy_tensor_layout": "TensorCoreFP8Layout",
|
"comfy_tensor_layout": "TensorCoreFP8E4M3Layout",
|
||||||
|
},
|
||||||
|
"float8_e5m2": {
|
||||||
|
"storage_t": torch.float8_e5m2,
|
||||||
|
"parameters": {"weight_scale", "input_scale"},
|
||||||
|
"comfy_tensor_layout": "TensorCoreFP8E5M2Layout",
|
||||||
|
},
|
||||||
|
"nvfp4": {
|
||||||
|
"storage_t": torch.uint8,
|
||||||
|
"parameters": {"weight_scale", "weight_scale_2", "input_scale"},
|
||||||
|
"comfy_tensor_layout": "TensorCoreNVFP4Layout",
|
||||||
|
"group_size": 16,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
LAYOUTS = {
|
|
||||||
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Re-exports for backward compatibility
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
|
__all__ = [
|
||||||
def fp8_linear(func, args, kwargs):
|
"QuantizedTensor",
|
||||||
input_tensor = args[0]
|
"QuantizedLayout",
|
||||||
weight = args[1]
|
"TensorCoreFP8Layout",
|
||||||
bias = args[2] if len(args) > 2 else None
|
"TensorCoreFP8E4M3Layout",
|
||||||
|
"TensorCoreFP8E5M2Layout",
|
||||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
"TensorCoreNVFP4Layout",
|
||||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
"QUANT_ALGOS",
|
||||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
"register_layout_op",
|
||||||
|
]
|
||||||
out_dtype = kwargs.get("out_dtype")
|
|
||||||
if out_dtype is None:
|
|
||||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
|
||||||
|
|
||||||
weight_t = plain_weight.t()
|
|
||||||
|
|
||||||
tensor_2d = False
|
|
||||||
if len(plain_input.shape) == 2:
|
|
||||||
tensor_2d = True
|
|
||||||
plain_input = plain_input.unsqueeze(1)
|
|
||||||
|
|
||||||
input_shape = plain_input.shape
|
|
||||||
if len(input_shape) != 3:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
output = torch._scaled_mm(
|
|
||||||
plain_input.reshape(-1, input_shape[2]).contiguous(),
|
|
||||||
weight_t,
|
|
||||||
bias=bias,
|
|
||||||
scale_a=scale_a,
|
|
||||||
scale_b=scale_b,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
|
||||||
output = output[0]
|
|
||||||
|
|
||||||
if not tensor_2d:
|
|
||||||
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
|
||||||
|
|
||||||
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
||||||
output_scale = scale_a * scale_b
|
|
||||||
output_params = {
|
|
||||||
'scale': output_scale,
|
|
||||||
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
|
||||||
}
|
|
||||||
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
|
|
||||||
else:
|
|
||||||
return output
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
|
||||||
|
|
||||||
# Case 2: DQ Fallback
|
|
||||||
if isinstance(weight, QuantizedTensor):
|
|
||||||
weight = weight.dequantize()
|
|
||||||
if isinstance(input_tensor, QuantizedTensor):
|
|
||||||
input_tensor = input_tensor.dequantize()
|
|
||||||
|
|
||||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
|
||||||
|
|
||||||
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
|
|
||||||
if out_dtype is None:
|
|
||||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
|
||||||
|
|
||||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
|
||||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
|
||||||
|
|
||||||
output = torch._scaled_mm(
|
|
||||||
plain_input.contiguous(),
|
|
||||||
plain_weight,
|
|
||||||
bias=bias,
|
|
||||||
scale_a=scale_a,
|
|
||||||
scale_b=scale_b,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
|
||||||
output = output[0]
|
|
||||||
return output
|
|
||||||
|
|
||||||
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
|
|
||||||
def fp8_addmm(func, args, kwargs):
|
|
||||||
input_tensor = args[1]
|
|
||||||
weight = args[2]
|
|
||||||
bias = args[0]
|
|
||||||
|
|
||||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
|
||||||
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
|
|
||||||
|
|
||||||
a = list(args)
|
|
||||||
if isinstance(args[0], QuantizedTensor):
|
|
||||||
a[0] = args[0].dequantize()
|
|
||||||
if isinstance(args[1], QuantizedTensor):
|
|
||||||
a[1] = args[1].dequantize()
|
|
||||||
if isinstance(args[2], QuantizedTensor):
|
|
||||||
a[2] = args[2].dequantize()
|
|
||||||
|
|
||||||
return func(*a, **kwargs)
|
|
||||||
|
|
||||||
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
|
|
||||||
def fp8_mm(func, args, kwargs):
|
|
||||||
input_tensor = args[0]
|
|
||||||
weight = args[1]
|
|
||||||
|
|
||||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
|
||||||
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
|
|
||||||
|
|
||||||
a = list(args)
|
|
||||||
if isinstance(args[0], QuantizedTensor):
|
|
||||||
a[0] = args[0].dequantize()
|
|
||||||
if isinstance(args[1], QuantizedTensor):
|
|
||||||
a[1] = args[1].dequantize()
|
|
||||||
return func(*a, **kwargs)
|
|
||||||
|
|
||||||
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
|
||||||
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
|
||||||
def fp8_func(func, args, kwargs):
|
|
||||||
input_tensor = args[0]
|
|
||||||
if isinstance(input_tensor, QuantizedTensor):
|
|
||||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
|
||||||
ar = list(args)
|
|
||||||
ar[0] = plain_input
|
|
||||||
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|||||||
@ -1041,7 +1041,8 @@ class TEModel(Enum):
|
|||||||
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
||||||
QWEN3_4B = 16
|
QWEN3_4B = 16
|
||||||
QWEN3_2B = 17
|
QWEN3_2B = 17
|
||||||
JINA_CLIP_2 = 18
|
GEMMA_3_12B = 18
|
||||||
|
JINA_CLIP_2 = 19
|
||||||
|
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
@ -1067,6 +1068,8 @@ def detect_te_model(sd):
|
|||||||
return TEModel.BYT5_SMALL_GLYPH
|
return TEModel.BYT5_SMALL_GLYPH
|
||||||
return TEModel.T5_BASE
|
return TEModel.T5_BASE
|
||||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||||
|
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||||
|
return TEModel.GEMMA_3_12B
|
||||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||||
return TEModel.GEMMA_3_4B
|
return TEModel.GEMMA_3_4B
|
||||||
return TEModel.GEMMA_2_2B
|
return TEModel.GEMMA_2_2B
|
||||||
@ -1271,6 +1274,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
|
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
|
||||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||||
|
elif clip_type == CLIPType.LTXV:
|
||||||
|
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
|
||||||
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif clip_type == CLIPType.NEWBIE:
|
elif clip_type == CLIPType.NEWBIE:
|
||||||
clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer
|
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer
|
||||||
|
|||||||
@ -836,6 +836,21 @@ class LTXV(supported_models_base.BASE):
|
|||||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
||||||
|
|
||||||
|
class LTXAV(LTXV):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "ltxav",
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.LTXAV
|
||||||
|
|
||||||
|
def __init__(self, unet_config):
|
||||||
|
super().__init__(unet_config)
|
||||||
|
self.memory_usage_factor = 0.055 # TODO
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.LTXAV(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class HunyuanVideo(supported_models_base.BASE):
|
class HunyuanVideo(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan_video",
|
"image_model": "hunyuan_video",
|
||||||
@ -1536,6 +1551,6 @@ class Kandinsky5Image(Kandinsky5):
|
|||||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -154,7 +154,8 @@ class TAEHV(nn.Module):
|
|||||||
self._show_progress_bar = value
|
self._show_progress_bar = value
|
||||||
|
|
||||||
def encode(self, x, **kwargs):
|
def encode(self, x, **kwargs):
|
||||||
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
|
if self.patch_size > 1:
|
||||||
|
x = F.pixel_unshuffle(x, self.patch_size)
|
||||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||||
if x.shape[1] % 4 != 0:
|
if x.shape[1] % 4 != 0:
|
||||||
# pad at end to multiple of 4
|
# pad at end to multiple of 4
|
||||||
@ -167,5 +168,6 @@ class TAEHV(nn.Module):
|
|||||||
def decode(self, x, **kwargs):
|
def decode(self, x, **kwargs):
|
||||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||||
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
|
if self.patch_size > 1:
|
||||||
|
x = F.pixel_shuffle(x, self.patch_size)
|
||||||
return x[:, self.frames_to_trim:].movedim(2, 1)
|
return x[:, self.frames_to_trim:].movedim(2, 1)
|
||||||
|
|||||||
@ -7,8 +7,8 @@ import math
|
|||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.clip_model
|
||||||
|
|
||||||
import comfy.model_management
|
|
||||||
from . import qwen_vl
|
from . import qwen_vl
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -189,6 +189,31 @@ class Gemma3_4B_Config:
|
|||||||
rope_scale = [8.0, 1.0]
|
rope_scale = [8.0, 1.0]
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Gemma3_12B_Config:
|
||||||
|
vocab_size: int = 262208
|
||||||
|
hidden_size: int = 3840
|
||||||
|
intermediate_size: int = 15360
|
||||||
|
num_hidden_layers: int = 48
|
||||||
|
num_attention_heads: int = 16
|
||||||
|
num_key_value_heads: int = 8
|
||||||
|
max_position_embeddings: int = 131072
|
||||||
|
rms_norm_eps: float = 1e-6
|
||||||
|
rope_theta = [1000000.0, 10000.0]
|
||||||
|
transformer_type: str = "gemma3"
|
||||||
|
head_dim = 256
|
||||||
|
rms_norm_add = True
|
||||||
|
mlp_activation = "gelu_pytorch_tanh"
|
||||||
|
qkv_bias = False
|
||||||
|
rope_dims = None
|
||||||
|
q_norm = "gemma3"
|
||||||
|
k_norm = "gemma3"
|
||||||
|
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||||
|
rope_scale = [8.0, 1.0]
|
||||||
|
final_norm: bool = True
|
||||||
|
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
||||||
|
mm_tokens_per_image = 256
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -521,6 +546,41 @@ class Llama2_(nn.Module):
|
|||||||
|
|
||||||
return x, intermediate
|
return x, intermediate
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3MultiModalProjector(torch.nn.Module):
|
||||||
|
def __init__(self, config, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.mm_input_projection_weight = nn.Parameter(
|
||||||
|
torch.empty(config.vision_config["hidden_size"], config.hidden_size, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mm_soft_emb_norm = RMSNorm(config.vision_config["hidden_size"], eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.patches_per_image = int(config.vision_config["image_size"] // config.vision_config["patch_size"])
|
||||||
|
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
|
||||||
|
self.kernel_size = self.patches_per_image // self.tokens_per_side
|
||||||
|
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
|
||||||
|
|
||||||
|
def forward(self, vision_outputs: torch.Tensor):
|
||||||
|
batch_size, _, seq_length = vision_outputs.shape
|
||||||
|
|
||||||
|
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
|
||||||
|
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
|
||||||
|
batch_size, seq_length, self.patches_per_image, self.patches_per_image
|
||||||
|
)
|
||||||
|
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
|
||||||
|
|
||||||
|
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
|
||||||
|
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
|
||||||
|
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
|
||||||
|
|
||||||
|
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
|
||||||
|
|
||||||
|
projected_vision_outputs = torch.matmul(normed_vision_outputs, comfy.model_management.cast_to_device(self.mm_input_projection_weight, device=normed_vision_outputs.device, dtype=normed_vision_outputs.dtype))
|
||||||
|
return projected_vision_outputs.type_as(vision_outputs)
|
||||||
|
|
||||||
|
|
||||||
class BaseLlama:
|
class BaseLlama:
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.model.embed_tokens
|
return self.model.embed_tokens
|
||||||
@ -637,3 +697,21 @@ class Gemma3_4B(BaseLlama, torch.nn.Module):
|
|||||||
|
|
||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
class Gemma3_12B(BaseLlama, torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
config = Gemma3_12B_Config(**config_dict)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
|
||||||
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
|
||||||
|
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.image_size = config.vision_config["image_size"]
|
||||||
|
|
||||||
|
def preprocess_embed(self, embed, device):
|
||||||
|
if embed["type"] == "image":
|
||||||
|
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
|
||||||
|
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
|
||||||
|
return None, None
|
||||||
|
|||||||
@ -1,7 +1,11 @@
|
|||||||
from comfy import sd1_clip
|
from comfy import sd1_clip
|
||||||
import os
|
import os
|
||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
|
from .spiece_tokenizer import SPieceTokenizer
|
||||||
import comfy.text_encoders.genmo
|
import comfy.text_encoders.genmo
|
||||||
|
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||||
|
import torch
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -16,3 +20,112 @@ class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
|
|
||||||
def ltxv_te(*args, **kwargs):
|
def ltxv_te(*args, **kwargs):
|
||||||
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
|
||||||
|
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
|
||||||
|
|
||||||
|
class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs):
|
||||||
|
text = llama_template.format(text)
|
||||||
|
text_tokens = super().tokenize_with_weights(text, return_word_ids)
|
||||||
|
embed_count = 0
|
||||||
|
for k in text_tokens:
|
||||||
|
tt = text_tokens[k]
|
||||||
|
for r in tt:
|
||||||
|
for i in range(len(r)):
|
||||||
|
if r[i][0] == 262144:
|
||||||
|
if image_embeds is not None and embed_count < image_embeds.shape[0]:
|
||||||
|
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
|
||||||
|
embed_count += 1
|
||||||
|
return text_tokens
|
||||||
|
|
||||||
|
class LTXAVTEModel(torch.nn.Module):
|
||||||
|
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__()
|
||||||
|
self.dtypes = set()
|
||||||
|
self.dtypes.add(dtype)
|
||||||
|
|
||||||
|
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
||||||
|
self.dtypes.add(dtype_llama)
|
||||||
|
|
||||||
|
operations = self.gemma3_12b.operations # TODO
|
||||||
|
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||||
|
split_rope=True,
|
||||||
|
double_precision_rope=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.video_embeddings_connector = Embeddings1DConnector(
|
||||||
|
split_rope=True,
|
||||||
|
double_precision_rope=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_clip_options(self, options):
|
||||||
|
self.execution_device = options.get("execution_device", self.execution_device)
|
||||||
|
self.gemma3_12b.set_clip_options(options)
|
||||||
|
|
||||||
|
def reset_clip_options(self):
|
||||||
|
self.gemma3_12b.reset_clip_options()
|
||||||
|
self.execution_device = None
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
||||||
|
|
||||||
|
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
||||||
|
out_device = out.device
|
||||||
|
out = out.movedim(1, -1).to(self.execution_device)
|
||||||
|
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
||||||
|
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||||
|
out = self.text_embedding_projection(out)
|
||||||
|
out_vid = self.video_embeddings_connector(out)[0]
|
||||||
|
out_audio = self.audio_embeddings_connector(out)[0]
|
||||||
|
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||||
|
|
||||||
|
return out.to(out_device), pooled
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||||
|
return self.gemma3_12b.load_sd(sd)
|
||||||
|
else:
|
||||||
|
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
|
||||||
|
if len(sdo) == 0:
|
||||||
|
sdo = sd
|
||||||
|
|
||||||
|
return self.load_state_dict(sdo, strict=False)
|
||||||
|
|
||||||
|
|
||||||
|
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
class LTXAVTEModel_(LTXAVTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return LTXAVTEModel_
|
||||||
|
|||||||
@ -1198,7 +1198,7 @@ def unpack_latents(combined_latent, latent_shapes):
|
|||||||
combined_latent = combined_latent[:, :, cut:]
|
combined_latent = combined_latent[:, :, cut:]
|
||||||
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
|
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
|
||||||
else:
|
else:
|
||||||
output_tensors = combined_latent
|
output_tensors = [combined_latent]
|
||||||
return output_tensors
|
return output_tensors
|
||||||
|
|
||||||
def detect_layer_quantization(state_dict, prefix):
|
def detect_layer_quantization(state_dict, prefix):
|
||||||
@ -1230,6 +1230,8 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
|||||||
out_sd = {}
|
out_sd = {}
|
||||||
layers = {}
|
layers = {}
|
||||||
for k in list(state_dict.keys()):
|
for k in list(state_dict.keys()):
|
||||||
|
if k == scaled_fp8_key:
|
||||||
|
continue
|
||||||
if not k.startswith(model_prefix):
|
if not k.startswith(model_prefix):
|
||||||
out_sd[k] = state_dict[k]
|
out_sd[k] = state_dict[k]
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -10,7 +10,6 @@ from ._input_impl import VideoFromFile, VideoFromComponents
|
|||||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||||
from . import _io_public as io
|
from . import _io_public as io
|
||||||
from . import _ui_public as ui
|
from . import _ui_public as ui
|
||||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
|
||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|||||||
@ -26,7 +26,6 @@ if TYPE_CHECKING:
|
|||||||
from comfy_api.input import VideoInput
|
from comfy_api.input import VideoInput
|
||||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||||
prune_dict, shallow_clone_class)
|
prune_dict, shallow_clone_class)
|
||||||
from ._resources import Resources, ResourcesLocal
|
|
||||||
from comfy_execution.graph_utils import ExecutionBlocker
|
from comfy_execution.graph_utils import ExecutionBlocker
|
||||||
from ._util import MESH, VOXEL, SVG as _SVG
|
from ._util import MESH, VOXEL, SVG as _SVG
|
||||||
|
|
||||||
@ -76,16 +75,6 @@ class NumberDisplay(str, Enum):
|
|||||||
slider = "slider"
|
slider = "slider"
|
||||||
|
|
||||||
|
|
||||||
class _StringIOType(str):
|
|
||||||
def __ne__(self, value: object) -> bool:
|
|
||||||
if self == "*" or value == "*":
|
|
||||||
return False
|
|
||||||
if not isinstance(value, str):
|
|
||||||
return True
|
|
||||||
a = frozenset(self.split(","))
|
|
||||||
b = frozenset(value.split(","))
|
|
||||||
return not (b.issubset(a) or a.issubset(b))
|
|
||||||
|
|
||||||
class _ComfyType(ABC):
|
class _ComfyType(ABC):
|
||||||
Type = Any
|
Type = Any
|
||||||
io_type: str = None
|
io_type: str = None
|
||||||
@ -125,8 +114,7 @@ def comfytype(io_type: str, **kwargs):
|
|||||||
new_cls.__module__ = cls.__module__
|
new_cls.__module__ = cls.__module__
|
||||||
new_cls.__doc__ = cls.__doc__
|
new_cls.__doc__ = cls.__doc__
|
||||||
# assign ComfyType attributes, if needed
|
# assign ComfyType attributes, if needed
|
||||||
# NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details)
|
new_cls.io_type = io_type
|
||||||
new_cls.io_type = _StringIOType(io_type)
|
|
||||||
if hasattr(new_cls, "Input") and new_cls.Input is not None:
|
if hasattr(new_cls, "Input") and new_cls.Input is not None:
|
||||||
new_cls.Input.Parent = new_cls
|
new_cls.Input.Parent = new_cls
|
||||||
if hasattr(new_cls, "Output") and new_cls.Output is not None:
|
if hasattr(new_cls, "Output") and new_cls.Output is not None:
|
||||||
@ -165,7 +153,7 @@ class Input(_IO_V3):
|
|||||||
'''
|
'''
|
||||||
Base class for a V3 Input.
|
Base class for a V3 Input.
|
||||||
'''
|
'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.id = id
|
self.id = id
|
||||||
self.display_name = display_name
|
self.display_name = display_name
|
||||||
@ -173,6 +161,7 @@ class Input(_IO_V3):
|
|||||||
self.tooltip = tooltip
|
self.tooltip = tooltip
|
||||||
self.lazy = lazy
|
self.lazy = lazy
|
||||||
self.extra_dict = extra_dict if extra_dict is not None else {}
|
self.extra_dict = extra_dict if extra_dict is not None else {}
|
||||||
|
self.rawLink = raw_link
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return prune_dict({
|
return prune_dict({
|
||||||
@ -180,10 +169,11 @@ class Input(_IO_V3):
|
|||||||
"optional": self.optional,
|
"optional": self.optional,
|
||||||
"tooltip": self.tooltip,
|
"tooltip": self.tooltip,
|
||||||
"lazy": self.lazy,
|
"lazy": self.lazy,
|
||||||
|
"rawLink": self.rawLink,
|
||||||
}) | prune_dict(self.extra_dict)
|
}) | prune_dict(self.extra_dict)
|
||||||
|
|
||||||
def get_io_type(self):
|
def get_io_type(self):
|
||||||
return _StringIOType(self.io_type)
|
return self.io_type
|
||||||
|
|
||||||
def get_all(self) -> list[Input]:
|
def get_all(self) -> list[Input]:
|
||||||
return [self]
|
return [self]
|
||||||
@ -194,8 +184,8 @@ class WidgetInput(Input):
|
|||||||
'''
|
'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: Any=None,
|
default: Any=None,
|
||||||
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None):
|
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||||
self.default = default
|
self.default = default
|
||||||
self.socketless = socketless
|
self.socketless = socketless
|
||||||
self.widget_type = widget_type
|
self.widget_type = widget_type
|
||||||
@ -217,13 +207,14 @@ class Output(_IO_V3):
|
|||||||
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
|
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
|
||||||
is_output_list=False):
|
is_output_list=False):
|
||||||
self.id = id
|
self.id = id
|
||||||
self.display_name = display_name
|
self.display_name = display_name if display_name else id
|
||||||
self.tooltip = tooltip
|
self.tooltip = tooltip
|
||||||
self.is_output_list = is_output_list
|
self.is_output_list = is_output_list
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
|
display_name = self.display_name if self.display_name else self.id
|
||||||
return prune_dict({
|
return prune_dict({
|
||||||
"display_name": self.display_name,
|
"display_name": display_name,
|
||||||
"tooltip": self.tooltip,
|
"tooltip": self.tooltip,
|
||||||
"is_output_list": self.is_output_list,
|
"is_output_list": self.is_output_list,
|
||||||
})
|
})
|
||||||
@ -251,8 +242,8 @@ class Boolean(ComfyTypeIO):
|
|||||||
'''Boolean input.'''
|
'''Boolean input.'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: bool=None, label_on: str=None, label_off: str=None,
|
default: bool=None, label_on: str=None, label_off: str=None,
|
||||||
socketless: bool=None, force_input: bool=None):
|
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||||
self.label_on = label_on
|
self.label_on = label_on
|
||||||
self.label_off = label_off
|
self.label_off = label_off
|
||||||
self.default: bool
|
self.default: bool
|
||||||
@ -271,8 +262,8 @@ class Int(ComfyTypeIO):
|
|||||||
'''Integer input.'''
|
'''Integer input.'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
||||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None):
|
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||||
self.min = min
|
self.min = min
|
||||||
self.max = max
|
self.max = max
|
||||||
self.step = step
|
self.step = step
|
||||||
@ -297,8 +288,8 @@ class Float(ComfyTypeIO):
|
|||||||
'''Float input.'''
|
'''Float input.'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
||||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None):
|
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||||
self.min = min
|
self.min = min
|
||||||
self.max = max
|
self.max = max
|
||||||
self.step = step
|
self.step = step
|
||||||
@ -323,8 +314,8 @@ class String(ComfyTypeIO):
|
|||||||
'''String input.'''
|
'''String input.'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
|
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
|
||||||
socketless: bool=None, force_input: bool=None):
|
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||||
self.multiline = multiline
|
self.multiline = multiline
|
||||||
self.placeholder = placeholder
|
self.placeholder = placeholder
|
||||||
self.dynamic_prompts = dynamic_prompts
|
self.dynamic_prompts = dynamic_prompts
|
||||||
@ -357,12 +348,14 @@ class Combo(ComfyTypeIO):
|
|||||||
image_folder: FolderType=None,
|
image_folder: FolderType=None,
|
||||||
remote: RemoteOptions=None,
|
remote: RemoteOptions=None,
|
||||||
socketless: bool=None,
|
socketless: bool=None,
|
||||||
|
extra_dict=None,
|
||||||
|
raw_link: bool=None,
|
||||||
):
|
):
|
||||||
if isinstance(options, type) and issubclass(options, Enum):
|
if isinstance(options, type) and issubclass(options, Enum):
|
||||||
options = [v.value for v in options]
|
options = [v.value for v in options]
|
||||||
if isinstance(default, Enum):
|
if isinstance(default, Enum):
|
||||||
default = default.value
|
default = default.value
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
|
||||||
self.multiselect = False
|
self.multiselect = False
|
||||||
self.options = options
|
self.options = options
|
||||||
self.control_after_generate = control_after_generate
|
self.control_after_generate = control_after_generate
|
||||||
@ -386,10 +379,6 @@ class Combo(ComfyTypeIO):
|
|||||||
super().__init__(id, display_name, tooltip, is_output_list)
|
super().__init__(id, display_name, tooltip, is_output_list)
|
||||||
self.options = options if options is not None else []
|
self.options = options if options is not None else []
|
||||||
|
|
||||||
@property
|
|
||||||
def io_type(self):
|
|
||||||
return self.options
|
|
||||||
|
|
||||||
@comfytype(io_type="COMBO")
|
@comfytype(io_type="COMBO")
|
||||||
class MultiCombo(ComfyTypeI):
|
class MultiCombo(ComfyTypeI):
|
||||||
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
|
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
|
||||||
@ -398,8 +387,8 @@ class MultiCombo(ComfyTypeI):
|
|||||||
class Input(Combo.Input):
|
class Input(Combo.Input):
|
||||||
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
||||||
socketless: bool=None):
|
socketless: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless)
|
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link)
|
||||||
self.multiselect = True
|
self.multiselect = True
|
||||||
self.placeholder = placeholder
|
self.placeholder = placeholder
|
||||||
self.chip = chip
|
self.chip = chip
|
||||||
@ -432,9 +421,9 @@ class Webcam(ComfyTypeIO):
|
|||||||
Type = str
|
Type = str
|
||||||
def __init__(
|
def __init__(
|
||||||
self, id: str, display_name: str=None, optional=False,
|
self, id: str, display_name: str=None, optional=False,
|
||||||
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None
|
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None
|
||||||
):
|
):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
|
||||||
|
|
||||||
|
|
||||||
@comfytype(io_type="MASK")
|
@comfytype(io_type="MASK")
|
||||||
@ -787,7 +776,7 @@ class MultiType:
|
|||||||
'''
|
'''
|
||||||
Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values.
|
Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values.
|
||||||
'''
|
'''
|
||||||
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
# if id is an Input, then use that Input with overridden values
|
# if id is an Input, then use that Input with overridden values
|
||||||
self.input_override = None
|
self.input_override = None
|
||||||
if isinstance(id, Input):
|
if isinstance(id, Input):
|
||||||
@ -800,7 +789,7 @@ class MultiType:
|
|||||||
# if is a widget input, make sure widget_type is set appropriately
|
# if is a widget input, make sure widget_type is set appropriately
|
||||||
if isinstance(self.input_override, WidgetInput):
|
if isinstance(self.input_override, WidgetInput):
|
||||||
self.input_override.widget_type = self.input_override.get_io_type()
|
self.input_override.widget_type = self.input_override.get_io_type()
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||||
self._io_types = types
|
self._io_types = types
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -854,8 +843,8 @@ class MatchType(ComfyTypeIO):
|
|||||||
|
|
||||||
class Input(Input):
|
class Input(Input):
|
||||||
def __init__(self, id: str, template: MatchType.Template,
|
def __init__(self, id: str, template: MatchType.Template,
|
||||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||||
self.template = template
|
self.template = template
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
@ -866,6 +855,8 @@ class MatchType(ComfyTypeIO):
|
|||||||
class Output(Output):
|
class Output(Output):
|
||||||
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
|
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
|
||||||
is_output_list=False):
|
is_output_list=False):
|
||||||
|
if not id and not display_name:
|
||||||
|
display_name = "MATCHTYPE"
|
||||||
super().__init__(id, display_name, tooltip, is_output_list)
|
super().__init__(id, display_name, tooltip, is_output_list)
|
||||||
self.template = template
|
self.template = template
|
||||||
|
|
||||||
@ -878,24 +869,30 @@ class DynamicInput(Input, ABC):
|
|||||||
'''
|
'''
|
||||||
Abstract class for dynamic input registration.
|
Abstract class for dynamic input registration.
|
||||||
'''
|
'''
|
||||||
def get_dynamic(self) -> list[Input]:
|
pass
|
||||||
return []
|
|
||||||
|
|
||||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicOutput(Output, ABC):
|
class DynamicOutput(Output, ABC):
|
||||||
'''
|
'''
|
||||||
Abstract class for dynamic output registration.
|
Abstract class for dynamic output registration.
|
||||||
'''
|
'''
|
||||||
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
|
pass
|
||||||
is_output_list=False):
|
|
||||||
super().__init__(id, display_name, tooltip, is_output_list)
|
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Output]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
def handle_prefix(prefix_list: list[str] | None, id: str | None = None) -> list[str]:
|
||||||
|
if prefix_list is None:
|
||||||
|
prefix_list = []
|
||||||
|
if id is not None:
|
||||||
|
prefix_list = prefix_list + [id]
|
||||||
|
return prefix_list
|
||||||
|
|
||||||
|
def finalize_prefix(prefix_list: list[str] | None, id: str | None = None) -> str:
|
||||||
|
assert not (prefix_list is None and id is None)
|
||||||
|
if prefix_list is None:
|
||||||
|
return id
|
||||||
|
elif id is not None:
|
||||||
|
prefix_list = prefix_list + [id]
|
||||||
|
return ".".join(prefix_list)
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_AUTOGROW_V3")
|
@comfytype(io_type="COMFY_AUTOGROW_V3")
|
||||||
class Autogrow(ComfyTypeI):
|
class Autogrow(ComfyTypeI):
|
||||||
@ -932,14 +929,6 @@ class Autogrow(ComfyTypeI):
|
|||||||
def validate(self):
|
def validate(self):
|
||||||
self.input.validate()
|
self.input.validate()
|
||||||
|
|
||||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
|
||||||
real_inputs = []
|
|
||||||
for name, input in self.cached_inputs.items():
|
|
||||||
if name in live_inputs:
|
|
||||||
real_inputs.append(input)
|
|
||||||
add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
|
|
||||||
add_dynamic_id_mapping(d, real_inputs, curr_prefix)
|
|
||||||
|
|
||||||
class TemplatePrefix(_AutogrowTemplate):
|
class TemplatePrefix(_AutogrowTemplate):
|
||||||
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
|
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
|
||||||
super().__init__(input)
|
super().__init__(input)
|
||||||
@ -984,22 +973,45 @@ class Autogrow(ComfyTypeI):
|
|||||||
"template": self.template.as_dict(),
|
"template": self.template.as_dict(),
|
||||||
})
|
})
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Input]:
|
|
||||||
return self.template.get_all()
|
|
||||||
|
|
||||||
def get_all(self) -> list[Input]:
|
def get_all(self) -> list[Input]:
|
||||||
return [self] + self.template.get_all()
|
return [self] + self.template.get_all()
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
self.template.validate()
|
self.template.validate()
|
||||||
|
|
||||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
@staticmethod
|
||||||
curr_prefix = f"{curr_prefix}{self.id}."
|
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||||
# need to remove self from expected inputs dictionary; replaced by template inputs in frontend
|
# NOTE: purposely do not include self in out_dict; instead use only the template inputs
|
||||||
for inner_dict in d.values():
|
# need to figure out names based on template type
|
||||||
if self.id in inner_dict:
|
is_names = ("names" in value[1]["template"])
|
||||||
del inner_dict[self.id]
|
is_prefix = ("prefix" in value[1]["template"])
|
||||||
self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
|
input = value[1]["template"]["input"]
|
||||||
|
if is_names:
|
||||||
|
min = value[1]["template"]["min"]
|
||||||
|
names = value[1]["template"]["names"]
|
||||||
|
max = len(names)
|
||||||
|
elif is_prefix:
|
||||||
|
prefix = value[1]["template"]["prefix"]
|
||||||
|
min = value[1]["template"]["min"]
|
||||||
|
max = value[1]["template"]["max"]
|
||||||
|
names = [f"{prefix}{i}" for i in range(max)]
|
||||||
|
# need to create a new input based on the contents of input
|
||||||
|
template_input = None
|
||||||
|
for _, dict_input in input.items():
|
||||||
|
# for now, get just the first value from dict_input
|
||||||
|
template_input = list(dict_input.values())[0]
|
||||||
|
new_dict = {}
|
||||||
|
for i, name in enumerate(names):
|
||||||
|
expected_id = finalize_prefix(curr_prefix, name)
|
||||||
|
if expected_id in live_inputs:
|
||||||
|
# required
|
||||||
|
if i < min:
|
||||||
|
type_dict = new_dict.setdefault("required", {})
|
||||||
|
# optional
|
||||||
|
else:
|
||||||
|
type_dict = new_dict.setdefault("optional", {})
|
||||||
|
type_dict[name] = template_input
|
||||||
|
parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix)
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
|
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
|
||||||
class DynamicCombo(ComfyTypeI):
|
class DynamicCombo(ComfyTypeI):
|
||||||
@ -1022,23 +1034,6 @@ class DynamicCombo(ComfyTypeI):
|
|||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||||
self.options = options
|
self.options = options
|
||||||
|
|
||||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
|
||||||
# check if dynamic input's id is in live_inputs
|
|
||||||
if self.id in live_inputs:
|
|
||||||
curr_prefix = f"{curr_prefix}{self.id}."
|
|
||||||
key = live_inputs[self.id]
|
|
||||||
selected_option = None
|
|
||||||
for option in self.options:
|
|
||||||
if option.key == key:
|
|
||||||
selected_option = option
|
|
||||||
break
|
|
||||||
if selected_option is not None:
|
|
||||||
add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
|
|
||||||
add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
|
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Input]:
|
|
||||||
return [input for option in self.options for input in option.inputs]
|
|
||||||
|
|
||||||
def get_all(self) -> list[Input]:
|
def get_all(self) -> list[Input]:
|
||||||
return [self] + [input for option in self.options for input in option.inputs]
|
return [self] + [input for option in self.options for input in option.inputs]
|
||||||
|
|
||||||
@ -1053,6 +1048,24 @@ class DynamicCombo(ComfyTypeI):
|
|||||||
for input in option.inputs:
|
for input in option.inputs:
|
||||||
input.validate()
|
input.validate()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||||
|
finalized_id = finalize_prefix(curr_prefix)
|
||||||
|
if finalized_id in live_inputs:
|
||||||
|
key = live_inputs[finalized_id]
|
||||||
|
selected_option = None
|
||||||
|
# get options from dict
|
||||||
|
options: list[dict[str, str | dict[str, Any]]] = value[1]["options"]
|
||||||
|
for option in options:
|
||||||
|
if option["key"] == key:
|
||||||
|
selected_option = option
|
||||||
|
break
|
||||||
|
if selected_option is not None:
|
||||||
|
parse_class_inputs(out_dict, live_inputs, selected_option["inputs"], curr_prefix)
|
||||||
|
# add self to inputs
|
||||||
|
out_dict[input_type][finalized_id] = value
|
||||||
|
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
|
@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
|
||||||
class DynamicSlot(ComfyTypeI):
|
class DynamicSlot(ComfyTypeI):
|
||||||
Type = dict[str, Any]
|
Type = dict[str, Any]
|
||||||
@ -1075,17 +1088,8 @@ class DynamicSlot(ComfyTypeI):
|
|||||||
self.force_input = True
|
self.force_input = True
|
||||||
self.slot.force_input = True
|
self.slot.force_input = True
|
||||||
|
|
||||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
|
||||||
if self.id in live_inputs:
|
|
||||||
curr_prefix = f"{curr_prefix}{self.id}."
|
|
||||||
add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
|
|
||||||
add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
|
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Input]:
|
|
||||||
return [self.slot] + self.inputs
|
|
||||||
|
|
||||||
def get_all(self) -> list[Input]:
|
def get_all(self) -> list[Input]:
|
||||||
return [self] + [self.slot] + self.inputs
|
return [self.slot] + self.inputs
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return super().as_dict() | prune_dict({
|
return super().as_dict() | prune_dict({
|
||||||
@ -1099,17 +1103,41 @@ class DynamicSlot(ComfyTypeI):
|
|||||||
for input in self.inputs:
|
for input in self.inputs:
|
||||||
input.validate()
|
input.validate()
|
||||||
|
|
||||||
def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None):
|
@staticmethod
|
||||||
dynamic = d.setdefault("dynamic_paths", {})
|
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||||
if self is not None:
|
finalized_id = finalize_prefix(curr_prefix)
|
||||||
dynamic[self.id] = f"{curr_prefix}{self.id}"
|
if finalized_id in live_inputs:
|
||||||
for i in inputs:
|
inputs = value[1]["inputs"]
|
||||||
if not isinstance(i, DynamicInput):
|
parse_class_inputs(out_dict, live_inputs, inputs, curr_prefix)
|
||||||
dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}"
|
# add self to inputs
|
||||||
|
out_dict[input_type][finalized_id] = value
|
||||||
|
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||||
|
|
||||||
|
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||||
|
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||||
|
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||||
|
|
||||||
|
def get_dynamic_input_func(io_type: str) -> Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]:
|
||||||
|
return DYNAMIC_INPUT_LOOKUP[io_type]
|
||||||
|
|
||||||
|
def setup_dynamic_input_funcs():
|
||||||
|
# Autogrow.Input
|
||||||
|
register_dynamic_input_func(Autogrow.io_type, Autogrow._expand_schema_for_dynamic)
|
||||||
|
# DynamicCombo.Input
|
||||||
|
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
|
||||||
|
# DynamicSlot.Input
|
||||||
|
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
|
||||||
|
|
||||||
|
if len(DYNAMIC_INPUT_LOOKUP) == 0:
|
||||||
|
setup_dynamic_input_funcs()
|
||||||
|
|
||||||
class V3Data(TypedDict):
|
class V3Data(TypedDict):
|
||||||
hidden_inputs: dict[str, Any]
|
hidden_inputs: dict[str, Any]
|
||||||
|
'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.'
|
||||||
dynamic_paths: dict[str, Any]
|
dynamic_paths: dict[str, Any]
|
||||||
|
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
|
||||||
|
create_dynamic_tuple: bool
|
||||||
|
'When True, the value of the dynamic input will be in the format (value, path_key).'
|
||||||
|
|
||||||
class HiddenHolder:
|
class HiddenHolder:
|
||||||
def __init__(self, unique_id: str, prompt: Any,
|
def __init__(self, unique_id: str, prompt: Any,
|
||||||
@ -1145,6 +1173,10 @@ class HiddenHolder:
|
|||||||
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
|
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_v3_data(cls, v3_data: V3Data | None) -> HiddenHolder:
|
||||||
|
return cls.from_dict(v3_data["hidden_inputs"] if v3_data else None)
|
||||||
|
|
||||||
class Hidden(str, Enum):
|
class Hidden(str, Enum):
|
||||||
'''
|
'''
|
||||||
Enumerator for requesting hidden variables in nodes.
|
Enumerator for requesting hidden variables in nodes.
|
||||||
@ -1250,61 +1282,56 @@ class Schema:
|
|||||||
- verify ids on inputs and outputs are unique - both internally and in relation to each other
|
- verify ids on inputs and outputs are unique - both internally and in relation to each other
|
||||||
'''
|
'''
|
||||||
nested_inputs: list[Input] = []
|
nested_inputs: list[Input] = []
|
||||||
if self.inputs is not None:
|
for input in self.inputs:
|
||||||
for input in self.inputs:
|
if not isinstance(input, DynamicInput):
|
||||||
nested_inputs.extend(input.get_all())
|
nested_inputs.extend(input.get_all())
|
||||||
input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else []
|
input_ids = [i.id for i in nested_inputs]
|
||||||
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
|
output_ids = [o.id for o in self.outputs]
|
||||||
input_set = set(input_ids)
|
input_set = set(input_ids)
|
||||||
output_set = set(output_ids)
|
output_set = set(output_ids)
|
||||||
issues = []
|
issues: list[str] = []
|
||||||
# verify ids are unique per list
|
# verify ids are unique per list
|
||||||
if len(input_set) != len(input_ids):
|
if len(input_set) != len(input_ids):
|
||||||
issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.")
|
issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.")
|
||||||
if len(output_set) != len(output_ids):
|
if len(output_set) != len(output_ids):
|
||||||
issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.")
|
issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.")
|
||||||
# verify ids are unique between lists
|
|
||||||
intersection = input_set & output_set
|
|
||||||
if len(intersection) > 0:
|
|
||||||
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
|
|
||||||
if len(issues) > 0:
|
if len(issues) > 0:
|
||||||
raise ValueError("\n".join(issues))
|
raise ValueError("\n".join(issues))
|
||||||
# validate inputs and outputs
|
# validate inputs and outputs
|
||||||
if self.inputs is not None:
|
for input in self.inputs:
|
||||||
for input in self.inputs:
|
input.validate()
|
||||||
input.validate()
|
for output in self.outputs:
|
||||||
if self.outputs is not None:
|
output.validate()
|
||||||
for output in self.outputs:
|
|
||||||
output.validate()
|
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
|
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
|
||||||
|
# ensure inputs, outputs, and hidden are lists
|
||||||
|
if self.inputs is None:
|
||||||
|
self.inputs = []
|
||||||
|
if self.outputs is None:
|
||||||
|
self.outputs = []
|
||||||
|
if self.hidden is None:
|
||||||
|
self.hidden = []
|
||||||
# if is an api_node, will need key-related hidden
|
# if is an api_node, will need key-related hidden
|
||||||
if self.is_api_node:
|
if self.is_api_node:
|
||||||
if self.hidden is None:
|
|
||||||
self.hidden = []
|
|
||||||
if Hidden.auth_token_comfy_org not in self.hidden:
|
if Hidden.auth_token_comfy_org not in self.hidden:
|
||||||
self.hidden.append(Hidden.auth_token_comfy_org)
|
self.hidden.append(Hidden.auth_token_comfy_org)
|
||||||
if Hidden.api_key_comfy_org not in self.hidden:
|
if Hidden.api_key_comfy_org not in self.hidden:
|
||||||
self.hidden.append(Hidden.api_key_comfy_org)
|
self.hidden.append(Hidden.api_key_comfy_org)
|
||||||
# if is an output_node, will need prompt and extra_pnginfo
|
# if is an output_node, will need prompt and extra_pnginfo
|
||||||
if self.is_output_node:
|
if self.is_output_node:
|
||||||
if self.hidden is None:
|
|
||||||
self.hidden = []
|
|
||||||
if Hidden.prompt not in self.hidden:
|
if Hidden.prompt not in self.hidden:
|
||||||
self.hidden.append(Hidden.prompt)
|
self.hidden.append(Hidden.prompt)
|
||||||
if Hidden.extra_pnginfo not in self.hidden:
|
if Hidden.extra_pnginfo not in self.hidden:
|
||||||
self.hidden.append(Hidden.extra_pnginfo)
|
self.hidden.append(Hidden.extra_pnginfo)
|
||||||
# give outputs without ids default ids
|
# give outputs without ids default ids
|
||||||
if self.outputs is not None:
|
for i, output in enumerate(self.outputs):
|
||||||
for i, output in enumerate(self.outputs):
|
if output.id is None:
|
||||||
if output.id is None:
|
output.id = f"_{i}_{output.io_type}_"
|
||||||
output.id = f"_{i}_{output.io_type}_"
|
|
||||||
|
|
||||||
def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1:
|
def get_v1_info(self, cls) -> NodeInfoV1:
|
||||||
# NOTE: live_inputs will not be used anymore very soon and this will be done another way
|
|
||||||
# get V1 inputs
|
# get V1 inputs
|
||||||
input = create_input_dict_v1(self.inputs, live_inputs)
|
input = create_input_dict_v1(self.inputs)
|
||||||
if self.hidden:
|
if self.hidden:
|
||||||
for hidden in self.hidden:
|
for hidden in self.hidden:
|
||||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||||
@ -1384,33 +1411,54 @@ class Schema:
|
|||||||
)
|
)
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]:
|
||||||
|
out_dict = {
|
||||||
|
"required": {},
|
||||||
|
"optional": {},
|
||||||
|
"dynamic_paths": {},
|
||||||
|
}
|
||||||
|
d = d.copy()
|
||||||
|
# ignore hidden for parsing
|
||||||
|
hidden = d.pop("hidden", None)
|
||||||
|
parse_class_inputs(out_dict, live_inputs, d)
|
||||||
|
if hidden is not None and include_hidden:
|
||||||
|
out_dict["hidden"] = hidden
|
||||||
|
v3_data = {}
|
||||||
|
dynamic_paths = out_dict.pop("dynamic_paths", None)
|
||||||
|
if dynamic_paths is not None:
|
||||||
|
v3_data["dynamic_paths"] = dynamic_paths
|
||||||
|
return out_dict, hidden, v3_data
|
||||||
|
|
||||||
def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict:
|
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
|
||||||
|
for input_type, inner_d in curr_dict.items():
|
||||||
|
for id, value in inner_d.items():
|
||||||
|
io_type = value[0]
|
||||||
|
if io_type in DYNAMIC_INPUT_LOOKUP:
|
||||||
|
# dynamic inputs need to be handled with lookup functions
|
||||||
|
dynamic_input_func = get_dynamic_input_func(io_type)
|
||||||
|
new_prefix = handle_prefix(curr_prefix, id)
|
||||||
|
dynamic_input_func(out_dict, live_inputs, value, input_type, new_prefix)
|
||||||
|
else:
|
||||||
|
# non-dynamic inputs get directly transferred
|
||||||
|
finalized_id = finalize_prefix(curr_prefix, id)
|
||||||
|
out_dict[input_type][finalized_id] = value
|
||||||
|
if curr_prefix:
|
||||||
|
out_dict["dynamic_paths"][finalized_id] = finalized_id
|
||||||
|
|
||||||
|
def create_input_dict_v1(inputs: list[Input]) -> dict:
|
||||||
input = {
|
input = {
|
||||||
"required": {}
|
"required": {}
|
||||||
}
|
}
|
||||||
add_to_input_dict_v1(input, inputs, live_inputs)
|
for i in inputs:
|
||||||
|
add_to_dict_v1(i, input)
|
||||||
return input
|
return input
|
||||||
|
|
||||||
def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''):
|
def add_to_dict_v1(i: Input, d: dict):
|
||||||
for i in inputs:
|
|
||||||
if isinstance(i, DynamicInput):
|
|
||||||
add_to_dict_v1(i, d)
|
|
||||||
if live_inputs is not None:
|
|
||||||
i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
|
|
||||||
else:
|
|
||||||
add_to_dict_v1(i, d)
|
|
||||||
|
|
||||||
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
|
|
||||||
key = "optional" if i.optional else "required"
|
key = "optional" if i.optional else "required"
|
||||||
as_dict = i.as_dict()
|
as_dict = i.as_dict()
|
||||||
# for v1, we don't want to include the optional key
|
# for v1, we don't want to include the optional key
|
||||||
as_dict.pop("optional", None)
|
as_dict.pop("optional", None)
|
||||||
if dynamic_dict is None:
|
d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
|
||||||
value = (i.get_io_type(), as_dict)
|
|
||||||
else:
|
|
||||||
value = (i.get_io_type(), as_dict, dynamic_dict)
|
|
||||||
d.setdefault(key, {})[i.id] = value
|
|
||||||
|
|
||||||
def add_to_dict_v3(io: Input | Output, d: dict):
|
def add_to_dict_v3(io: Input | Output, d: dict):
|
||||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||||
@ -1422,6 +1470,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
|||||||
values = values.copy()
|
values = values.copy()
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
|
create_tuple = v3_data.get("create_dynamic_tuple", False)
|
||||||
|
|
||||||
for key, path in paths.items():
|
for key, path in paths.items():
|
||||||
parts = path.split(".")
|
parts = path.split(".")
|
||||||
current = result
|
current = result
|
||||||
@ -1430,7 +1480,10 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
|||||||
is_last = (i == len(parts) - 1)
|
is_last = (i == len(parts) - 1)
|
||||||
|
|
||||||
if is_last:
|
if is_last:
|
||||||
current[p] = values.pop(key, None)
|
value = values.pop(key, None)
|
||||||
|
if create_tuple:
|
||||||
|
value = (value, key)
|
||||||
|
current[p] = value
|
||||||
else:
|
else:
|
||||||
current = current.setdefault(p, {})
|
current = current.setdefault(p, {})
|
||||||
|
|
||||||
@ -1445,7 +1498,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
SCHEMA = None
|
SCHEMA = None
|
||||||
|
|
||||||
# filled in during execution
|
# filled in during execution
|
||||||
resources: Resources = None
|
|
||||||
hidden: HiddenHolder = None
|
hidden: HiddenHolder = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1492,7 +1544,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
return [name for name in kwargs if kwargs[name] is None]
|
return [name for name in kwargs if kwargs[name] is None]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.local_resources: ResourcesLocal = None
|
|
||||||
self.__class__.VALIDATE_CLASS()
|
self.__class__.VALIDATE_CLASS()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1560,7 +1611,7 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
||||||
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
||||||
# set hidden
|
# set hidden
|
||||||
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None)
|
type_clone.hidden = HiddenHolder.from_v3_data(v3_data)
|
||||||
return type_clone
|
return type_clone
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@ -1677,19 +1728,10 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
|
def INPUT_TYPES(cls) -> dict[str, dict]:
|
||||||
schema = cls.FINALIZE_SCHEMA()
|
schema = cls.FINALIZE_SCHEMA()
|
||||||
info = schema.get_v1_info(cls, live_inputs)
|
info = schema.get_v1_info(cls)
|
||||||
input = info.input
|
return info.input
|
||||||
if not include_hidden:
|
|
||||||
input.pop("hidden", None)
|
|
||||||
if return_schema:
|
|
||||||
v3_data: V3Data = {}
|
|
||||||
dynamic = input.pop("dynamic_paths", None)
|
|
||||||
if dynamic is not None:
|
|
||||||
v3_data["dynamic_paths"] = dynamic
|
|
||||||
return input, schema, v3_data
|
|
||||||
return input
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1808,7 +1850,7 @@ class NodeOutput(_NodeOutputInternal):
|
|||||||
return self.args if len(self.args) > 0 else None
|
return self.args if len(self.args) > 0 else None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> "NodeOutput":
|
def from_dict(cls, data: dict[str, Any]) -> NodeOutput:
|
||||||
args = ()
|
args = ()
|
||||||
ui = None
|
ui = None
|
||||||
expand = None
|
expand = None
|
||||||
@ -1903,8 +1945,8 @@ __all__ = [
|
|||||||
"Tracks",
|
"Tracks",
|
||||||
# Dynamic Types
|
# Dynamic Types
|
||||||
"MatchType",
|
"MatchType",
|
||||||
# "DynamicCombo",
|
"DynamicCombo",
|
||||||
# "Autogrow",
|
"Autogrow",
|
||||||
# Other classes
|
# Other classes
|
||||||
"HiddenHolder",
|
"HiddenHolder",
|
||||||
"Hidden",
|
"Hidden",
|
||||||
|
|||||||
@ -1,72 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import comfy.utils
|
|
||||||
import folder_paths
|
|
||||||
import logging
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any
|
|
||||||
import torch
|
|
||||||
|
|
||||||
class ResourceKey(ABC):
|
|
||||||
Type = Any
|
|
||||||
def __init__(self):
|
|
||||||
...
|
|
||||||
|
|
||||||
class TorchDictFolderFilename(ResourceKey):
|
|
||||||
'''Key for requesting a torch file via file_name from a folder category.'''
|
|
||||||
Type = dict[str, torch.Tensor]
|
|
||||||
def __init__(self, folder_name: str, file_name: str):
|
|
||||||
self.folder_name = folder_name
|
|
||||||
self.file_name = file_name
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return hash((self.folder_name, self.file_name))
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, TorchDictFolderFilename):
|
|
||||||
return False
|
|
||||||
return self.folder_name == other.folder_name and self.file_name == other.file_name
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"{self.folder_name} -> {self.file_name}"
|
|
||||||
|
|
||||||
class Resources(ABC):
|
|
||||||
def __init__(self):
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class ResourcesLocal(Resources):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.local_resources: dict[ResourceKey, Any] = {}
|
|
||||||
|
|
||||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
|
||||||
cached = self.local_resources.get(key, None)
|
|
||||||
if cached is not None:
|
|
||||||
logging.info(f"Using cached resource '{key}'")
|
|
||||||
return cached
|
|
||||||
logging.info(f"Loading resource '{key}'")
|
|
||||||
to_return = None
|
|
||||||
if isinstance(key, TorchDictFolderFilename):
|
|
||||||
if default is ...:
|
|
||||||
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
|
|
||||||
else:
|
|
||||||
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
|
|
||||||
if full_path is not None:
|
|
||||||
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
|
|
||||||
|
|
||||||
if to_return is not None:
|
|
||||||
self.local_resources[key] = to_return
|
|
||||||
return to_return
|
|
||||||
if default is not ...:
|
|
||||||
return default
|
|
||||||
raise Exception(f"Unsupported resource key type: {type(key)}")
|
|
||||||
|
|
||||||
|
|
||||||
class _RESOURCES:
|
|
||||||
ResourceKey = ResourceKey
|
|
||||||
TorchDictFolderFilename = TorchDictFolderFilename
|
|
||||||
Resources = Resources
|
|
||||||
ResourcesLocal = ResourcesLocal
|
|
||||||
@ -229,6 +229,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -269,7 +270,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="ByteDanceSeedreamNode",
|
node_id="ByteDanceSeedreamNode",
|
||||||
display_name="ByteDance Seedream 4",
|
display_name="ByteDance Seedream 4.5",
|
||||||
category="api node/image/ByteDance",
|
category="api node/image/ByteDance",
|
||||||
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
||||||
inputs=[
|
inputs=[
|
||||||
|
|||||||
@ -807,6 +807,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||||
IO.Combo.Input("duration", options=[5, 10]),
|
IO.Combo.Input("duration", options=[5, 10]),
|
||||||
|
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Video.Output(),
|
IO.Video.Output(),
|
||||||
@ -826,6 +827,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
duration: int,
|
duration: int,
|
||||||
|
resolution: str = "1080p",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
@ -837,6 +839,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
duration=str(duration),
|
duration=str(duration),
|
||||||
|
mode="pro" if resolution == "1080p" else "std",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await finish_omni_video_task(cls, response)
|
return await finish_omni_video_task(cls, response)
|
||||||
@ -872,6 +875,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Up to 6 additional reference images.",
|
tooltip="Up to 6 additional reference images.",
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Video.Output(),
|
IO.Video.Output(),
|
||||||
@ -893,6 +897,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||||||
first_frame: Input.Image,
|
first_frame: Input.Image,
|
||||||
end_frame: Input.Image | None = None,
|
end_frame: Input.Image | None = None,
|
||||||
reference_images: Input.Image | None = None,
|
reference_images: Input.Image | None = None,
|
||||||
|
resolution: str = "1080p",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
prompt = normalize_omni_prompt_references(prompt)
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
@ -936,6 +941,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
duration=str(duration),
|
duration=str(duration),
|
||||||
image_list=image_list,
|
image_list=image_list,
|
||||||
|
mode="pro" if resolution == "1080p" else "std",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await finish_omni_video_task(cls, response)
|
return await finish_omni_video_task(cls, response)
|
||||||
@ -964,6 +970,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||||||
"reference_images",
|
"reference_images",
|
||||||
tooltip="Up to 7 reference images.",
|
tooltip="Up to 7 reference images.",
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Video.Output(),
|
IO.Video.Output(),
|
||||||
@ -984,6 +991,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
duration: int,
|
duration: int,
|
||||||
reference_images: Input.Image,
|
reference_images: Input.Image,
|
||||||
|
resolution: str = "1080p",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
prompt = normalize_omni_prompt_references(prompt)
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
@ -1005,6 +1013,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
duration=str(duration),
|
duration=str(duration),
|
||||||
image_list=image_list,
|
image_list=image_list,
|
||||||
|
mode="pro" if resolution == "1080p" else "std",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await finish_omni_video_task(cls, response)
|
return await finish_omni_video_task(cls, response)
|
||||||
@ -1036,6 +1045,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
|||||||
tooltip="Up to 4 additional reference images.",
|
tooltip="Up to 4 additional reference images.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Video.Output(),
|
IO.Video.Output(),
|
||||||
@ -1058,6 +1068,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
|||||||
reference_video: Input.Video,
|
reference_video: Input.Video,
|
||||||
keep_original_sound: bool,
|
keep_original_sound: bool,
|
||||||
reference_images: Input.Image | None = None,
|
reference_images: Input.Image | None = None,
|
||||||
|
resolution: str = "1080p",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
prompt = normalize_omni_prompt_references(prompt)
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
@ -1090,6 +1101,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
|||||||
duration=str(duration),
|
duration=str(duration),
|
||||||
image_list=image_list if image_list else None,
|
image_list=image_list if image_list else None,
|
||||||
video_list=video_list,
|
video_list=video_list,
|
||||||
|
mode="pro" if resolution == "1080p" else "std",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await finish_omni_video_task(cls, response)
|
return await finish_omni_video_task(cls, response)
|
||||||
@ -1119,6 +1131,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||||||
tooltip="Up to 4 additional reference images.",
|
tooltip="Up to 4 additional reference images.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Video.Output(),
|
IO.Video.Output(),
|
||||||
@ -1139,6 +1152,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||||||
video: Input.Video,
|
video: Input.Video,
|
||||||
keep_original_sound: bool,
|
keep_original_sound: bool,
|
||||||
reference_images: Input.Image | None = None,
|
reference_images: Input.Image | None = None,
|
||||||
|
resolution: str = "1080p",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
prompt = normalize_omni_prompt_references(prompt)
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
@ -1171,6 +1185,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||||||
duration=None,
|
duration=None,
|
||||||
image_list=image_list if image_list else None,
|
image_list=image_list if image_list else None,
|
||||||
video_list=video_list,
|
video_list=video_list,
|
||||||
|
mode="pro" if resolution == "1080p" else "std",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await finish_omni_video_task(cls, response)
|
return await finish_omni_video_task(cls, response)
|
||||||
|
|||||||
@ -155,7 +155,7 @@ class TripoTextToModelNode(IO.ComfyNode):
|
|||||||
model_seed=model_seed,
|
model_seed=model_seed,
|
||||||
texture_seed=texture_seed,
|
texture_seed=texture_seed,
|
||||||
texture_quality=texture_quality,
|
texture_quality=texture_quality,
|
||||||
face_limit=face_limit,
|
face_limit=face_limit if face_limit != -1 else None,
|
||||||
geometry_quality=geometry_quality,
|
geometry_quality=geometry_quality,
|
||||||
auto_size=True,
|
auto_size=True,
|
||||||
quad=quad,
|
quad=quad,
|
||||||
@ -255,7 +255,7 @@ class TripoImageToModelNode(IO.ComfyNode):
|
|||||||
texture_alignment=texture_alignment,
|
texture_alignment=texture_alignment,
|
||||||
texture_seed=texture_seed,
|
texture_seed=texture_seed,
|
||||||
texture_quality=texture_quality,
|
texture_quality=texture_quality,
|
||||||
face_limit=face_limit,
|
face_limit=face_limit if face_limit != -1 else None,
|
||||||
auto_size=True,
|
auto_size=True,
|
||||||
quad=quad,
|
quad=quad,
|
||||||
),
|
),
|
||||||
@ -369,7 +369,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
|||||||
texture_quality=texture_quality,
|
texture_quality=texture_quality,
|
||||||
geometry_quality=geometry_quality,
|
geometry_quality=geometry_quality,
|
||||||
texture_alignment=texture_alignment,
|
texture_alignment=texture_alignment,
|
||||||
face_limit=face_limit,
|
face_limit=face_limit if face_limit != -1 else None,
|
||||||
quad=quad,
|
quad=quad,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -13,7 +13,9 @@ from comfy_api_nodes.util import (
|
|||||||
poll_op,
|
poll_op,
|
||||||
sync_op,
|
sync_op,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
|
upload_video_to_comfyapi,
|
||||||
validate_audio_duration,
|
validate_audio_duration,
|
||||||
|
validate_video_duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -41,6 +43,12 @@ class Image2VideoInputField(BaseModel):
|
|||||||
audio_url: str | None = Field(None)
|
audio_url: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class Reference2VideoInputField(BaseModel):
|
||||||
|
prompt: str = Field(...)
|
||||||
|
negative_prompt: str | None = Field(None)
|
||||||
|
reference_video_urls: list[str] = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class Txt2ImageParametersField(BaseModel):
|
class Txt2ImageParametersField(BaseModel):
|
||||||
size: str = Field(...)
|
size: str = Field(...)
|
||||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||||
@ -76,6 +84,14 @@ class Image2VideoParametersField(BaseModel):
|
|||||||
shot_type: str = Field("single")
|
shot_type: str = Field("single")
|
||||||
|
|
||||||
|
|
||||||
|
class Reference2VideoParametersField(BaseModel):
|
||||||
|
size: str = Field(...)
|
||||||
|
duration: int = Field(5, ge=5, le=15)
|
||||||
|
shot_type: str = Field("single")
|
||||||
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
|
watermark: bool = Field(False)
|
||||||
|
|
||||||
|
|
||||||
class Text2ImageTaskCreationRequest(BaseModel):
|
class Text2ImageTaskCreationRequest(BaseModel):
|
||||||
model: str = Field(...)
|
model: str = Field(...)
|
||||||
input: Text2ImageInputField = Field(...)
|
input: Text2ImageInputField = Field(...)
|
||||||
@ -100,6 +116,12 @@ class Image2VideoTaskCreationRequest(BaseModel):
|
|||||||
parameters: Image2VideoParametersField = Field(...)
|
parameters: Image2VideoParametersField = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class Reference2VideoTaskCreationRequest(BaseModel):
|
||||||
|
model: str = Field(...)
|
||||||
|
input: Reference2VideoInputField = Field(...)
|
||||||
|
parameters: Reference2VideoParametersField = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class TaskCreationOutputField(BaseModel):
|
class TaskCreationOutputField(BaseModel):
|
||||||
task_id: str = Field(...)
|
task_id: str = Field(...)
|
||||||
task_status: str = Field(...)
|
task_status: str = Field(...)
|
||||||
@ -721,6 +743,143 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||||
|
|
||||||
|
|
||||||
|
class WanReferenceVideoApi(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="WanReferenceVideoApi",
|
||||||
|
display_name="Wan Reference to Video",
|
||||||
|
category="api node/video/Wan",
|
||||||
|
description="Use the character and voice from input videos, combined with a prompt, "
|
||||||
|
"to generate a new video that maintains character consistency.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model", options=["wan2.6-r2v"]),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Prompt describing the elements and visual features. Supports English and Chinese. "
|
||||||
|
"Use identifiers such as `character1` and `character2` to refer to the reference characters.",
|
||||||
|
),
|
||||||
|
IO.String.Input(
|
||||||
|
"negative_prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Negative prompt describing what to avoid.",
|
||||||
|
),
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"reference_videos",
|
||||||
|
template=IO.Autogrow.TemplateNames(
|
||||||
|
IO.Video.Input("reference_video"),
|
||||||
|
names=["character1", "character2", "character3"],
|
||||||
|
min=1,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"size",
|
||||||
|
options=[
|
||||||
|
"720p: 1:1 (960x960)",
|
||||||
|
"720p: 16:9 (1280x720)",
|
||||||
|
"720p: 9:16 (720x1280)",
|
||||||
|
"720p: 4:3 (1088x832)",
|
||||||
|
"720p: 3:4 (832x1088)",
|
||||||
|
"1080p: 1:1 (1440x1440)",
|
||||||
|
"1080p: 16:9 (1920x1080)",
|
||||||
|
"1080p: 9:16 (1080x1920)",
|
||||||
|
"1080p: 4:3 (1632x1248)",
|
||||||
|
"1080p: 3:4 (1248x1632)",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=5,
|
||||||
|
min=5,
|
||||||
|
max=10,
|
||||||
|
step=5,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"shot_type",
|
||||||
|
options=["single", "multi"],
|
||||||
|
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
|
||||||
|
"single continuous shot or multiple shots with cuts.",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"watermark",
|
||||||
|
default=False,
|
||||||
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
reference_videos: IO.Autogrow.Type,
|
||||||
|
size: str,
|
||||||
|
duration: int,
|
||||||
|
seed: int,
|
||||||
|
shot_type: str,
|
||||||
|
watermark: bool,
|
||||||
|
):
|
||||||
|
reference_video_urls = []
|
||||||
|
for i in reference_videos:
|
||||||
|
validate_video_duration(reference_videos[i], min_duration=2, max_duration=30)
|
||||||
|
for i in reference_videos:
|
||||||
|
reference_video_urls.append(await upload_video_to_comfyapi(cls, reference_videos[i]))
|
||||||
|
width, height = RES_IN_PARENS.search(size).groups()
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
|
||||||
|
response_model=TaskCreationResponse,
|
||||||
|
data=Reference2VideoTaskCreationRequest(
|
||||||
|
model=model,
|
||||||
|
input=Reference2VideoInputField(
|
||||||
|
prompt=prompt, negative_prompt=negative_prompt, reference_video_urls=reference_video_urls
|
||||||
|
),
|
||||||
|
parameters=Reference2VideoParametersField(
|
||||||
|
size=f"{width}*{height}",
|
||||||
|
duration=duration,
|
||||||
|
shot_type=shot_type,
|
||||||
|
watermark=watermark,
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not initial_response.output:
|
||||||
|
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
|
response_model=VideoTaskStatusResponse,
|
||||||
|
status_extractor=lambda x: x.output.task_status,
|
||||||
|
poll_interval=6,
|
||||||
|
max_poll_attempts=280,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||||
|
|
||||||
|
|
||||||
class WanApiExtension(ComfyExtension):
|
class WanApiExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -729,6 +888,7 @@ class WanApiExtension(ComfyExtension):
|
|||||||
WanImageToImageApi,
|
WanImageToImageApi,
|
||||||
WanTextToVideoApi,
|
WanTextToVideoApi,
|
||||||
WanImageToVideoApi,
|
WanImageToVideoApi,
|
||||||
|
WanReferenceVideoApi,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,16 +1,22 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy.model_management import processing_interrupted
|
from comfy.model_management import processing_interrupted
|
||||||
from comfy_api.latest import IO
|
from comfy_api.latest import IO
|
||||||
|
|
||||||
from .common_exceptions import ProcessingInterrupted
|
from .common_exceptions import ProcessingInterrupted
|
||||||
|
|
||||||
|
_HAS_PCT_ESC = re.compile(r"%[0-9A-Fa-f]{2}") # any % followed by 2 hex digits
|
||||||
|
_HAS_BAD_PCT = re.compile(r"%(?![0-9A-Fa-f]{2})") # any % not followed by 2 hex digits
|
||||||
|
|
||||||
|
|
||||||
def is_processing_interrupted() -> bool:
|
def is_processing_interrupted() -> bool:
|
||||||
"""Return True if user/runtime requested interruption."""
|
"""Return True if user/runtime requested interruption."""
|
||||||
@ -69,3 +75,17 @@ def get_fs_object_size(path_or_object: str | BytesIO) -> int:
|
|||||||
if isinstance(path_or_object, str):
|
if isinstance(path_or_object, str):
|
||||||
return os.path.getsize(path_or_object)
|
return os.path.getsize(path_or_object)
|
||||||
return len(path_or_object.getvalue())
|
return len(path_or_object.getvalue())
|
||||||
|
|
||||||
|
|
||||||
|
def to_aiohttp_url(url: str) -> URL:
|
||||||
|
"""If `url` appears to be already percent-encoded (contains at least one valid %HH
|
||||||
|
escape and no malformed '%' sequences) and contains no raw whitespace/control
|
||||||
|
characters preserve the original encoding byte-for-byte (important for signed/presigned URLs).
|
||||||
|
Otherwise, return `URL(url)` and allow yarl to normalize/quote as needed."""
|
||||||
|
if any(c.isspace() for c in url) or any(ord(c) < 0x20 for c in url):
|
||||||
|
# Avoid encoded=True if URL contains raw whitespace/control chars
|
||||||
|
return URL(url)
|
||||||
|
if _HAS_PCT_ESC.search(url) and not _HAS_BAD_PCT.search(url):
|
||||||
|
# Preserve encoding only if it appears pre-encoded AND has no invalid % sequences
|
||||||
|
return URL(url, encoded=True)
|
||||||
|
return URL(url)
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from ._helpers import (
|
|||||||
get_auth_header,
|
get_auth_header,
|
||||||
is_processing_interrupted,
|
is_processing_interrupted,
|
||||||
sleep_with_interrupt,
|
sleep_with_interrupt,
|
||||||
|
to_aiohttp_url,
|
||||||
)
|
)
|
||||||
from .client import _diagnose_connectivity
|
from .client import _diagnose_connectivity
|
||||||
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||||
@ -94,7 +95,7 @@ async def download_url_to_bytesio(
|
|||||||
|
|
||||||
monitor_task = asyncio.create_task(_monitor())
|
monitor_task = asyncio.create_task(_monitor())
|
||||||
|
|
||||||
req_task = asyncio.create_task(session.get(url, headers=headers))
|
req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers))
|
||||||
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
|
||||||
if monitor_task in done and req_task in pending:
|
if monitor_task in done and req_task in pending:
|
||||||
|
|||||||
@ -119,7 +119,7 @@ async def upload_video_to_comfyapi(
|
|||||||
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||||
|
|
||||||
upload_mime_type = f"video/{container.value.lower()}"
|
upload_mime_type = f"video/{container.value.lower()}"
|
||||||
filename = f"uploaded_video.{container.value.lower()}"
|
filename = f"{uuid.uuid4()}.{container.value.lower()}"
|
||||||
|
|
||||||
# Convert VideoInput to BytesIO using specified container/codec
|
# Convert VideoInput to BytesIO using specified container/codec
|
||||||
video_bytes_io = BytesIO()
|
video_bytes_io = BytesIO()
|
||||||
|
|||||||
@ -97,6 +97,11 @@ def get_input_info(
|
|||||||
extra_info = input_info[1]
|
extra_info = input_info[1]
|
||||||
else:
|
else:
|
||||||
extra_info = {}
|
extra_info = {}
|
||||||
|
# if input_type is a list, it is a Combo defined in outdated format; convert it.
|
||||||
|
# NOTE: uncomment this when we are confident old format going away won't cause too much trouble.
|
||||||
|
# if isinstance(input_type, list):
|
||||||
|
# extra_info["options"] = input_type
|
||||||
|
# input_type = IO.Combo.io_type
|
||||||
return input_type, input_category, extra_info
|
return input_type, input_category, extra_info
|
||||||
|
|
||||||
class TopologicalSort:
|
class TopologicalSort:
|
||||||
@ -202,15 +207,15 @@ class ExecutionList(TopologicalSort):
|
|||||||
return self.output_cache.get(node_id) is not None
|
return self.output_cache.get(node_id) is not None
|
||||||
|
|
||||||
def cache_link(self, from_node_id, to_node_id):
|
def cache_link(self, from_node_id, to_node_id):
|
||||||
if not to_node_id in self.execution_cache:
|
if to_node_id not in self.execution_cache:
|
||||||
self.execution_cache[to_node_id] = {}
|
self.execution_cache[to_node_id] = {}
|
||||||
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
||||||
if not from_node_id in self.execution_cache_listeners:
|
if from_node_id not in self.execution_cache_listeners:
|
||||||
self.execution_cache_listeners[from_node_id] = set()
|
self.execution_cache_listeners[from_node_id] = set()
|
||||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||||
|
|
||||||
def get_cache(self, from_node_id, to_node_id):
|
def get_cache(self, from_node_id, to_node_id):
|
||||||
if not to_node_id in self.execution_cache:
|
if to_node_id not in self.execution_cache:
|
||||||
return None
|
return None
|
||||||
value = self.execution_cache[to_node_id].get(from_node_id)
|
value = self.execution_cache[to_node_id].get(from_node_id)
|
||||||
if value is None:
|
if value is None:
|
||||||
|
|||||||
@ -21,14 +21,24 @@ def validate_node_input(
|
|||||||
"""
|
"""
|
||||||
# If the types are exactly the same, we can return immediately
|
# If the types are exactly the same, we can return immediately
|
||||||
# Use pre-union behaviour: inverse of `__ne__`
|
# Use pre-union behaviour: inverse of `__ne__`
|
||||||
|
# NOTE: this lets legacy '*' Any types work that override the __ne__ method of the str class.
|
||||||
if not received_type != input_type:
|
if not received_type != input_type:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# If one of the types is '*', we can return True immediately; this is the 'Any' type.
|
||||||
|
if received_type == IO.AnyType.io_type or input_type == IO.AnyType.io_type:
|
||||||
|
return True
|
||||||
|
|
||||||
# If the received type or input_type is a MatchType, we can return True immediately;
|
# If the received type or input_type is a MatchType, we can return True immediately;
|
||||||
# validation for this is handled by the frontend
|
# validation for this is handled by the frontend
|
||||||
if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
|
if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# This accounts for some custom nodes that output lists of options as the type;
|
||||||
|
# if we ever want to break them on purpose, this can be removed
|
||||||
|
if isinstance(received_type, list) and input_type == IO.Combo.io_type:
|
||||||
|
return True
|
||||||
|
|
||||||
# Not equal, and not strings
|
# Not equal, and not strings
|
||||||
if not isinstance(received_type, str) or not isinstance(input_type, str):
|
if not isinstance(received_type, str) or not isinstance(input_type, str):
|
||||||
return False
|
return False
|
||||||
@ -37,6 +47,10 @@ def validate_node_input(
|
|||||||
received_types = set(t.strip() for t in received_type.split(","))
|
received_types = set(t.strip() for t in received_type.split(","))
|
||||||
input_types = set(t.strip() for t in input_type.split(","))
|
input_types = set(t.strip() for t in input_type.split(","))
|
||||||
|
|
||||||
|
# If any of the types is '*', we can return True immediately; this is the 'Any' type.
|
||||||
|
if IO.AnyType.io_type in received_types or IO.AnyType.io_type in input_types:
|
||||||
|
return True
|
||||||
|
|
||||||
if strict:
|
if strict:
|
||||||
# In strict mode, all received types must be in the input types
|
# In strict mode, all received types must be in the input types
|
||||||
return received_types.issubset(input_types)
|
return received_types.issubset(input_types)
|
||||||
|
|||||||
@ -55,7 +55,8 @@ class APG(io.ComfyNode):
|
|||||||
def pre_cfg_function(args):
|
def pre_cfg_function(args):
|
||||||
nonlocal running_avg, prev_sigma
|
nonlocal running_avg, prev_sigma
|
||||||
|
|
||||||
if len(args["conds_out"]) == 1: return args["conds_out"]
|
if len(args["conds_out"]) == 1:
|
||||||
|
return args["conds_out"]
|
||||||
|
|
||||||
cond = args["conds_out"][0]
|
cond = args["conds_out"][0]
|
||||||
uncond = args["conds_out"][1]
|
uncond = args["conds_out"][1]
|
||||||
|
|||||||
@ -112,7 +112,7 @@ class VAEDecodeAudio(IO.ComfyNode):
|
|||||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||||
std[std < 1.0] = 1.0
|
std[std < 1.0] = 1.0
|
||||||
audio /= std
|
audio /= std
|
||||||
return IO.NodeOutput({"waveform": audio, "sample_rate": 44100})
|
return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]})
|
||||||
|
|
||||||
decode = execute # TODO: remove
|
decode = execute # TODO: remove
|
||||||
|
|
||||||
|
|||||||
@ -667,16 +667,19 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, image, longer_edge):
|
def _process(cls, image, longer_edge):
|
||||||
img = tensor_to_pil(image)
|
resized_images = []
|
||||||
w, h = img.size
|
for image_i in image:
|
||||||
if w > h:
|
img = tensor_to_pil(image_i)
|
||||||
new_w = longer_edge
|
w, h = img.size
|
||||||
new_h = int(h * (longer_edge / w))
|
if w > h:
|
||||||
else:
|
new_w = longer_edge
|
||||||
new_h = longer_edge
|
new_h = int(h * (longer_edge / w))
|
||||||
new_w = int(w * (longer_edge / h))
|
else:
|
||||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
new_h = longer_edge
|
||||||
return pil_to_tensor(img)
|
new_w = int(w * (longer_edge / h))
|
||||||
|
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||||
|
resized_images.append(pil_to_tensor(img))
|
||||||
|
return torch.cat(resized_images, dim=0)
|
||||||
|
|
||||||
|
|
||||||
class CenterCropImagesNode(ImageProcessingNode):
|
class CenterCropImagesNode(ImageProcessingNode):
|
||||||
|
|||||||
@ -5,7 +5,9 @@ import comfy.model_management
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel
|
from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel
|
||||||
|
from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
import json
|
||||||
|
|
||||||
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
|
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -186,7 +188,7 @@ class LatentUpscaleModelLoader(io.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model_name) -> io.NodeOutput:
|
def execute(cls, model_name) -> io.NodeOutput:
|
||||||
model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name)
|
model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name)
|
||||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True)
|
||||||
|
|
||||||
if "blocks.0.block.0.conv.weight" in sd:
|
if "blocks.0.block.0.conv.weight" in sd:
|
||||||
config = {
|
config = {
|
||||||
@ -197,6 +199,8 @@ class LatentUpscaleModelLoader(io.ComfyNode):
|
|||||||
"global_residual": False,
|
"global_residual": False,
|
||||||
}
|
}
|
||||||
model_type = "720p"
|
model_type = "720p"
|
||||||
|
model = HunyuanVideo15SRModel(model_type, config)
|
||||||
|
model.load_sd(sd)
|
||||||
elif "up.0.block.0.conv1.conv.weight" in sd:
|
elif "up.0.block.0.conv1.conv.weight" in sd:
|
||||||
sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()}
|
sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()}
|
||||||
config = {
|
config = {
|
||||||
@ -205,9 +209,12 @@ class LatentUpscaleModelLoader(io.ComfyNode):
|
|||||||
"block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))),
|
"block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))),
|
||||||
}
|
}
|
||||||
model_type = "1080p"
|
model_type = "1080p"
|
||||||
|
model = HunyuanVideo15SRModel(model_type, config)
|
||||||
model = HunyuanVideo15SRModel(model_type, config)
|
model.load_sd(sd)
|
||||||
model.load_sd(sd)
|
elif "post_upsample_res_blocks.0.conv2.bias" in sd:
|
||||||
|
config = json.loads(metadata["config"])
|
||||||
|
model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32]))
|
||||||
|
model.load_state_dict(sd)
|
||||||
|
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
|||||||
@ -255,6 +255,7 @@ class LatentBatch(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="LatentBatch",
|
node_id="LatentBatch",
|
||||||
category="latent/batch",
|
category="latent/batch",
|
||||||
|
is_deprecated=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Latent.Input("samples1"),
|
io.Latent.Input("samples1"),
|
||||||
io.Latent.Input("samples2"),
|
io.Latent.Input("samples2"),
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
|
from __future__ import annotations
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
from comfy_api.latest import _io
|
from comfy_api.latest import _io
|
||||||
|
|
||||||
|
# sentinel for missing inputs
|
||||||
|
MISSING = object()
|
||||||
|
|
||||||
|
|
||||||
class SwitchNode(io.ComfyNode):
|
class SwitchNode(io.ComfyNode):
|
||||||
@ -14,6 +17,37 @@ class SwitchNode(io.ComfyNode):
|
|||||||
display_name="Switch",
|
display_name="Switch",
|
||||||
category="logic",
|
category="logic",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Boolean.Input("switch"),
|
||||||
|
io.MatchType.Input("on_false", template=template, lazy=True),
|
||||||
|
io.MatchType.Input("on_true", template=template, lazy=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.MatchType.Output(template=template, display_name="output"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_lazy_status(cls, switch, on_false=None, on_true=None):
|
||||||
|
if switch and on_true is None:
|
||||||
|
return ["on_true"]
|
||||||
|
if not switch and on_false is None:
|
||||||
|
return ["on_false"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, switch, on_true, on_false) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(on_true if switch else on_false)
|
||||||
|
|
||||||
|
|
||||||
|
class SoftSwitchNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
template = io.MatchType.Template("switch")
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ComfySoftSwitchNode",
|
||||||
|
display_name="Soft Switch",
|
||||||
|
category="logic",
|
||||||
|
is_experimental=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Boolean.Input("switch"),
|
io.Boolean.Input("switch"),
|
||||||
io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
|
io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
|
||||||
@ -25,14 +59,14 @@ class SwitchNode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_lazy_status(cls, switch, on_false=..., on_true=...):
|
def check_lazy_status(cls, switch, on_false=MISSING, on_true=MISSING):
|
||||||
# We use ... instead of None, as None is passed for connected-but-unevaluated inputs.
|
# We use MISSING instead of None, as None is passed for connected-but-unevaluated inputs.
|
||||||
# This trick allows us to ignore the value of the switch and still be able to run execute().
|
# This trick allows us to ignore the value of the switch and still be able to run execute().
|
||||||
|
|
||||||
# One of the inputs may be missing, in which case we need to evaluate the other input
|
# One of the inputs may be missing, in which case we need to evaluate the other input
|
||||||
if on_false is ...:
|
if on_false is MISSING:
|
||||||
return ["on_true"]
|
return ["on_true"]
|
||||||
if on_true is ...:
|
if on_true is MISSING:
|
||||||
return ["on_false"]
|
return ["on_false"]
|
||||||
# Normal lazy switch operation
|
# Normal lazy switch operation
|
||||||
if switch and on_true is None:
|
if switch and on_true is None:
|
||||||
@ -41,22 +75,50 @@ class SwitchNode(io.ComfyNode):
|
|||||||
return ["on_false"]
|
return ["on_false"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_inputs(cls, switch, on_false=..., on_true=...):
|
def validate_inputs(cls, switch, on_false=MISSING, on_true=MISSING):
|
||||||
# This check happens before check_lazy_status(), so we can eliminate the case where
|
# This check happens before check_lazy_status(), so we can eliminate the case where
|
||||||
# both inputs are missing.
|
# both inputs are missing.
|
||||||
if on_false is ... and on_true is ...:
|
if on_false is MISSING and on_true is MISSING:
|
||||||
return "At least one of on_false or on_true must be connected to Switch node"
|
return "At least one of on_false or on_true must be connected to Switch node"
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput:
|
def execute(cls, switch, on_true=MISSING, on_false=MISSING) -> io.NodeOutput:
|
||||||
if on_true is ...:
|
if on_true is MISSING:
|
||||||
return io.NodeOutput(on_false)
|
return io.NodeOutput(on_false)
|
||||||
if on_false is ...:
|
if on_false is MISSING:
|
||||||
return io.NodeOutput(on_true)
|
return io.NodeOutput(on_true)
|
||||||
return io.NodeOutput(on_true if switch else on_false)
|
return io.NodeOutput(on_true if switch else on_false)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomComboNode(io.ComfyNode):
|
||||||
|
"""
|
||||||
|
Frontend node that allows user to write their own options for a combo.
|
||||||
|
This is here to make sure the node has a backend-representation to avoid some annoyances.
|
||||||
|
"""
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CustomCombo",
|
||||||
|
display_name="Custom Combo",
|
||||||
|
category="utils",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[io.Combo.Input("choice", options=[])],
|
||||||
|
outputs=[io.String.Output()]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_inputs(cls, choice: io.Combo.Type) -> bool:
|
||||||
|
# NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs.
|
||||||
|
# I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined.
|
||||||
|
# I need to skip checking that the chosen combo option is in the options list, since those are defined by the user.
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, choice: io.Combo.Type) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(choice)
|
||||||
|
|
||||||
|
|
||||||
class DCTestNode(io.ComfyNode):
|
class DCTestNode(io.ComfyNode):
|
||||||
class DCValues(TypedDict):
|
class DCValues(TypedDict):
|
||||||
combo: str
|
combo: str
|
||||||
@ -72,14 +134,14 @@ class DCTestNode(io.ComfyNode):
|
|||||||
display_name="DCTest",
|
display_name="DCTest",
|
||||||
category="logic",
|
category="logic",
|
||||||
is_output_node=True,
|
is_output_node=True,
|
||||||
inputs=[_io.DynamicCombo.Input("combo", options=[
|
inputs=[io.DynamicCombo.Input("combo", options=[
|
||||||
_io.DynamicCombo.Option("option1", [io.String.Input("string")]),
|
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
|
||||||
_io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
|
io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
|
||||||
_io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
|
io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
|
||||||
_io.DynamicCombo.Option("option4", [
|
io.DynamicCombo.Option("option4", [
|
||||||
_io.DynamicCombo.Input("subcombo", options=[
|
io.DynamicCombo.Input("subcombo", options=[
|
||||||
_io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
|
io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
|
||||||
_io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
|
io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
|
||||||
])
|
])
|
||||||
])]
|
])]
|
||||||
)],
|
)],
|
||||||
@ -141,14 +203,65 @@ class AutogrowPrefixTestNode(io.ComfyNode):
|
|||||||
combined = ",".join([str(x) for x in vals])
|
combined = ",".join([str(x) for x in vals])
|
||||||
return io.NodeOutput(combined)
|
return io.NodeOutput(combined)
|
||||||
|
|
||||||
|
class ComboOutputTestNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ComboOptionTestNode",
|
||||||
|
display_name="ComboOptionTest",
|
||||||
|
category="logic",
|
||||||
|
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
|
||||||
|
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
|
||||||
|
outputs=[io.Combo.Output(), io.Combo.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, combo: io.Combo.Type, combo2: io.Combo.Type) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(combo, combo2)
|
||||||
|
|
||||||
|
class ConvertStringToComboNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ConvertStringToComboNode",
|
||||||
|
display_name="Convert String to Combo",
|
||||||
|
category="logic",
|
||||||
|
inputs=[io.String.Input("string")],
|
||||||
|
outputs=[io.Combo.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, string: str) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(string)
|
||||||
|
|
||||||
|
class InvertBooleanNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="InvertBooleanNode",
|
||||||
|
display_name="Invert Boolean",
|
||||||
|
category="logic",
|
||||||
|
inputs=[io.Boolean.Input("boolean")],
|
||||||
|
outputs=[io.Boolean.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, boolean: bool) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(not boolean)
|
||||||
|
|
||||||
class LogicExtension(ComfyExtension):
|
class LogicExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
# SwitchNode,
|
SwitchNode,
|
||||||
|
CustomComboNode,
|
||||||
|
# SoftSwitchNode,
|
||||||
|
# ConvertStringToComboNode,
|
||||||
# DCTestNode,
|
# DCTestNode,
|
||||||
# AutogrowNamesTestNode,
|
# AutogrowNamesTestNode,
|
||||||
# AutogrowPrefixTestNode,
|
# AutogrowPrefixTestNode,
|
||||||
|
# ComboOutputTestNode,
|
||||||
|
# InvertBooleanNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def comfy_entrypoint() -> LogicExtension:
|
async def comfy_entrypoint() -> LogicExtension:
|
||||||
|
|||||||
@ -81,6 +81,59 @@ class LTXVImgToVideo(io.ComfyNode):
|
|||||||
generate = execute # TODO: remove
|
generate = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVImgToVideoInplace(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVImgToVideoInplace",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Image.Input("image"),
|
||||||
|
io.Latent.Input("latent"),
|
||||||
|
io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
|
||||||
|
io.Boolean.Input("bypass", default=False, tooltip="Bypass the conditioning.")
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput:
|
||||||
|
if bypass:
|
||||||
|
return (latent,)
|
||||||
|
|
||||||
|
samples = latent["samples"]
|
||||||
|
_, height_scale_factor, width_scale_factor = (
|
||||||
|
vae.downscale_index_formula
|
||||||
|
)
|
||||||
|
|
||||||
|
batch, _, latent_frames, latent_height, latent_width = samples.shape
|
||||||
|
width = latent_width * width_scale_factor
|
||||||
|
height = latent_height * height_scale_factor
|
||||||
|
|
||||||
|
if image.shape[1] != height or image.shape[2] != width:
|
||||||
|
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
else:
|
||||||
|
pixels = image
|
||||||
|
encode_pixels = pixels[:, :, :, :3]
|
||||||
|
t = vae.encode(encode_pixels)
|
||||||
|
|
||||||
|
samples[:, :, :t.shape[2]] = t
|
||||||
|
|
||||||
|
conditioning_latent_frames_mask = torch.ones(
|
||||||
|
(batch, 1, latent_frames, 1, 1),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=samples.device,
|
||||||
|
)
|
||||||
|
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
|
||||||
|
|
||||||
|
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
|
||||||
|
|
||||||
|
generate = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
def conditioning_get_any_value(conditioning, key, default=None):
|
def conditioning_get_any_value(conditioning, key, default=None):
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
if key in t[1]:
|
if key in t[1]:
|
||||||
@ -106,12 +159,12 @@ def get_keyframe_idxs(cond):
|
|||||||
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
||||||
if keyframe_idxs is None:
|
if keyframe_idxs is None:
|
||||||
return None, 0
|
return None, 0
|
||||||
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
|
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
|
||||||
|
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
|
||||||
return keyframe_idxs, num_keyframes
|
return keyframe_idxs, num_keyframes
|
||||||
|
|
||||||
class LTXVAddGuide(io.ComfyNode):
|
class LTXVAddGuide(io.ComfyNode):
|
||||||
NUM_PREFIX_FRAMES = 2
|
PATCHIFIER = SymmetricPatchifier(1, start_end=True)
|
||||||
PATCHIFIER = SymmetricPatchifier(1)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -182,26 +235,35 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
|
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128):
|
||||||
_, latent_idx = cls.get_latent_index(
|
if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels:
|
||||||
cond=positive,
|
raise ValueError("Adding guide to a combined AV latent is not supported.")
|
||||||
latent_length=latent_image.shape[2],
|
|
||||||
guide_length=guiding_latent.shape[2],
|
|
||||||
frame_idx=frame_idx,
|
|
||||||
scale_factors=scale_factors,
|
|
||||||
)
|
|
||||||
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
|
|
||||||
|
|
||||||
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
||||||
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||||
|
|
||||||
mask = torch.full(
|
if guide_mask is not None:
|
||||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
target_h = max(noise_mask.shape[3], guide_mask.shape[3])
|
||||||
1.0 - strength,
|
target_w = max(noise_mask.shape[4], guide_mask.shape[4])
|
||||||
dtype=noise_mask.dtype,
|
|
||||||
device=noise_mask.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
if noise_mask.shape[3] == 1 or noise_mask.shape[4] == 1:
|
||||||
|
noise_mask = noise_mask.expand(-1, -1, -1, target_h, target_w)
|
||||||
|
|
||||||
|
if guide_mask.shape[3] == 1 or guide_mask.shape[4] == 1:
|
||||||
|
guide_mask = guide_mask.expand(-1, -1, -1, target_h, target_w)
|
||||||
|
mask = guide_mask - strength
|
||||||
|
else:
|
||||||
|
mask = torch.full(
|
||||||
|
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||||
|
1.0 - strength,
|
||||||
|
dtype=noise_mask.dtype,
|
||||||
|
device=noise_mask.device,
|
||||||
|
)
|
||||||
|
# This solves audio video combined latent case where latent_image has audio latent concatenated
|
||||||
|
# in channel dimension with video latent. The solution is to pad guiding latent accordingly.
|
||||||
|
if latent_image.shape[1] > guiding_latent.shape[1]:
|
||||||
|
pad_len = latent_image.shape[1] - guiding_latent.shape[1]
|
||||||
|
guiding_latent = torch.nn.functional.pad(guiding_latent, pad=(0, 0, 0, 0, 0, 0, 0, pad_len), value=0)
|
||||||
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
|
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
|
||||||
noise_mask = torch.cat([noise_mask, mask], dim=2)
|
noise_mask = torch.cat([noise_mask, mask], dim=2)
|
||||||
return positive, negative, latent_image, noise_mask
|
return positive, negative, latent_image, noise_mask
|
||||||
@ -238,33 +300,17 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||||
|
|
||||||
num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2])
|
|
||||||
|
|
||||||
positive, negative, latent_image, noise_mask = cls.append_keyframe(
|
positive, negative, latent_image, noise_mask = cls.append_keyframe(
|
||||||
positive,
|
positive,
|
||||||
negative,
|
negative,
|
||||||
frame_idx,
|
frame_idx,
|
||||||
latent_image,
|
latent_image,
|
||||||
noise_mask,
|
noise_mask,
|
||||||
t[:, :, :num_prefix_frames],
|
t,
|
||||||
strength,
|
strength,
|
||||||
scale_factors,
|
scale_factors,
|
||||||
)
|
)
|
||||||
|
|
||||||
latent_idx += num_prefix_frames
|
|
||||||
|
|
||||||
t = t[:, :, num_prefix_frames:]
|
|
||||||
if t.shape[2] == 0:
|
|
||||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
|
||||||
|
|
||||||
latent_image, noise_mask = cls.replace_latent_frames(
|
|
||||||
latent_image,
|
|
||||||
noise_mask,
|
|
||||||
t,
|
|
||||||
latent_idx,
|
|
||||||
strength,
|
|
||||||
)
|
|
||||||
|
|
||||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||||
|
|
||||||
generate = execute # TODO: remove
|
generate = execute # TODO: remove
|
||||||
@ -507,18 +553,90 @@ class LTXVPreprocess(io.ComfyNode):
|
|||||||
|
|
||||||
preprocess = execute # TODO: remove
|
preprocess = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
import comfy.nested_tensor
|
||||||
|
class LTXVConcatAVLatent(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVConcatAVLatent",
|
||||||
|
category="latent/video/ltxv",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("video_latent"),
|
||||||
|
io.Latent.Input("audio_latent"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, video_latent, audio_latent) -> io.NodeOutput:
|
||||||
|
output = {}
|
||||||
|
output.update(video_latent)
|
||||||
|
output.update(audio_latent)
|
||||||
|
video_noise_mask = video_latent.get("noise_mask", None)
|
||||||
|
audio_noise_mask = audio_latent.get("noise_mask", None)
|
||||||
|
|
||||||
|
if video_noise_mask is not None or audio_noise_mask is not None:
|
||||||
|
if video_noise_mask is None:
|
||||||
|
video_noise_mask = torch.ones_like(video_latent["samples"])
|
||||||
|
if audio_noise_mask is None:
|
||||||
|
audio_noise_mask = torch.ones_like(audio_latent["samples"])
|
||||||
|
output["noise_mask"] = comfy.nested_tensor.NestedTensor((video_noise_mask, audio_noise_mask))
|
||||||
|
|
||||||
|
output["samples"] = comfy.nested_tensor.NestedTensor((video_latent["samples"], audio_latent["samples"]))
|
||||||
|
|
||||||
|
return io.NodeOutput(output)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVSeparateAVLatent(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVSeparateAVLatent",
|
||||||
|
category="latent/video/ltxv",
|
||||||
|
description="LTXV Separate AV Latent",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("av_latent"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(display_name="video_latent"),
|
||||||
|
io.Latent.Output(display_name="audio_latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, av_latent) -> io.NodeOutput:
|
||||||
|
latents = av_latent["samples"].unbind()
|
||||||
|
video_latent = av_latent.copy()
|
||||||
|
video_latent["samples"] = latents[0]
|
||||||
|
audio_latent = av_latent.copy()
|
||||||
|
audio_latent["samples"] = latents[1]
|
||||||
|
if "noise_mask" in av_latent:
|
||||||
|
masks = av_latent["noise_mask"]
|
||||||
|
if masks is not None:
|
||||||
|
masks = masks.unbind()
|
||||||
|
video_latent["noise_mask"] = masks[0]
|
||||||
|
audio_latent["noise_mask"] = masks[1]
|
||||||
|
return io.NodeOutput(video_latent, audio_latent)
|
||||||
|
|
||||||
|
|
||||||
class LtxvExtension(ComfyExtension):
|
class LtxvExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
EmptyLTXVLatentVideo,
|
EmptyLTXVLatentVideo,
|
||||||
LTXVImgToVideo,
|
LTXVImgToVideo,
|
||||||
|
LTXVImgToVideoInplace,
|
||||||
ModelSamplingLTXV,
|
ModelSamplingLTXV,
|
||||||
LTXVConditioning,
|
LTXVConditioning,
|
||||||
LTXVScheduler,
|
LTXVScheduler,
|
||||||
LTXVAddGuide,
|
LTXVAddGuide,
|
||||||
LTXVPreprocess,
|
LTXVPreprocess,
|
||||||
LTXVCropGuides,
|
LTXVCropGuides,
|
||||||
|
LTXVConcatAVLatent,
|
||||||
|
LTXVSeparateAVLatent,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
216
comfy_extras/nodes_lt_audio.py
Normal file
216
comfy_extras/nodes_lt_audio.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
import folder_paths
|
||||||
|
import comfy.utils
|
||||||
|
import comfy.model_management
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVAudioVAELoader(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVAudioVAELoader",
|
||||||
|
display_name="LTXV Audio VAE Loader",
|
||||||
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input(
|
||||||
|
"ckpt_name",
|
||||||
|
options=folder_paths.get_filename_list("checkpoints"),
|
||||||
|
tooltip="Audio VAE checkpoint to load.",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
outputs=[io.Vae.Output(display_name="Audio VAE")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, ckpt_name: str) -> io.NodeOutput:
|
||||||
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
|
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||||
|
return io.NodeOutput(AudioVAE(sd, metadata))
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVAudioVAEEncode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVAudioVAEEncode",
|
||||||
|
display_name="LTXV Audio VAE Encode",
|
||||||
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
io.Audio.Input("audio", tooltip="The audio to be encoded."),
|
||||||
|
io.Vae.Input(
|
||||||
|
id="audio_vae",
|
||||||
|
display_name="Audio VAE",
|
||||||
|
tooltip="The Audio VAE model to use for encoding.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Latent.Output(display_name="Audio Latent")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput:
|
||||||
|
audio_latents = audio_vae.encode(audio)
|
||||||
|
return io.NodeOutput(
|
||||||
|
{
|
||||||
|
"samples": audio_latents,
|
||||||
|
"sample_rate": int(audio_vae.sample_rate),
|
||||||
|
"type": "audio",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVAudioVAEDecode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVAudioVAEDecode",
|
||||||
|
display_name="LTXV Audio VAE Decode",
|
||||||
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples", tooltip="The latent to be decoded."),
|
||||||
|
io.Vae.Input(
|
||||||
|
id="audio_vae",
|
||||||
|
display_name="Audio VAE",
|
||||||
|
tooltip="The Audio VAE model used for decoding the latent.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Audio.Output(display_name="Audio")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput:
|
||||||
|
audio_latent = samples["samples"]
|
||||||
|
if audio_latent.is_nested:
|
||||||
|
audio_latent = audio_latent.unbind()[-1]
|
||||||
|
audio = audio_vae.decode(audio_latent).to(audio_latent.device)
|
||||||
|
output_audio_sample_rate = audio_vae.output_sample_rate
|
||||||
|
return io.NodeOutput(
|
||||||
|
{
|
||||||
|
"waveform": audio,
|
||||||
|
"sample_rate": int(output_audio_sample_rate),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVEmptyLatentAudio(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVEmptyLatentAudio",
|
||||||
|
display_name="LTXV Empty Latent Audio",
|
||||||
|
category="latent/audio",
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input(
|
||||||
|
"frames_number",
|
||||||
|
default=97,
|
||||||
|
min=1,
|
||||||
|
max=1000,
|
||||||
|
step=1,
|
||||||
|
display_mode=io.NumberDisplay.number,
|
||||||
|
tooltip="Number of frames.",
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"frame_rate",
|
||||||
|
default=25,
|
||||||
|
min=1,
|
||||||
|
max=1000,
|
||||||
|
step=1,
|
||||||
|
display_mode=io.NumberDisplay.number,
|
||||||
|
tooltip="Number of frames per second.",
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"batch_size",
|
||||||
|
default=1,
|
||||||
|
min=1,
|
||||||
|
max=4096,
|
||||||
|
display_mode=io.NumberDisplay.number,
|
||||||
|
tooltip="The number of latent audio samples in the batch.",
|
||||||
|
),
|
||||||
|
io.Vae.Input(
|
||||||
|
id="audio_vae",
|
||||||
|
display_name="Audio VAE",
|
||||||
|
tooltip="The Audio VAE model to get configuration from.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Latent.Output(display_name="Latent")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(
|
||||||
|
cls,
|
||||||
|
frames_number: int,
|
||||||
|
frame_rate: int,
|
||||||
|
batch_size: int,
|
||||||
|
audio_vae: AudioVAE,
|
||||||
|
) -> io.NodeOutput:
|
||||||
|
"""Generate empty audio latents matching the reference pipeline structure."""
|
||||||
|
|
||||||
|
assert audio_vae is not None, "Audio VAE model is required"
|
||||||
|
|
||||||
|
z_channels = audio_vae.latent_channels
|
||||||
|
audio_freq = audio_vae.latent_frequency_bins
|
||||||
|
sampling_rate = int(audio_vae.sample_rate)
|
||||||
|
|
||||||
|
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate)
|
||||||
|
|
||||||
|
audio_latents = torch.zeros(
|
||||||
|
(batch_size, z_channels, num_audio_latents, audio_freq),
|
||||||
|
device=comfy.model_management.intermediate_device(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return io.NodeOutput(
|
||||||
|
{
|
||||||
|
"samples": audio_latents,
|
||||||
|
"sample_rate": sampling_rate,
|
||||||
|
"type": "audio",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXAVTextEncoderLoader(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXAVTextEncoderLoader",
|
||||||
|
display_name="LTXV Audio Text Encoder Loader",
|
||||||
|
category="advanced/loaders",
|
||||||
|
description="[Recipes]\n\nltxav: gemma 3 12B",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input(
|
||||||
|
"text_encoder",
|
||||||
|
options=folder_paths.get_filename_list("text_encoders"),
|
||||||
|
),
|
||||||
|
io.Combo.Input(
|
||||||
|
"ckpt_name",
|
||||||
|
options=folder_paths.get_filename_list("checkpoints"),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
outputs=[io.Clip.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, text_encoder, ckpt_name, device="default"):
|
||||||
|
clip_type = comfy.sd.CLIPType.LTXV
|
||||||
|
|
||||||
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder)
|
||||||
|
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
|
|
||||||
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||||
|
return io.NodeOutput(clip)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVAudioExtension(ComfyExtension):
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
LTXVAudioVAELoader,
|
||||||
|
LTXVAudioVAEEncode,
|
||||||
|
LTXVAudioVAEDecode,
|
||||||
|
LTXVEmptyLatentAudio,
|
||||||
|
LTXAVTextEncoderLoader,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ComfyExtension:
|
||||||
|
return LTXVAudioExtension()
|
||||||
75
comfy_extras/nodes_lt_upsampler.py
Normal file
75
comfy_extras/nodes_lt_upsampler.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
from comfy import model_management
|
||||||
|
import math
|
||||||
|
|
||||||
|
class LTXVLatentUpsampler:
|
||||||
|
"""
|
||||||
|
Upsamples a video latent by a factor of 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"samples": ("LATENT",),
|
||||||
|
"upscale_model": ("LATENT_UPSCALE_MODEL",),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "upsample_latent"
|
||||||
|
CATEGORY = "latent/video"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def upsample_latent(
|
||||||
|
self,
|
||||||
|
samples: dict,
|
||||||
|
upscale_model,
|
||||||
|
vae,
|
||||||
|
) -> tuple:
|
||||||
|
"""
|
||||||
|
Upsample the input latent using the provided model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
samples (dict): Input latent samples
|
||||||
|
upscale_model (LatentUpsampler): Loaded upscale model
|
||||||
|
vae: VAE model for normalization
|
||||||
|
auto_tiling (bool): Whether to automatically tile the input for processing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Tuple containing the upsampled latent
|
||||||
|
"""
|
||||||
|
device = model_management.get_torch_device()
|
||||||
|
memory_required = model_management.module_size(upscale_model)
|
||||||
|
|
||||||
|
model_dtype = next(upscale_model.parameters()).dtype
|
||||||
|
latents = samples["samples"]
|
||||||
|
input_dtype = latents.dtype
|
||||||
|
|
||||||
|
memory_required += math.prod(latents.shape) * 3000.0 # TODO: more accurate
|
||||||
|
model_management.free_memory(memory_required, device)
|
||||||
|
|
||||||
|
try:
|
||||||
|
upscale_model.to(device) # TODO: use the comfy model management system.
|
||||||
|
|
||||||
|
latents = latents.to(dtype=model_dtype, device=device)
|
||||||
|
|
||||||
|
"""Upsample latents without tiling."""
|
||||||
|
latents = vae.first_stage_model.per_channel_statistics.un_normalize(latents)
|
||||||
|
upsampled_latents = upscale_model(latents)
|
||||||
|
finally:
|
||||||
|
upscale_model.cpu()
|
||||||
|
|
||||||
|
upsampled_latents = vae.first_stage_model.per_channel_statistics.normalize(
|
||||||
|
upsampled_latents
|
||||||
|
)
|
||||||
|
upsampled_latents = upsampled_latents.to(dtype=input_dtype, device=model_management.intermediate_device())
|
||||||
|
return_dict = samples.copy()
|
||||||
|
return_dict["samples"] = upsampled_latents
|
||||||
|
return_dict.pop("noise_mask", None)
|
||||||
|
return (return_dict,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"LTXVLatentUpsampler": LTXVLatentUpsampler,
|
||||||
|
}
|
||||||
@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="Mahiro",
|
node_id="Mahiro",
|
||||||
display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
|
display_name="Mahiro CFG",
|
||||||
category="_for_testing",
|
category="_for_testing",
|
||||||
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
|
|||||||
@ -4,11 +4,15 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import math
|
import math
|
||||||
|
from enum import Enum
|
||||||
|
from typing import TypedDict, Literal
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy_extras.nodes_latent import reshape_latent_to
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from nodes import MAX_RESOLUTION
|
||||||
|
|
||||||
class Blend(io.ComfyNode):
|
class Blend(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -241,6 +245,353 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
|||||||
s = s.movedim(1,-1)
|
s = s.movedim(1,-1)
|
||||||
return io.NodeOutput(s)
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
|
class ResizeType(str, Enum):
|
||||||
|
SCALE_BY = "scale by multiplier"
|
||||||
|
SCALE_DIMENSIONS = "scale dimensions"
|
||||||
|
SCALE_LONGER_DIMENSION = "scale longer dimension"
|
||||||
|
SCALE_SHORTER_DIMENSION = "scale shorter dimension"
|
||||||
|
SCALE_WIDTH = "scale width"
|
||||||
|
SCALE_HEIGHT = "scale height"
|
||||||
|
SCALE_TOTAL_PIXELS = "scale total pixels"
|
||||||
|
MATCH_SIZE = "match size"
|
||||||
|
|
||||||
|
def is_image(input: torch.Tensor) -> bool:
|
||||||
|
# images have 4 dimensions: [batch, height, width, channels]
|
||||||
|
# masks have 3 dimensions: [batch, height, width]
|
||||||
|
return len(input.shape) == 4
|
||||||
|
|
||||||
|
def init_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor:
|
||||||
|
if is_type_image:
|
||||||
|
input = input.movedim(-1, 1)
|
||||||
|
else:
|
||||||
|
input = input.unsqueeze(1)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def finalize_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor:
|
||||||
|
if is_type_image:
|
||||||
|
input = input.movedim(1, -1)
|
||||||
|
else:
|
||||||
|
input = input.squeeze(1)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def scale_by(input: torch.Tensor, multiplier: float, scale_method: str) -> torch.Tensor:
|
||||||
|
is_type_image = is_image(input)
|
||||||
|
input = init_image_mask_input(input, is_type_image)
|
||||||
|
width = round(input.shape[-1] * multiplier)
|
||||||
|
height = round(input.shape[-2] * multiplier)
|
||||||
|
|
||||||
|
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||||
|
input = finalize_image_mask_input(input, is_type_image)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def scale_dimensions(input: torch.Tensor, width: int, height: int, scale_method: str, crop: str="disabled") -> torch.Tensor:
|
||||||
|
if width == 0 and height == 0:
|
||||||
|
return input
|
||||||
|
is_type_image = is_image(input)
|
||||||
|
input = init_image_mask_input(input, is_type_image)
|
||||||
|
|
||||||
|
if width == 0:
|
||||||
|
width = max(1, round(input.shape[-1] * height / input.shape[-2]))
|
||||||
|
elif height == 0:
|
||||||
|
height = max(1, round(input.shape[-2] * width / input.shape[-1]))
|
||||||
|
|
||||||
|
input = comfy.utils.common_upscale(input, width, height, scale_method, crop)
|
||||||
|
input = finalize_image_mask_input(input, is_type_image)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def scale_longer_dimension(input: torch.Tensor, longer_size: int, scale_method: str) -> torch.Tensor:
|
||||||
|
is_type_image = is_image(input)
|
||||||
|
input = init_image_mask_input(input, is_type_image)
|
||||||
|
width = input.shape[-1]
|
||||||
|
height = input.shape[-2]
|
||||||
|
|
||||||
|
if height > width:
|
||||||
|
width = round((width / height) * longer_size)
|
||||||
|
height = longer_size
|
||||||
|
elif width > height:
|
||||||
|
height = round((height / width) * longer_size)
|
||||||
|
width = longer_size
|
||||||
|
else:
|
||||||
|
height = longer_size
|
||||||
|
width = longer_size
|
||||||
|
|
||||||
|
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||||
|
input = finalize_image_mask_input(input, is_type_image)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method: str) -> torch.Tensor:
|
||||||
|
is_type_image = is_image(input)
|
||||||
|
input = init_image_mask_input(input, is_type_image)
|
||||||
|
width = input.shape[-1]
|
||||||
|
height = input.shape[-2]
|
||||||
|
|
||||||
|
if height < width:
|
||||||
|
width = round((width / height) * shorter_size)
|
||||||
|
height = shorter_size
|
||||||
|
elif width > height:
|
||||||
|
height = round((height / width) * shorter_size)
|
||||||
|
width = shorter_size
|
||||||
|
else:
|
||||||
|
height = shorter_size
|
||||||
|
width = shorter_size
|
||||||
|
|
||||||
|
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||||
|
input = finalize_image_mask_input(input, is_type_image)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def scale_total_pixels(input: torch.Tensor, megapixels: float, scale_method: str) -> torch.Tensor:
|
||||||
|
is_type_image = is_image(input)
|
||||||
|
input = init_image_mask_input(input, is_type_image)
|
||||||
|
total = int(megapixels * 1024 * 1024)
|
||||||
|
|
||||||
|
scale_by = math.sqrt(total / (input.shape[-1] * input.shape[-2]))
|
||||||
|
width = round(input.shape[-1] * scale_by)
|
||||||
|
height = round(input.shape[-2] * scale_by)
|
||||||
|
|
||||||
|
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||||
|
input = finalize_image_mask_input(input, is_type_image)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str, crop: str) -> torch.Tensor:
|
||||||
|
is_type_image = is_image(input)
|
||||||
|
input = init_image_mask_input(input, is_type_image)
|
||||||
|
match = init_image_mask_input(match, is_image(match))
|
||||||
|
|
||||||
|
width = match.shape[-1]
|
||||||
|
height = match.shape[-2]
|
||||||
|
input = comfy.utils.common_upscale(input, width, height, scale_method, crop)
|
||||||
|
input = finalize_image_mask_input(input, is_type_image)
|
||||||
|
return input
|
||||||
|
|
||||||
|
class ResizeImageMaskNode(io.ComfyNode):
|
||||||
|
|
||||||
|
scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||||
|
crop_methods = ["disabled", "center"]
|
||||||
|
|
||||||
|
class ResizeTypedDict(TypedDict):
|
||||||
|
resize_type: ResizeType
|
||||||
|
scale_method: Literal["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||||
|
crop: Literal["disabled", "center"]
|
||||||
|
multiplier: float
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
longer_size: int
|
||||||
|
shorter_size: int
|
||||||
|
megapixels: float
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
template = io.MatchType.Template("input_type", [io.Image, io.Mask])
|
||||||
|
crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center")
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ResizeImageMaskNode",
|
||||||
|
display_name="Resize Image/Mask",
|
||||||
|
category="transform",
|
||||||
|
inputs=[
|
||||||
|
io.MatchType.Input("input", template=template),
|
||||||
|
io.DynamicCombo.Input("resize_type", options=[
|
||||||
|
io.DynamicCombo.Option(ResizeType.SCALE_BY, [
|
||||||
|
io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01),
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [
|
||||||
|
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||||
|
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||||
|
crop_combo,
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [
|
||||||
|
io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [
|
||||||
|
io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [
|
||||||
|
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [
|
||||||
|
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [
|
||||||
|
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [
|
||||||
|
io.MultiType.Input("match", [io.Image, io.Mask]),
|
||||||
|
crop_combo,
|
||||||
|
]),
|
||||||
|
]),
|
||||||
|
io.Combo.Input("scale_method", options=cls.scale_methods, default="area"),
|
||||||
|
],
|
||||||
|
outputs=[io.MatchType.Output(template=template, display_name="resized")]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, input: io.Image.Type | io.Mask.Type, scale_method: io.Combo.Type, resize_type: ResizeTypedDict) -> io.NodeOutput:
|
||||||
|
selected_type = resize_type["resize_type"]
|
||||||
|
if selected_type == ResizeType.SCALE_BY:
|
||||||
|
return io.NodeOutput(scale_by(input, resize_type["multiplier"], scale_method))
|
||||||
|
elif selected_type == ResizeType.SCALE_DIMENSIONS:
|
||||||
|
return io.NodeOutput(scale_dimensions(input, resize_type["width"], resize_type["height"], scale_method, resize_type["crop"]))
|
||||||
|
elif selected_type == ResizeType.SCALE_LONGER_DIMENSION:
|
||||||
|
return io.NodeOutput(scale_longer_dimension(input, resize_type["longer_size"], scale_method))
|
||||||
|
elif selected_type == ResizeType.SCALE_SHORTER_DIMENSION:
|
||||||
|
return io.NodeOutput(scale_shorter_dimension(input, resize_type["shorter_size"], scale_method))
|
||||||
|
elif selected_type == ResizeType.SCALE_WIDTH:
|
||||||
|
return io.NodeOutput(scale_dimensions(input, resize_type["width"], 0, scale_method))
|
||||||
|
elif selected_type == ResizeType.SCALE_HEIGHT:
|
||||||
|
return io.NodeOutput(scale_dimensions(input, 0, resize_type["height"], scale_method))
|
||||||
|
elif selected_type == ResizeType.SCALE_TOTAL_PIXELS:
|
||||||
|
return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method))
|
||||||
|
elif selected_type == ResizeType.MATCH_SIZE:
|
||||||
|
return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"]))
|
||||||
|
raise ValueError(f"Unsupported resize type: {selected_type}")
|
||||||
|
|
||||||
|
def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None:
|
||||||
|
if len(images) == 0:
|
||||||
|
return None
|
||||||
|
# first, get the max channels count
|
||||||
|
max_channels = max(image.shape[-1] for image in images)
|
||||||
|
# then, pad all images to have the same channels count
|
||||||
|
padded_images: list[torch.Tensor] = []
|
||||||
|
for image in images:
|
||||||
|
if image.shape[-1] < max_channels:
|
||||||
|
padded_images.append(torch.nn.functional.pad(image, (0,1), mode='constant', value=1.0))
|
||||||
|
else:
|
||||||
|
padded_images.append(image)
|
||||||
|
# resize all images to be the same size as the first image
|
||||||
|
resized_images: list[torch.Tensor] = []
|
||||||
|
first_image_shape = padded_images[0].shape
|
||||||
|
for image in padded_images:
|
||||||
|
if image.shape[1:] != first_image_shape[1:]:
|
||||||
|
resized_images.append(comfy.utils.common_upscale(image.movedim(-1,1), first_image_shape[2], first_image_shape[1], "bilinear", "center").movedim(1,-1))
|
||||||
|
else:
|
||||||
|
resized_images.append(image)
|
||||||
|
# batch the images in the format [b, h, w, c]
|
||||||
|
return torch.cat(resized_images, dim=0)
|
||||||
|
|
||||||
|
def batch_masks(masks: list[torch.Tensor]) -> torch.Tensor | None:
|
||||||
|
if len(masks) == 0:
|
||||||
|
return None
|
||||||
|
# resize all masks to be the same size as the first mask
|
||||||
|
resized_masks: list[torch.Tensor] = []
|
||||||
|
first_mask_shape = masks[0].shape
|
||||||
|
for mask in masks:
|
||||||
|
if mask.shape[1:] != first_mask_shape[1:]:
|
||||||
|
mask = init_image_mask_input(mask, is_type_image=False)
|
||||||
|
mask = comfy.utils.common_upscale(mask, first_mask_shape[2], first_mask_shape[1], "bilinear", "center")
|
||||||
|
resized_masks.append(finalize_image_mask_input(mask, is_type_image=False))
|
||||||
|
else:
|
||||||
|
resized_masks.append(mask)
|
||||||
|
# batch the masks in the format [b, h, w]
|
||||||
|
return torch.cat(resized_masks, dim=0)
|
||||||
|
|
||||||
|
def batch_latents(latents: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor] | None:
|
||||||
|
if len(latents) == 0:
|
||||||
|
return None
|
||||||
|
samples_out = latents[0].copy()
|
||||||
|
samples_out["batch_index"] = []
|
||||||
|
first_samples = latents[0]["samples"]
|
||||||
|
tensors: list[torch.Tensor] = []
|
||||||
|
for latent in latents:
|
||||||
|
# first, deal with latent tensors
|
||||||
|
tensors.append(reshape_latent_to(first_samples.shape, latent["samples"], repeat_batch=False))
|
||||||
|
# next, deal with batch_index
|
||||||
|
samples_out["batch_index"].extend(latent.get("batch_index", [x for x in range(0, latent["samples"].shape[0])]))
|
||||||
|
samples_out["samples"] = torch.cat(tensors, dim=0)
|
||||||
|
return samples_out
|
||||||
|
|
||||||
|
class BatchImagesNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
autogrow_template = io.Autogrow.TemplatePrefix(io.Image.Input("image"), prefix="image", min=2, max=50)
|
||||||
|
return io.Schema(
|
||||||
|
node_id="BatchImagesNode",
|
||||||
|
display_name="Batch Images",
|
||||||
|
category="image",
|
||||||
|
inputs=[
|
||||||
|
io.Autogrow.Input("images", template=autogrow_template)
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, images: io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(batch_images(list(images.values())))
|
||||||
|
|
||||||
|
class BatchMasksNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50)
|
||||||
|
return io.Schema(
|
||||||
|
node_id="BatchMasksNode",
|
||||||
|
display_name="Batch Masks",
|
||||||
|
category="mask",
|
||||||
|
inputs=[
|
||||||
|
io.Autogrow.Input("masks", template=autogrow_template)
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Mask.Output()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, masks: io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(batch_masks(list(masks.values())))
|
||||||
|
|
||||||
|
class BatchLatentsNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50)
|
||||||
|
return io.Schema(
|
||||||
|
node_id="BatchLatentsNode",
|
||||||
|
display_name="Batch Latents",
|
||||||
|
category="latent",
|
||||||
|
inputs=[
|
||||||
|
io.Autogrow.Input("latents", template=autogrow_template)
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, latents: io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(batch_latents(list(latents.values())))
|
||||||
|
|
||||||
|
class BatchImagesMasksLatentsNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
matchtype_template = io.MatchType.Template("input", allowed_types=[io.Image, io.Mask, io.Latent])
|
||||||
|
autogrow_template = io.Autogrow.TemplatePrefix(
|
||||||
|
io.MatchType.Input("input", matchtype_template),
|
||||||
|
prefix="input", min=1, max=50)
|
||||||
|
return io.Schema(
|
||||||
|
node_id="BatchImagesMasksLatentsNode",
|
||||||
|
display_name="Batch Images/Masks/Latents",
|
||||||
|
category="util",
|
||||||
|
inputs=[
|
||||||
|
io.Autogrow.Input("inputs", template=autogrow_template)
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.MatchType.Output(id=None, template=matchtype_template)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
batched = None
|
||||||
|
values = list(inputs.values())
|
||||||
|
# latents
|
||||||
|
if isinstance(values[0], dict):
|
||||||
|
batched = batch_latents(values)
|
||||||
|
# images
|
||||||
|
elif is_image(values[0]):
|
||||||
|
batched = batch_images(values)
|
||||||
|
# masks
|
||||||
|
else:
|
||||||
|
batched = batch_masks(values)
|
||||||
|
return io.NodeOutput(batched)
|
||||||
|
|
||||||
class PostProcessingExtension(ComfyExtension):
|
class PostProcessingExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -250,6 +601,11 @@ class PostProcessingExtension(ComfyExtension):
|
|||||||
Quantize,
|
Quantize,
|
||||||
Sharpen,
|
Sharpen,
|
||||||
ImageScaleToTotalPixels,
|
ImageScaleToTotalPixels,
|
||||||
|
ResizeImageMaskNode,
|
||||||
|
BatchImagesNode,
|
||||||
|
BatchMasksNode,
|
||||||
|
BatchLatentsNode,
|
||||||
|
# BatchImagesMasksLatentsNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def comfy_entrypoint() -> PostProcessingExtension:
|
async def comfy_entrypoint() -> PostProcessingExtension:
|
||||||
|
|||||||
@ -66,7 +66,7 @@ class Float(io.ComfyNode):
|
|||||||
display_name="Float",
|
display_name="Float",
|
||||||
category="utils/primitive",
|
category="utils/primitive",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize),
|
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize, step=0.1),
|
||||||
],
|
],
|
||||||
outputs=[io.Float.Output()],
|
outputs=[io.Float.Output()],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -78,18 +78,20 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
|||||||
overlap = 32
|
overlap = 32
|
||||||
|
|
||||||
oom = True
|
oom = True
|
||||||
while oom:
|
try:
|
||||||
try:
|
while oom:
|
||||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
try:
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
oom = False
|
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||||
except model_management.OOM_EXCEPTION as e:
|
oom = False
|
||||||
tile //= 2
|
except model_management.OOM_EXCEPTION as e:
|
||||||
if tile < 128:
|
tile //= 2
|
||||||
raise e
|
if tile < 128:
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
upscale_model.to("cpu")
|
||||||
|
|
||||||
upscale_model.to("cpu")
|
|
||||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
||||||
return io.NodeOutput(s)
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
|
|||||||
@ -817,7 +817,7 @@ def get_sample_indices(original_fps,
|
|||||||
if required_duration > total_frames / original_fps:
|
if required_duration > total_frames / original_fps:
|
||||||
raise ValueError("required_duration must be less than video length")
|
raise ValueError("required_duration must be less than video length")
|
||||||
|
|
||||||
if not fixed_start is None and fixed_start >= 0:
|
if fixed_start is not None and fixed_start >= 0:
|
||||||
start_frame = fixed_start
|
start_frame = fixed_start
|
||||||
else:
|
else:
|
||||||
max_start = total_frames - required_origin_frames
|
max_start = total_frames - required_origin_frames
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.6.0"
|
__version__ = "0.8.0"
|
||||||
|
|||||||
46
execution.py
46
execution.py
@ -79,7 +79,7 @@ class IsChangedCache:
|
|||||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||||
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
|
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
|
||||||
try:
|
try:
|
||||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
|
||||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -148,13 +148,12 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
|||||||
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
|
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
|
||||||
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||||
v3_data: io.V3Data = {}
|
v3_data: io.V3Data = {}
|
||||||
|
hidden_inputs_v3 = {}
|
||||||
|
valid_inputs = class_def.INPUT_TYPES()
|
||||||
if is_v3:
|
if is_v3:
|
||||||
valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
|
valid_inputs, hidden, v3_data = _io.get_finalized_class_inputs(valid_inputs, inputs)
|
||||||
else:
|
|
||||||
valid_inputs = class_def.INPUT_TYPES()
|
|
||||||
input_data_all = {}
|
input_data_all = {}
|
||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
hidden_inputs_v3 = {}
|
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||||
@ -180,18 +179,18 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
|||||||
input_data_all[x] = [input_data]
|
input_data_all[x] = [input_data]
|
||||||
|
|
||||||
if is_v3:
|
if is_v3:
|
||||||
if schema.hidden:
|
if hidden is not None:
|
||||||
if io.Hidden.prompt in schema.hidden:
|
if io.Hidden.prompt.name in hidden:
|
||||||
hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
|
hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
|
||||||
if io.Hidden.dynprompt in schema.hidden:
|
if io.Hidden.dynprompt.name in hidden:
|
||||||
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
|
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
|
||||||
if io.Hidden.extra_pnginfo in schema.hidden:
|
if io.Hidden.extra_pnginfo.name in hidden:
|
||||||
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
|
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
|
||||||
if io.Hidden.unique_id in schema.hidden:
|
if io.Hidden.unique_id.name in hidden:
|
||||||
hidden_inputs_v3[io.Hidden.unique_id] = unique_id
|
hidden_inputs_v3[io.Hidden.unique_id] = unique_id
|
||||||
if io.Hidden.auth_token_comfy_org in schema.hidden:
|
if io.Hidden.auth_token_comfy_org.name in hidden:
|
||||||
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||||
if io.Hidden.api_key_comfy_org in schema.hidden:
|
if io.Hidden.api_key_comfy_org.name in hidden:
|
||||||
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||||
else:
|
else:
|
||||||
if "hidden" in valid_inputs:
|
if "hidden" in valid_inputs:
|
||||||
@ -258,7 +257,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
|||||||
pre_execute_cb(index)
|
pre_execute_cb(index)
|
||||||
# V3
|
# V3
|
||||||
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
|
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
|
||||||
# if is just a class, then assign no resources or state, just create clone
|
# if is just a class, then assign no state, just create clone
|
||||||
if is_class(obj):
|
if is_class(obj):
|
||||||
type_obj = obj
|
type_obj = obj
|
||||||
obj.VALIDATE_CLASS()
|
obj.VALIDATE_CLASS()
|
||||||
@ -481,7 +480,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
else:
|
else:
|
||||||
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||||
if lazy_status_present:
|
if lazy_status_present:
|
||||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data)
|
# for check_lazy_status, the returned data should include the original key of the input
|
||||||
|
v3_data_lazy = v3_data.copy()
|
||||||
|
v3_data_lazy["create_dynamic_tuple"] = True
|
||||||
|
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data_lazy)
|
||||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||||
@ -599,6 +601,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
|
|
||||||
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
||||||
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
||||||
|
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
||||||
logging.error("Got an OOM, unloading all loaded models.")
|
logging.error("Got an OOM, unloading all loaded models.")
|
||||||
comfy.model_management.unload_all_models()
|
comfy.model_management.unload_all_models()
|
||||||
|
|
||||||
@ -756,10 +759,13 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
errors = []
|
errors = []
|
||||||
valid = True
|
valid = True
|
||||||
|
|
||||||
|
v3_data = None
|
||||||
validate_function_inputs = []
|
validate_function_inputs = []
|
||||||
validate_has_kwargs = False
|
validate_has_kwargs = False
|
||||||
if issubclass(obj_class, _ComfyNodeInternal):
|
if issubclass(obj_class, _ComfyNodeInternal):
|
||||||
class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
|
obj_class: _io._ComfyNodeBaseInternal
|
||||||
|
class_inputs = obj_class.INPUT_TYPES()
|
||||||
|
class_inputs, _, v3_data = _io.get_finalized_class_inputs(class_inputs, inputs)
|
||||||
validate_function_name = "validate_inputs"
|
validate_function_name = "validate_inputs"
|
||||||
validate_function = first_real_override(obj_class, validate_function_name)
|
validate_function = first_real_override(obj_class, validate_function_name)
|
||||||
else:
|
else:
|
||||||
@ -779,10 +785,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
assert extra_info is not None
|
assert extra_info is not None
|
||||||
if x not in inputs:
|
if x not in inputs:
|
||||||
if input_category == "required":
|
if input_category == "required":
|
||||||
|
details = f"{x}" if not v3_data else x.split(".")[-1]
|
||||||
error = {
|
error = {
|
||||||
"type": "required_input_missing",
|
"type": "required_input_missing",
|
||||||
"message": "Required input is missing",
|
"message": "Required input is missing",
|
||||||
"details": f"{x}",
|
"details": details,
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x
|
"input_name": x
|
||||||
}
|
}
|
||||||
@ -916,8 +923,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(input_type, list):
|
if isinstance(input_type, list) or input_type == io.Combo.io_type:
|
||||||
combo_options = input_type
|
if input_type == io.Combo.io_type:
|
||||||
|
combo_options = extra_info.get("options", [])
|
||||||
|
else:
|
||||||
|
combo_options = input_type
|
||||||
if val not in combo_options:
|
if val not in combo_options:
|
||||||
input_config = info
|
input_config = info
|
||||||
list_info = ""
|
list_info = ""
|
||||||
|
|||||||
24
nodes.py
24
nodes.py
@ -295,7 +295,11 @@ class VAEDecode:
|
|||||||
DESCRIPTION = "Decodes latent images back into pixel space images."
|
DESCRIPTION = "Decodes latent images back into pixel space images."
|
||||||
|
|
||||||
def decode(self, vae, samples):
|
def decode(self, vae, samples):
|
||||||
images = vae.decode(samples["samples"])
|
latent = samples["samples"]
|
||||||
|
if latent.is_nested:
|
||||||
|
latent = latent.unbind()[0]
|
||||||
|
|
||||||
|
images = vae.decode(latent)
|
||||||
if len(images.shape) == 5: #Combine batches
|
if len(images.shape) == 5: #Combine batches
|
||||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||||
return (images, )
|
return (images, )
|
||||||
@ -970,7 +974,7 @@ class DualCLIPLoader:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ),
|
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -1663,8 +1667,6 @@ class LoadImage:
|
|||||||
output_masks = []
|
output_masks = []
|
||||||
w, h = None, None
|
w, h = None, None
|
||||||
|
|
||||||
excluded_formats = ['MPO']
|
|
||||||
|
|
||||||
for i in ImageSequence.Iterator(img):
|
for i in ImageSequence.Iterator(img):
|
||||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||||
|
|
||||||
@ -1692,7 +1694,10 @@ class LoadImage:
|
|||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
output_masks.append(mask.unsqueeze(0))
|
output_masks.append(mask.unsqueeze(0))
|
||||||
|
|
||||||
if len(output_images) > 1 and img.format not in excluded_formats:
|
if img.format == "MPO":
|
||||||
|
break # ignore all frames except the first one for MPO format
|
||||||
|
|
||||||
|
if len(output_images) > 1:
|
||||||
output_image = torch.cat(output_images, dim=0)
|
output_image = torch.cat(output_images, dim=0)
|
||||||
output_mask = torch.cat(output_masks, dim=0)
|
output_mask = torch.cat(output_masks, dim=0)
|
||||||
else:
|
else:
|
||||||
@ -1863,6 +1868,7 @@ class ImageBatch:
|
|||||||
FUNCTION = "batch"
|
FUNCTION = "batch"
|
||||||
|
|
||||||
CATEGORY = "image"
|
CATEGORY = "image"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
def batch(self, image1, image2):
|
def batch(self, image1, image2):
|
||||||
if image1.shape[-1] != image2.shape[-1]:
|
if image1.shape[-1] != image2.shape[-1]:
|
||||||
@ -2241,8 +2247,10 @@ async def init_external_custom_nodes():
|
|||||||
|
|
||||||
for possible_module in possible_modules:
|
for possible_module in possible_modules:
|
||||||
module_path = os.path.join(custom_node_path, possible_module)
|
module_path = os.path.join(custom_node_path, possible_module)
|
||||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py":
|
||||||
if module_path.endswith(".disabled"): continue
|
continue
|
||||||
|
if module_path.endswith(".disabled"):
|
||||||
|
continue
|
||||||
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
||||||
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||||
continue
|
continue
|
||||||
@ -2327,6 +2335,8 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_mochi.py",
|
"nodes_mochi.py",
|
||||||
"nodes_slg.py",
|
"nodes_slg.py",
|
||||||
"nodes_mahiro.py",
|
"nodes_mahiro.py",
|
||||||
|
"nodes_lt_upsampler.py",
|
||||||
|
"nodes_lt_audio.py",
|
||||||
"nodes_lt.py",
|
"nodes_lt.py",
|
||||||
"nodes_hooks.py",
|
"nodes_hooks.py",
|
||||||
"nodes_load_3d.py",
|
"nodes_load_3d.py",
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.6.0"
|
version = "0.8.0"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.10"
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
homepage = "https://www.comfy.org/"
|
homepage = "https://www.comfy.org/"
|
||||||
@ -15,12 +15,16 @@ lint.select = [
|
|||||||
"N805", # invalid-first-argument-name-for-method
|
"N805", # invalid-first-argument-name-for-method
|
||||||
"S307", # suspicious-eval-usage
|
"S307", # suspicious-eval-usage
|
||||||
"S102", # exec
|
"S102", # exec
|
||||||
|
"E",
|
||||||
"T", # print-usage
|
"T", # print-usage
|
||||||
"W",
|
"W",
|
||||||
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
||||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||||
"F",
|
"F",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
lint.ignore = ["E501", "E722", "E731", "E712", "E402", "E741"]
|
||||||
|
|
||||||
exclude = ["*.ipynb", "**/generated/*.pyi"]
|
exclude = ["*.ipynb", "**/generated/*.pyi"]
|
||||||
|
|
||||||
[tool.pylint]
|
[tool.pylint]
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.35.9
|
comfyui-frontend-package==1.35.9
|
||||||
comfyui-workflow-templates==0.7.64
|
comfyui-workflow-templates==0.7.67
|
||||||
comfyui-embedded-docs==0.3.1
|
comfyui-embedded-docs==0.3.1
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
@ -21,6 +21,7 @@ psutil
|
|||||||
alembic
|
alembic
|
||||||
SQLAlchemy
|
SQLAlchemy
|
||||||
av>=14.2.0
|
av>=14.2.0
|
||||||
|
comfy-kitchen>=0.2.3
|
||||||
|
|
||||||
#non essential dependencies:
|
#non essential dependencies:
|
||||||
kornia>=0.7.1
|
kornia>=0.7.1
|
||||||
|
|||||||
@ -324,7 +324,7 @@ class PromptServer():
|
|||||||
@routes.get("/models/{folder}")
|
@routes.get("/models/{folder}")
|
||||||
async def get_models(request):
|
async def get_models(request):
|
||||||
folder = request.match_info.get("folder", None)
|
folder = request.match_info.get("folder", None)
|
||||||
if not folder in folder_paths.folder_names_and_paths:
|
if folder not in folder_paths.folder_names_and_paths:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
files = folder_paths.get_filename_list(folder)
|
files = folder_paths.get_filename_list(folder)
|
||||||
return web.json_response(files)
|
return web.json_response(files)
|
||||||
@ -579,7 +579,7 @@ class PromptServer():
|
|||||||
folder_name = request.match_info.get("folder_name", None)
|
folder_name = request.match_info.get("folder_name", None)
|
||||||
if folder_name is None:
|
if folder_name is None:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
if not "filename" in request.rel_url.query:
|
if "filename" not in request.rel_url.query:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
filename = request.rel_url.query["filename"]
|
filename = request.rel_url.query["filename"]
|
||||||
@ -593,7 +593,7 @@ class PromptServer():
|
|||||||
if out is None:
|
if out is None:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
dt = json.loads(out)
|
dt = json.loads(out)
|
||||||
if not "__metadata__" in dt:
|
if "__metadata__" not in dt:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
return web.json_response(dt["__metadata__"])
|
return web.json_response(dt["__metadata__"])
|
||||||
|
|
||||||
|
|||||||
@ -103,18 +103,18 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
|
|
||||||
# Verify weights are wrapped in QuantizedTensor
|
# Verify weights are wrapped in QuantizedTensor
|
||||||
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
||||||
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
|
self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout")
|
||||||
|
|
||||||
# Layer 2 should NOT be quantized
|
# Layer 2 should NOT be quantized
|
||||||
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
||||||
|
|
||||||
# Layer 3 should be quantized
|
# Layer 3 should be quantized
|
||||||
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
||||||
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
|
self.assertEqual(model.layer3.weight._layout_cls, "TensorCoreFP8E4M3Layout")
|
||||||
|
|
||||||
# Verify scales were loaded
|
# Verify scales were loaded
|
||||||
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
self.assertEqual(model.layer1.weight._params.scale.item(), 2.0)
|
||||||
self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5)
|
self.assertEqual(model.layer3.weight._params.scale.item(), 1.5)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
@ -154,8 +154,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
|
|
||||||
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
||||||
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
||||||
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0)
|
||||||
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
|
self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout")
|
||||||
|
|
||||||
# Verify non-quantized layers are standard tensors
|
# Verify non-quantized layers are standard tensors
|
||||||
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
||||||
|
|||||||
@ -1,190 +0,0 @@
|
|||||||
import unittest
|
|
||||||
import torch
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Add comfy to path
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
||||||
|
|
||||||
def has_gpu():
|
|
||||||
return torch.cuda.is_available()
|
|
||||||
|
|
||||||
from comfy.cli_args import args
|
|
||||||
if not has_gpu():
|
|
||||||
args.cpu = True
|
|
||||||
|
|
||||||
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
|
||||||
|
|
||||||
|
|
||||||
class TestQuantizedTensor(unittest.TestCase):
|
|
||||||
"""Test the QuantizedTensor subclass with FP8 layout"""
|
|
||||||
|
|
||||||
def test_creation(self):
|
|
||||||
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
|
|
||||||
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
||||||
scale = torch.tensor(2.0)
|
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
|
||||||
|
|
||||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
||||||
|
|
||||||
self.assertIsInstance(qt, QuantizedTensor)
|
|
||||||
self.assertEqual(qt.shape, (256, 128))
|
|
||||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
|
||||||
self.assertEqual(qt._layout_params['scale'], scale)
|
|
||||||
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
|
||||||
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
|
|
||||||
|
|
||||||
def test_dequantize(self):
|
|
||||||
"""Test explicit dequantization"""
|
|
||||||
|
|
||||||
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
||||||
scale = torch.tensor(3.0)
|
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
||||||
|
|
||||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
||||||
dequantized = qt.dequantize()
|
|
||||||
|
|
||||||
self.assertEqual(dequantized.dtype, torch.float32)
|
|
||||||
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
|
||||||
|
|
||||||
def test_from_float(self):
|
|
||||||
"""Test creating QuantizedTensor from float tensor"""
|
|
||||||
float_tensor = torch.randn(64, 32, dtype=torch.float32)
|
|
||||||
scale = torch.tensor(1.5)
|
|
||||||
|
|
||||||
qt = QuantizedTensor.from_float(
|
|
||||||
float_tensor,
|
|
||||||
"TensorCoreFP8Layout",
|
|
||||||
scale=scale,
|
|
||||||
dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(qt, QuantizedTensor)
|
|
||||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
|
||||||
self.assertEqual(qt.shape, (64, 32))
|
|
||||||
|
|
||||||
# Verify dequantization gives approximately original values
|
|
||||||
dequantized = qt.dequantize()
|
|
||||||
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
|
|
||||||
self.assertLess(mean_rel_error, 0.1)
|
|
||||||
|
|
||||||
|
|
||||||
class TestGenericUtilities(unittest.TestCase):
|
|
||||||
"""Test generic utility operations"""
|
|
||||||
|
|
||||||
def test_detach(self):
|
|
||||||
"""Test detach operation on quantized tensor"""
|
|
||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
||||||
scale = torch.tensor(1.5)
|
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
||||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
||||||
|
|
||||||
# Detach should return a new QuantizedTensor
|
|
||||||
qt_detached = qt.detach()
|
|
||||||
|
|
||||||
self.assertIsInstance(qt_detached, QuantizedTensor)
|
|
||||||
self.assertEqual(qt_detached.shape, qt.shape)
|
|
||||||
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
|
|
||||||
|
|
||||||
def test_clone(self):
|
|
||||||
"""Test clone operation on quantized tensor"""
|
|
||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
||||||
scale = torch.tensor(1.5)
|
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
||||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
||||||
|
|
||||||
# Clone should return a new QuantizedTensor
|
|
||||||
qt_cloned = qt.clone()
|
|
||||||
|
|
||||||
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
|
||||||
self.assertEqual(qt_cloned.shape, qt.shape)
|
|
||||||
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
|
|
||||||
|
|
||||||
# Verify it's a deep copy
|
|
||||||
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
|
||||||
|
|
||||||
@unittest.skipUnless(has_gpu(), "GPU not available")
|
|
||||||
def test_to_device(self):
|
|
||||||
"""Test device transfer"""
|
|
||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
||||||
scale = torch.tensor(1.5)
|
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
||||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
||||||
|
|
||||||
# Moving to same device should work (CPU to CPU)
|
|
||||||
qt_cpu = qt.to('cpu')
|
|
||||||
|
|
||||||
self.assertIsInstance(qt_cpu, QuantizedTensor)
|
|
||||||
self.assertEqual(qt_cpu.device.type, 'cpu')
|
|
||||||
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
|
|
||||||
|
|
||||||
|
|
||||||
class TestTensorCoreFP8Layout(unittest.TestCase):
|
|
||||||
"""Test the TensorCoreFP8Layout implementation"""
|
|
||||||
|
|
||||||
def test_quantize(self):
|
|
||||||
"""Test quantization method"""
|
|
||||||
float_tensor = torch.randn(32, 64, dtype=torch.float32)
|
|
||||||
scale = torch.tensor(1.5)
|
|
||||||
|
|
||||||
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
|
||||||
float_tensor,
|
|
||||||
scale=scale,
|
|
||||||
dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
|
|
||||||
self.assertEqual(qdata.shape, float_tensor.shape)
|
|
||||||
self.assertIn('scale', layout_params)
|
|
||||||
self.assertIn('orig_dtype', layout_params)
|
|
||||||
self.assertEqual(layout_params['orig_dtype'], torch.float32)
|
|
||||||
|
|
||||||
def test_dequantize(self):
|
|
||||||
"""Test dequantization method"""
|
|
||||||
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
|
|
||||||
scale = torch.tensor(1.0)
|
|
||||||
|
|
||||||
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
|
||||||
float_tensor,
|
|
||||||
scale=scale,
|
|
||||||
dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
|
|
||||||
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
|
|
||||||
|
|
||||||
# Should approximately match original
|
|
||||||
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
|
|
||||||
|
|
||||||
|
|
||||||
class TestFallbackMechanism(unittest.TestCase):
|
|
||||||
"""Test fallback for unsupported operations"""
|
|
||||||
|
|
||||||
def test_unsupported_op_dequantizes(self):
|
|
||||||
"""Test that unsupported operations fall back to dequantization"""
|
|
||||||
# Set seed for reproducibility
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
# Create quantized tensor
|
|
||||||
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
|
||||||
scale = torch.tensor(1.0)
|
|
||||||
a_q = QuantizedTensor.from_float(
|
|
||||||
a_fp32,
|
|
||||||
"TensorCoreFP8Layout",
|
|
||||||
scale=scale,
|
|
||||||
dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call an operation that doesn't have a registered handler
|
|
||||||
# For example, torch.abs
|
|
||||||
result = torch.abs(a_q)
|
|
||||||
|
|
||||||
# Should work via fallback (dequantize → abs → return)
|
|
||||||
self.assertNotIsInstance(result, QuantizedTensor)
|
|
||||||
expected = torch.abs(a_fp32)
|
|
||||||
# FP8 introduces quantization error, so use loose tolerance
|
|
||||||
mean_error = (result - expected).abs().mean()
|
|
||||||
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
Loading…
Reference in New Issue
Block a user