mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-01 21:02:30 +08:00
Merge branch 'master' into ops-changes
This commit is contained in:
commit
4130be262e
@ -119,6 +119,9 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
|
|||||||
|
|
||||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||||
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
||||||
|
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
|
||||||
|
- Minor versions will be used for releases off the master branch.
|
||||||
|
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
|
||||||
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
|
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
|
||||||
- Serves as the foundation for the desktop release
|
- Serves as the foundation for the desktop release
|
||||||
|
|
||||||
@ -209,6 +212,8 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
|
|||||||
|
|
||||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||||
|
|
||||||
|
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old.
|
||||||
|
|
||||||
### Instructions:
|
### Instructions:
|
||||||
|
|
||||||
Git clone this repo.
|
Git clone this repo.
|
||||||
|
|||||||
@ -44,7 +44,7 @@ class ModelFileManager:
|
|||||||
@routes.get("/experiment/models/{folder}")
|
@routes.get("/experiment/models/{folder}")
|
||||||
async def get_all_models(request):
|
async def get_all_models(request):
|
||||||
folder = request.match_info.get("folder", None)
|
folder = request.match_info.get("folder", None)
|
||||||
if not folder in folder_paths.folder_names_and_paths:
|
if folder not in folder_paths.folder_names_and_paths:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
files = self.get_model_file_list(folder)
|
files = self.get_model_file_list(folder)
|
||||||
return web.json_response(files)
|
return web.json_response(files)
|
||||||
@ -55,7 +55,7 @@ class ModelFileManager:
|
|||||||
path_index = int(request.match_info.get("path_index", None))
|
path_index = int(request.match_info.get("path_index", None))
|
||||||
filename = request.match_info.get("filename", None)
|
filename = request.match_info.get("filename", None)
|
||||||
|
|
||||||
if not folder_name in folder_paths.folder_names_and_paths:
|
if folder_name not in folder_paths.folder_names_and_paths:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||||
|
|||||||
@ -2,6 +2,25 @@ import torch
|
|||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
|
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||||
|
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||||
|
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||||
|
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||||
|
image = image.movedim(-1, 1)
|
||||||
|
if not (image.shape[2] == size and image.shape[3] == size):
|
||||||
|
if crop:
|
||||||
|
scale = (size / min(image.shape[2], image.shape[3]))
|
||||||
|
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||||
|
else:
|
||||||
|
scale_size = (size, size)
|
||||||
|
|
||||||
|
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||||
|
h = (image.shape[2] - size)//2
|
||||||
|
w = (image.shape[3] - size)//2
|
||||||
|
image = image[:,:,h:h+size,w:w+size]
|
||||||
|
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||||
|
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||||
|
|
||||||
class CLIPAttention(torch.nn.Module):
|
class CLIPAttention(torch.nn.Module):
|
||||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -17,24 +16,7 @@ class Output:
|
|||||||
def __setitem__(self, key, item):
|
def __setitem__(self, key, item):
|
||||||
setattr(self, key, item)
|
setattr(self, key, item)
|
||||||
|
|
||||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually
|
||||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
|
||||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
|
||||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
|
||||||
image = image.movedim(-1, 1)
|
|
||||||
if not (image.shape[2] == size and image.shape[3] == size):
|
|
||||||
if crop:
|
|
||||||
scale = (size / min(image.shape[2], image.shape[3]))
|
|
||||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
|
||||||
else:
|
|
||||||
scale_size = (size, size)
|
|
||||||
|
|
||||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
|
||||||
h = (image.shape[2] - size)//2
|
|
||||||
w = (image.shape[3] - size)//2
|
|
||||||
image = image[:,:,h:h+size,w:w+size]
|
|
||||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
|
||||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
|
||||||
|
|
||||||
IMAGE_ENCODERS = {
|
IMAGE_ENCODERS = {
|
||||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
@ -73,7 +55,7 @@ class ClipVisionModel():
|
|||||||
|
|
||||||
def encode_image(self, image, crop=True):
|
def encode_image(self, image, crop=True):
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||||
|
|
||||||
outputs = Output()
|
outputs = Output()
|
||||||
|
|||||||
@ -143,7 +143,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
# if multiple conds, split based on primary region
|
# if multiple conds, split based on primary region
|
||||||
if self.split_conds_to_windows and len(cond_in) > 1:
|
if self.split_conds_to_windows and len(cond_in) > 1:
|
||||||
region = window.get_region_index(len(cond_in))
|
region = window.get_region_index(len(cond_in))
|
||||||
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
|
logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}")
|
||||||
cond_in = [cond_in[region]]
|
cond_in = [cond_in[region]]
|
||||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||||
for actual_cond in cond_in:
|
for actual_cond in cond_in:
|
||||||
@ -188,6 +188,12 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
audio_cond = cond_value.cond
|
audio_cond = cond_value.cond
|
||||||
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
||||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
||||||
|
# Handle vace_context (temporal dim is 3)
|
||||||
|
elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
|
vace_cond = cond_value.cond
|
||||||
|
if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim):
|
||||||
|
sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list)
|
||||||
|
new_cond_item[cond_key] = cond_value._copy_with(sliced_vace)
|
||||||
# if has cond that is a Tensor, check if needs to be subset
|
# if has cond that is a Tensor, check if needs to be subset
|
||||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
||||||
|
|||||||
@ -527,7 +527,8 @@ class HookKeyframeGroup:
|
|||||||
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
||||||
break
|
break
|
||||||
# if eval_c is outside the percent range, stop looking further
|
# if eval_c is outside the percent range, stop looking further
|
||||||
else: break
|
else:
|
||||||
|
break
|
||||||
# update steps current context is used
|
# update steps current context is used
|
||||||
self._current_used_steps += 1
|
self._current_used_steps += 1
|
||||||
# update current timestep this was performed on
|
# update current timestep this was performed on
|
||||||
|
|||||||
@ -74,6 +74,9 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
|||||||
|
|
||||||
def default_noise_sampler(x, seed=None):
|
def default_noise_sampler(x, seed=None):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
|
if x.device == torch.device("cpu"):
|
||||||
|
seed += 1
|
||||||
|
|
||||||
generator = torch.Generator(device=x.device)
|
generator = torch.Generator(device=x.device)
|
||||||
generator.manual_seed(seed)
|
generator.manual_seed(seed)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -407,6 +407,9 @@ class LTXV(LatentFormat):
|
|||||||
|
|
||||||
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
||||||
|
|
||||||
|
class LTXAV(LTXV):
|
||||||
|
pass
|
||||||
|
|
||||||
class HunyuanVideo(LatentFormat):
|
class HunyuanVideo(LatentFormat):
|
||||||
latent_channels = 16
|
latent_channels = 16
|
||||||
latent_dimensions = 3
|
latent_dimensions = 3
|
||||||
|
|||||||
@ -270,7 +270,7 @@ class ChromaRadiance(Chroma):
|
|||||||
bad_keys = tuple(
|
bad_keys = tuple(
|
||||||
k
|
k
|
||||||
for k, v in overrides.items()
|
for k, v in overrides.items()
|
||||||
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
|
if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys)
|
||||||
)
|
)
|
||||||
if bad_keys:
|
if bad_keys:
|
||||||
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||||
|
|||||||
@ -3,7 +3,8 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||||
import model_management, model_patcher
|
import model_management
|
||||||
|
import model_patcher
|
||||||
|
|
||||||
class SRResidualCausalBlock3D(nn.Module):
|
class SRResidualCausalBlock3D(nn.Module):
|
||||||
def __init__(self, channels: int):
|
def __init__(self, channels: int):
|
||||||
|
|||||||
837
comfy/ldm/lightricks/av_model.py
Normal file
837
comfy/ldm/lightricks/av_model.py
Normal file
@ -0,0 +1,837 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from comfy.ldm.lightricks.model import (
|
||||||
|
CrossAttention,
|
||||||
|
FeedForward,
|
||||||
|
AdaLayerNormSingle,
|
||||||
|
PixArtAlphaTextProjection,
|
||||||
|
LTXVModel,
|
||||||
|
)
|
||||||
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
class BasicAVTransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
v_dim,
|
||||||
|
a_dim,
|
||||||
|
v_heads,
|
||||||
|
a_heads,
|
||||||
|
vd_head,
|
||||||
|
ad_head,
|
||||||
|
v_context_dim=None,
|
||||||
|
a_context_dim=None,
|
||||||
|
attn_precision=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn_precision = attn_precision
|
||||||
|
|
||||||
|
self.attn1 = CrossAttention(
|
||||||
|
query_dim=v_dim,
|
||||||
|
heads=v_heads,
|
||||||
|
dim_head=vd_head,
|
||||||
|
context_dim=None,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.audio_attn1 = CrossAttention(
|
||||||
|
query_dim=a_dim,
|
||||||
|
heads=a_heads,
|
||||||
|
dim_head=ad_head,
|
||||||
|
context_dim=None,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attn2 = CrossAttention(
|
||||||
|
query_dim=v_dim,
|
||||||
|
context_dim=v_context_dim,
|
||||||
|
heads=v_heads,
|
||||||
|
dim_head=vd_head,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.audio_attn2 = CrossAttention(
|
||||||
|
query_dim=a_dim,
|
||||||
|
context_dim=a_context_dim,
|
||||||
|
heads=a_heads,
|
||||||
|
dim_head=ad_head,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Q: Video, K,V: Audio
|
||||||
|
self.audio_to_video_attn = CrossAttention(
|
||||||
|
query_dim=v_dim,
|
||||||
|
context_dim=a_dim,
|
||||||
|
heads=a_heads,
|
||||||
|
dim_head=ad_head,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Q: Audio, K,V: Video
|
||||||
|
self.video_to_audio_attn = CrossAttention(
|
||||||
|
query_dim=a_dim,
|
||||||
|
context_dim=v_dim,
|
||||||
|
heads=a_heads,
|
||||||
|
dim_head=ad_head,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ff = FeedForward(
|
||||||
|
v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
self.audio_ff = FeedForward(
|
||||||
|
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
||||||
|
self.audio_scale_shift_table = nn.Parameter(
|
||||||
|
torch.empty(6, a_dim, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
||||||
|
torch.empty(5, a_dim, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
self.scale_shift_table_a2v_ca_video = nn.Parameter(
|
||||||
|
torch.empty(5, v_dim, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_ada_values(
|
||||||
|
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
||||||
|
):
|
||||||
|
num_ada_params = scale_shift_table.shape[0]
|
||||||
|
|
||||||
|
ada_values = (
|
||||||
|
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
|
||||||
|
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
|
||||||
|
).unbind(dim=2)
|
||||||
|
return ada_values
|
||||||
|
|
||||||
|
def get_av_ca_ada_values(
|
||||||
|
self,
|
||||||
|
scale_shift_table: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
scale_shift_timestep: torch.Tensor,
|
||||||
|
gate_timestep: torch.Tensor,
|
||||||
|
num_scale_shift_values: int = 4,
|
||||||
|
):
|
||||||
|
scale_shift_ada_values = self.get_ada_values(
|
||||||
|
scale_shift_table[:num_scale_shift_values, :],
|
||||||
|
batch_size,
|
||||||
|
scale_shift_timestep,
|
||||||
|
)
|
||||||
|
gate_ada_values = self.get_ada_values(
|
||||||
|
scale_shift_table[num_scale_shift_values:, :],
|
||||||
|
batch_size,
|
||||||
|
gate_timestep,
|
||||||
|
)
|
||||||
|
|
||||||
|
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
|
||||||
|
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
|
||||||
|
|
||||||
|
return (*scale_shift_chunks, *gate_ada_values)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
v_context=None,
|
||||||
|
a_context=None,
|
||||||
|
attention_mask=None,
|
||||||
|
v_timestep=None,
|
||||||
|
a_timestep=None,
|
||||||
|
v_pe=None,
|
||||||
|
a_pe=None,
|
||||||
|
v_cross_pe=None,
|
||||||
|
a_cross_pe=None,
|
||||||
|
v_cross_scale_shift_timestep=None,
|
||||||
|
a_cross_scale_shift_timestep=None,
|
||||||
|
v_cross_gate_timestep=None,
|
||||||
|
a_cross_gate_timestep=None,
|
||||||
|
transformer_options=None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
run_vx = transformer_options.get("run_vx", True)
|
||||||
|
run_ax = transformer_options.get("run_ax", True)
|
||||||
|
|
||||||
|
vx, ax = x
|
||||||
|
run_ax = run_ax and ax.numel() > 0
|
||||||
|
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||||
|
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||||
|
|
||||||
|
if run_vx:
|
||||||
|
vshift_msa, vscale_msa, vgate_msa = (
|
||||||
|
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
||||||
|
)
|
||||||
|
|
||||||
|
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||||
|
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
||||||
|
vx += self.attn2(
|
||||||
|
comfy.ldm.common_dit.rms_norm(vx),
|
||||||
|
context=v_context,
|
||||||
|
mask=attention_mask,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
del vshift_msa, vscale_msa, vgate_msa
|
||||||
|
|
||||||
|
if run_ax:
|
||||||
|
ashift_msa, ascale_msa, agate_msa = (
|
||||||
|
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
||||||
|
)
|
||||||
|
|
||||||
|
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||||
|
ax += (
|
||||||
|
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||||
|
* agate_msa
|
||||||
|
)
|
||||||
|
ax += self.audio_attn2(
|
||||||
|
comfy.ldm.common_dit.rms_norm(ax),
|
||||||
|
context=a_context,
|
||||||
|
mask=attention_mask,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
del ashift_msa, ascale_msa, agate_msa
|
||||||
|
|
||||||
|
# Audio - Video cross attention.
|
||||||
|
if run_a2v or run_v2a:
|
||||||
|
# norm3
|
||||||
|
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||||
|
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||||
|
|
||||||
|
(
|
||||||
|
scale_ca_audio_hidden_states_a2v,
|
||||||
|
shift_ca_audio_hidden_states_a2v,
|
||||||
|
scale_ca_audio_hidden_states_v2a,
|
||||||
|
shift_ca_audio_hidden_states_v2a,
|
||||||
|
gate_out_v2a,
|
||||||
|
) = self.get_av_ca_ada_values(
|
||||||
|
self.scale_shift_table_a2v_ca_audio,
|
||||||
|
ax.shape[0],
|
||||||
|
a_cross_scale_shift_timestep,
|
||||||
|
a_cross_gate_timestep,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
scale_ca_video_hidden_states_a2v,
|
||||||
|
shift_ca_video_hidden_states_a2v,
|
||||||
|
scale_ca_video_hidden_states_v2a,
|
||||||
|
shift_ca_video_hidden_states_v2a,
|
||||||
|
gate_out_a2v,
|
||||||
|
) = self.get_av_ca_ada_values(
|
||||||
|
self.scale_shift_table_a2v_ca_video,
|
||||||
|
vx.shape[0],
|
||||||
|
v_cross_scale_shift_timestep,
|
||||||
|
v_cross_gate_timestep,
|
||||||
|
)
|
||||||
|
|
||||||
|
if run_a2v:
|
||||||
|
vx_scaled = (
|
||||||
|
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
||||||
|
+ shift_ca_video_hidden_states_a2v
|
||||||
|
)
|
||||||
|
ax_scaled = (
|
||||||
|
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
||||||
|
+ shift_ca_audio_hidden_states_a2v
|
||||||
|
)
|
||||||
|
vx += (
|
||||||
|
self.audio_to_video_attn(
|
||||||
|
vx_scaled,
|
||||||
|
context=ax_scaled,
|
||||||
|
pe=v_cross_pe,
|
||||||
|
k_pe=a_cross_pe,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
* gate_out_a2v
|
||||||
|
)
|
||||||
|
|
||||||
|
del gate_out_a2v
|
||||||
|
del scale_ca_video_hidden_states_a2v,\
|
||||||
|
shift_ca_video_hidden_states_a2v,\
|
||||||
|
scale_ca_audio_hidden_states_a2v,\
|
||||||
|
shift_ca_audio_hidden_states_a2v,\
|
||||||
|
|
||||||
|
if run_v2a:
|
||||||
|
ax_scaled = (
|
||||||
|
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
||||||
|
+ shift_ca_audio_hidden_states_v2a
|
||||||
|
)
|
||||||
|
vx_scaled = (
|
||||||
|
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
||||||
|
+ shift_ca_video_hidden_states_v2a
|
||||||
|
)
|
||||||
|
ax += (
|
||||||
|
self.video_to_audio_attn(
|
||||||
|
ax_scaled,
|
||||||
|
context=vx_scaled,
|
||||||
|
pe=a_cross_pe,
|
||||||
|
k_pe=v_cross_pe,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
* gate_out_v2a
|
||||||
|
)
|
||||||
|
|
||||||
|
del gate_out_v2a
|
||||||
|
del scale_ca_video_hidden_states_v2a,\
|
||||||
|
shift_ca_video_hidden_states_v2a,\
|
||||||
|
scale_ca_audio_hidden_states_v2a,\
|
||||||
|
shift_ca_audio_hidden_states_v2a
|
||||||
|
|
||||||
|
if run_vx:
|
||||||
|
vshift_mlp, vscale_mlp, vgate_mlp = (
|
||||||
|
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
||||||
|
)
|
||||||
|
|
||||||
|
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||||
|
vx += self.ff(vx_scaled) * vgate_mlp
|
||||||
|
del vshift_mlp, vscale_mlp, vgate_mlp
|
||||||
|
|
||||||
|
if run_ax:
|
||||||
|
ashift_mlp, ascale_mlp, agate_mlp = (
|
||||||
|
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
||||||
|
)
|
||||||
|
|
||||||
|
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||||
|
ax += self.audio_ff(ax_scaled) * agate_mlp
|
||||||
|
|
||||||
|
del ashift_mlp, ascale_mlp, agate_mlp
|
||||||
|
|
||||||
|
|
||||||
|
return vx, ax
|
||||||
|
|
||||||
|
|
||||||
|
class LTXAVModel(LTXVModel):
|
||||||
|
"""LTXAV model for audio-video generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=128,
|
||||||
|
audio_in_channels=128,
|
||||||
|
cross_attention_dim=4096,
|
||||||
|
audio_cross_attention_dim=2048,
|
||||||
|
attention_head_dim=128,
|
||||||
|
audio_attention_head_dim=64,
|
||||||
|
num_attention_heads=32,
|
||||||
|
audio_num_attention_heads=32,
|
||||||
|
caption_channels=3840,
|
||||||
|
num_layers=48,
|
||||||
|
positional_embedding_theta=10000.0,
|
||||||
|
positional_embedding_max_pos=[20, 2048, 2048],
|
||||||
|
audio_positional_embedding_max_pos=[20],
|
||||||
|
causal_temporal_positioning=False,
|
||||||
|
vae_scale_factors=(8, 32, 32),
|
||||||
|
use_middle_indices_grid=False,
|
||||||
|
timestep_scale_multiplier=1000.0,
|
||||||
|
av_ca_timestep_scale_multiplier=1.0,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# Store audio-specific parameters
|
||||||
|
self.audio_in_channels = audio_in_channels
|
||||||
|
self.audio_cross_attention_dim = audio_cross_attention_dim
|
||||||
|
self.audio_attention_head_dim = audio_attention_head_dim
|
||||||
|
self.audio_num_attention_heads = audio_num_attention_heads
|
||||||
|
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
||||||
|
|
||||||
|
# Calculate audio dimensions
|
||||||
|
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
||||||
|
self.audio_out_channels = audio_in_channels
|
||||||
|
|
||||||
|
# Audio-specific constants
|
||||||
|
self.num_audio_channels = 8
|
||||||
|
self.audio_frequency_bins = 16
|
||||||
|
|
||||||
|
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
in_channels=in_channels,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
caption_channels=caption_channels,
|
||||||
|
num_layers=num_layers,
|
||||||
|
positional_embedding_theta=positional_embedding_theta,
|
||||||
|
positional_embedding_max_pos=positional_embedding_max_pos,
|
||||||
|
causal_temporal_positioning=causal_temporal_positioning,
|
||||||
|
vae_scale_factors=vae_scale_factors,
|
||||||
|
use_middle_indices_grid=use_middle_indices_grid,
|
||||||
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_model_components(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize LTXAV-specific components."""
|
||||||
|
# Audio-specific projections
|
||||||
|
self.audio_patchify_proj = self.operations.Linear(
|
||||||
|
self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audio-specific AdaLN
|
||||||
|
self.audio_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.audio_inner_dim,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_scale_shift_values = 4
|
||||||
|
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.inner_dim,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
embedding_coefficient=num_scale_shift_values,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.inner_dim,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
embedding_coefficient=1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.audio_inner_dim,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
embedding_coefficient=num_scale_shift_values,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.audio_inner_dim,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
embedding_coefficient=1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audio caption projection
|
||||||
|
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||||
|
in_features=self.caption_channels,
|
||||||
|
hidden_size=self.audio_inner_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize transformer blocks for LTXAV."""
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicAVTransformerBlock(
|
||||||
|
v_dim=self.inner_dim,
|
||||||
|
a_dim=self.audio_inner_dim,
|
||||||
|
v_heads=self.num_attention_heads,
|
||||||
|
a_heads=self.audio_num_attention_heads,
|
||||||
|
vd_head=self.attention_head_dim,
|
||||||
|
ad_head=self.audio_attention_head_dim,
|
||||||
|
v_context_dim=self.cross_attention_dim,
|
||||||
|
a_context_dim=self.audio_cross_attention_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_output_components(self, device, dtype):
|
||||||
|
"""Initialize output components for LTXAV."""
|
||||||
|
# Video output components
|
||||||
|
super()._init_output_components(device, dtype)
|
||||||
|
# Audio output components
|
||||||
|
self.audio_scale_shift_table = nn.Parameter(
|
||||||
|
torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
self.audio_norm_out = self.operations.LayerNorm(
|
||||||
|
self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.audio_proj_out = self.operations.Linear(
|
||||||
|
self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.a_patchifier = AudioPatchifier(1, start_end=True)
|
||||||
|
|
||||||
|
def separate_audio_and_video_latents(self, x, audio_length):
|
||||||
|
"""Separate audio and video latents from combined input."""
|
||||||
|
# vx = x[:, : self.in_channels]
|
||||||
|
# ax = x[:, self.in_channels :]
|
||||||
|
#
|
||||||
|
# ax = ax.reshape(ax.shape[0], -1)
|
||||||
|
# ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins]
|
||||||
|
#
|
||||||
|
# ax = ax.reshape(
|
||||||
|
# ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins
|
||||||
|
# )
|
||||||
|
|
||||||
|
vx = x[0]
|
||||||
|
ax = x[1] if len(x) > 1 else torch.zeros(
|
||||||
|
(vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins),
|
||||||
|
device=vx.device, dtype=vx.dtype
|
||||||
|
)
|
||||||
|
return vx, ax
|
||||||
|
|
||||||
|
def recombine_audio_and_video_latents(self, vx, ax, target_shape=None):
|
||||||
|
if ax.numel() == 0:
|
||||||
|
return vx
|
||||||
|
else:
|
||||||
|
return [vx, ax]
|
||||||
|
"""Recombine audio and video latents for output."""
|
||||||
|
# if ax.device != vx.device or ax.dtype != vx.dtype:
|
||||||
|
# logging.warning("Audio and video latents are on different devices or dtypes.")
|
||||||
|
# ax = ax.to(device=vx.device, dtype=vx.dtype)
|
||||||
|
# logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}")
|
||||||
|
#
|
||||||
|
# ax = ax.reshape(ax.shape[0], -1)
|
||||||
|
# # pad to f x h x w of the video latents
|
||||||
|
# divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3]
|
||||||
|
# if target_shape is None:
|
||||||
|
# repetitions = math.ceil(ax.shape[-1] / divisor)
|
||||||
|
# else:
|
||||||
|
# repetitions = target_shape[1] - vx.shape[1]
|
||||||
|
# padded_len = repetitions * divisor
|
||||||
|
# ax = F.pad(ax, (0, padded_len - ax.shape[-1]))
|
||||||
|
# ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1])
|
||||||
|
# return torch.cat([vx, ax], dim=1)
|
||||||
|
|
||||||
|
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||||
|
"""Process input for LTXAV - separate audio and video, then patchify."""
|
||||||
|
audio_length = kwargs.get("audio_length", 0)
|
||||||
|
# Separate audio and video latents
|
||||||
|
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||||
|
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||||
|
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||||
|
ax = self.audio_patchify_proj(ax)
|
||||||
|
|
||||||
|
# additional_args.update({"av_orig_shape": list(x.shape)})
|
||||||
|
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
|
||||||
|
|
||||||
|
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||||
|
"""Prepare timestep embeddings."""
|
||||||
|
# TODO: some code reuse is needed here.
|
||||||
|
grid_mask = kwargs.get("grid_mask", None)
|
||||||
|
if grid_mask is not None:
|
||||||
|
timestep = timestep[:, grid_mask]
|
||||||
|
|
||||||
|
timestep = timestep * self.timestep_scale_multiplier
|
||||||
|
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||||
|
timestep.flatten(),
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||||
|
v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1])
|
||||||
|
v_embedded_timestep = v_embedded_timestep.view(
|
||||||
|
batch_size, -1, v_embedded_timestep.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare audio timestep
|
||||||
|
a_timestep = kwargs.get("a_timestep")
|
||||||
|
if a_timestep is not None:
|
||||||
|
a_timestep = a_timestep * self.timestep_scale_multiplier
|
||||||
|
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
||||||
|
|
||||||
|
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||||
|
a_timestep.flatten(),
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||||
|
timestep.flatten(),
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||||
|
timestep.flatten() * av_ca_factor,
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||||
|
a_timestep.flatten() * av_ca_factor,
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
||||||
|
a_timestep.flatten(),
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||||
|
a_embedded_timestep = a_embedded_timestep.view(
|
||||||
|
batch_size, -1, a_embedded_timestep.shape[-1]
|
||||||
|
)
|
||||||
|
cross_av_timestep_ss = [
|
||||||
|
av_ca_audio_scale_shift_timestep,
|
||||||
|
av_ca_video_scale_shift_timestep,
|
||||||
|
av_ca_a2v_gate_noise_timestep,
|
||||||
|
av_ca_v2a_gate_noise_timestep,
|
||||||
|
]
|
||||||
|
cross_av_timestep_ss = list(
|
||||||
|
[t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
a_timestep = timestep
|
||||||
|
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||||
|
cross_av_timestep_ss = []
|
||||||
|
|
||||||
|
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
||||||
|
v_embedded_timestep,
|
||||||
|
a_embedded_timestep,
|
||||||
|
]
|
||||||
|
|
||||||
|
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||||
|
vx = x[0]
|
||||||
|
ax = x[1]
|
||||||
|
v_context, a_context = torch.split(
|
||||||
|
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
v_context, attention_mask = super()._prepare_context(
|
||||||
|
v_context, batch_size, vx, attention_mask
|
||||||
|
)
|
||||||
|
if self.audio_caption_projection is not None:
|
||||||
|
a_context = self.audio_caption_projection(a_context)
|
||||||
|
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
||||||
|
|
||||||
|
return [v_context, a_context], attention_mask
|
||||||
|
|
||||||
|
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
||||||
|
v_pixel_coords = pixel_coords[0]
|
||||||
|
v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype)
|
||||||
|
|
||||||
|
a_latent_coords = pixel_coords[1]
|
||||||
|
a_pe = self._precompute_freqs_cis(
|
||||||
|
a_latent_coords,
|
||||||
|
dim=self.audio_inner_dim,
|
||||||
|
out_dtype=x_dtype,
|
||||||
|
max_pos=self.audio_positional_embedding_max_pos,
|
||||||
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||||
|
num_attention_heads=self.audio_num_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate positional embeddings for the middle of the token duration, to use in av cross attention layers.
|
||||||
|
max_pos = max(
|
||||||
|
self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]
|
||||||
|
)
|
||||||
|
v_pixel_coords = v_pixel_coords.to(torch.float32)
|
||||||
|
v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate)
|
||||||
|
av_cross_video_freq_cis = self._precompute_freqs_cis(
|
||||||
|
v_pixel_coords[:, 0:1, :],
|
||||||
|
dim=self.audio_cross_attention_dim,
|
||||||
|
out_dtype=x_dtype,
|
||||||
|
max_pos=[max_pos],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=self.audio_num_attention_heads,
|
||||||
|
)
|
||||||
|
av_cross_audio_freq_cis = self._precompute_freqs_cis(
|
||||||
|
a_latent_coords[:, 0:1, :],
|
||||||
|
dim=self.audio_cross_attention_dim,
|
||||||
|
out_dtype=x_dtype,
|
||||||
|
max_pos=[max_pos],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=self.audio_num_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||||
|
|
||||||
|
def _process_transformer_blocks(
|
||||||
|
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
||||||
|
):
|
||||||
|
vx = x[0]
|
||||||
|
ax = x[1]
|
||||||
|
v_context = context[0]
|
||||||
|
a_context = context[1]
|
||||||
|
v_timestep = timestep[0]
|
||||||
|
a_timestep = timestep[1]
|
||||||
|
v_pe, av_cross_video_freq_cis = pe[0]
|
||||||
|
a_pe, av_cross_audio_freq_cis = pe[1]
|
||||||
|
|
||||||
|
(
|
||||||
|
av_ca_audio_scale_shift_timestep,
|
||||||
|
av_ca_video_scale_shift_timestep,
|
||||||
|
av_ca_a2v_gate_noise_timestep,
|
||||||
|
av_ca_v2a_gate_noise_timestep,
|
||||||
|
) = timestep[2]
|
||||||
|
|
||||||
|
"""Process transformer blocks for LTXAV."""
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
|
# Process transformer blocks
|
||||||
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(
|
||||||
|
args["img"],
|
||||||
|
v_context=args["v_context"],
|
||||||
|
a_context=args["a_context"],
|
||||||
|
attention_mask=args["attention_mask"],
|
||||||
|
v_timestep=args["v_timestep"],
|
||||||
|
a_timestep=args["a_timestep"],
|
||||||
|
v_pe=args["v_pe"],
|
||||||
|
a_pe=args["a_pe"],
|
||||||
|
v_cross_pe=args["v_cross_pe"],
|
||||||
|
a_cross_pe=args["a_cross_pe"],
|
||||||
|
v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"],
|
||||||
|
a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"],
|
||||||
|
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||||
|
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||||
|
transformer_options=args["transformer_options"],
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("double_block", i)](
|
||||||
|
{
|
||||||
|
"img": (vx, ax),
|
||||||
|
"v_context": v_context,
|
||||||
|
"a_context": a_context,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"v_timestep": v_timestep,
|
||||||
|
"a_timestep": a_timestep,
|
||||||
|
"v_pe": v_pe,
|
||||||
|
"a_pe": a_pe,
|
||||||
|
"v_cross_pe": av_cross_video_freq_cis,
|
||||||
|
"a_cross_pe": av_cross_audio_freq_cis,
|
||||||
|
"v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep,
|
||||||
|
"a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep,
|
||||||
|
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||||
|
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||||
|
"transformer_options": transformer_options,
|
||||||
|
},
|
||||||
|
{"original_block": block_wrap},
|
||||||
|
)
|
||||||
|
vx, ax = out["img"]
|
||||||
|
else:
|
||||||
|
vx, ax = block(
|
||||||
|
(vx, ax),
|
||||||
|
v_context=v_context,
|
||||||
|
a_context=a_context,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
v_timestep=v_timestep,
|
||||||
|
a_timestep=a_timestep,
|
||||||
|
v_pe=v_pe,
|
||||||
|
a_pe=a_pe,
|
||||||
|
v_cross_pe=av_cross_video_freq_cis,
|
||||||
|
a_cross_pe=av_cross_audio_freq_cis,
|
||||||
|
v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep,
|
||||||
|
a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep,
|
||||||
|
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||||
|
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [vx, ax]
|
||||||
|
|
||||||
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||||
|
vx = x[0]
|
||||||
|
ax = x[1]
|
||||||
|
v_embedded_timestep = embedded_timestep[0]
|
||||||
|
a_embedded_timestep = embedded_timestep[1]
|
||||||
|
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
||||||
|
|
||||||
|
# Process audio output
|
||||||
|
a_scale_shift_values = (
|
||||||
|
self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype)
|
||||||
|
+ a_embedded_timestep[:, :, None]
|
||||||
|
)
|
||||||
|
a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1]
|
||||||
|
|
||||||
|
ax = self.audio_norm_out(ax)
|
||||||
|
ax = ax * (1 + a_scale) + a_shift
|
||||||
|
ax = self.audio_proj_out(ax)
|
||||||
|
|
||||||
|
# Unpatchify audio
|
||||||
|
ax = self.a_patchifier.unpatchify(
|
||||||
|
ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recombine audio and video
|
||||||
|
original_shape = kwargs.get("av_orig_shape")
|
||||||
|
return self.recombine_audio_and_video_latents(vx, ax, original_shape)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
context,
|
||||||
|
attention_mask=None,
|
||||||
|
frame_rate=25,
|
||||||
|
transformer_options={},
|
||||||
|
keyframe_idxs=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass for LTXAV model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Combined audio-video input tensor
|
||||||
|
timestep: Tuple of (video_timestep, audio_timestep) or single timestep
|
||||||
|
context: Context tensor (e.g., text embeddings)
|
||||||
|
attention_mask: Attention mask tensor
|
||||||
|
frame_rate: Frame rate for temporal processing
|
||||||
|
transformer_options: Additional options for transformer blocks
|
||||||
|
keyframe_idxs: Keyframe indices for temporal processing
|
||||||
|
**kwargs: Additional keyword arguments including audio_length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined audio-video output tensor
|
||||||
|
"""
|
||||||
|
# Handle timestep format
|
||||||
|
if isinstance(timestep, (tuple, list)) and len(timestep) == 2:
|
||||||
|
v_timestep, a_timestep = timestep
|
||||||
|
kwargs["a_timestep"] = a_timestep
|
||||||
|
timestep = v_timestep
|
||||||
|
else:
|
||||||
|
kwargs["a_timestep"] = timestep
|
||||||
|
|
||||||
|
# Call parent forward method
|
||||||
|
return super().forward(
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
context,
|
||||||
|
attention_mask,
|
||||||
|
frame_rate,
|
||||||
|
transformer_options,
|
||||||
|
keyframe_idxs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
305
comfy/ldm/lightricks/embeddings_connector.py
Normal file
305
comfy/ldm/lightricks/embeddings_connector.py
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
import torch
|
||||||
|
from comfy.ldm.lightricks.model import (
|
||||||
|
CrossAttention,
|
||||||
|
FeedForward,
|
||||||
|
generate_freq_grid_np,
|
||||||
|
interleaved_freqs_cis,
|
||||||
|
split_freqs_cis,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTransformerBlock1D(nn.Module):
|
||||||
|
r"""
|
||||||
|
A basic Transformer block.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
|
||||||
|
dim (`int`): The number of channels in the input and output.
|
||||||
|
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||||
|
attention_head_dim (`int`): The number of channels in each head.
|
||||||
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||||
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||||
|
attention_bias (:
|
||||||
|
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||||
|
upcast_attention (`bool`, *optional*):
|
||||||
|
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||||
|
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use learnable elementwise affine parameters for normalization.
|
||||||
|
standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
|
||||||
|
norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers.
|
||||||
|
qk_norm (`str`, *optional*, defaults to None):
|
||||||
|
Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
|
||||||
|
final_dropout (`bool` *optional*, defaults to False):
|
||||||
|
Whether to apply a final dropout after the last feed-forward layer.
|
||||||
|
ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`.
|
||||||
|
ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer.
|
||||||
|
attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer.
|
||||||
|
use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE).
|
||||||
|
ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
context_dim=None,
|
||||||
|
attn_precision=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Define 3 blocks. Each block has its own normalization layer.
|
||||||
|
# 1. Self-Attn
|
||||||
|
self.attn1 = CrossAttention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
context_dim=None,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Feed-forward
|
||||||
|
self.ff = FeedForward(
|
||||||
|
dim,
|
||||||
|
dim_out=dim,
|
||||||
|
glu=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor:
|
||||||
|
|
||||||
|
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||||
|
|
||||||
|
# 1. Normalization Before Self-Attention
|
||||||
|
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||||
|
|
||||||
|
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
# 2. Self-Attention
|
||||||
|
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
|
||||||
|
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
if hidden_states.ndim == 4:
|
||||||
|
hidden_states = hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
# 3. Normalization before Feed-Forward
|
||||||
|
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||||
|
|
||||||
|
# 4. Feed-forward
|
||||||
|
ff_output = self.ff(norm_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = ff_output + hidden_states
|
||||||
|
if hidden_states.ndim == 4:
|
||||||
|
hidden_states = hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Embeddings1DConnector(nn.Module):
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=128,
|
||||||
|
cross_attention_dim=2048,
|
||||||
|
attention_head_dim=128,
|
||||||
|
num_attention_heads=30,
|
||||||
|
num_layers=2,
|
||||||
|
positional_embedding_theta=10000.0,
|
||||||
|
positional_embedding_max_pos=[4096],
|
||||||
|
causal_temporal_positioning=False,
|
||||||
|
num_learnable_registers: Optional[int] = 128,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
split_rope=False,
|
||||||
|
double_precision_rope=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.out_channels = in_channels
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.causal_temporal_positioning = causal_temporal_positioning
|
||||||
|
self.positional_embedding_theta = positional_embedding_theta
|
||||||
|
self.positional_embedding_max_pos = positional_embedding_max_pos
|
||||||
|
self.split_rope = split_rope
|
||||||
|
self.double_precision_rope = double_precision_rope
|
||||||
|
self.transformer_1d_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicTransformerBlock1D(
|
||||||
|
self.inner_dim,
|
||||||
|
num_attention_heads,
|
||||||
|
attention_head_dim,
|
||||||
|
context_dim=cross_attention_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.num_learnable_registers = num_learnable_registers
|
||||||
|
if self.num_learnable_registers:
|
||||||
|
self.learnable_registers = nn.Parameter(
|
||||||
|
torch.rand(
|
||||||
|
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
* 2.0
|
||||||
|
- 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_fractional_positions(self, indices_grid):
|
||||||
|
fractional_positions = torch.stack(
|
||||||
|
[
|
||||||
|
indices_grid[:, i] / self.positional_embedding_max_pos[i]
|
||||||
|
for i in range(1)
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
return fractional_positions
|
||||||
|
|
||||||
|
def precompute_freqs(self, indices_grid, spacing):
|
||||||
|
source_dtype = indices_grid.dtype
|
||||||
|
dtype = (
|
||||||
|
torch.float32
|
||||||
|
if source_dtype in (torch.bfloat16, torch.float16)
|
||||||
|
else source_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
fractional_positions = self.get_fractional_positions(indices_grid)
|
||||||
|
indices = (
|
||||||
|
generate_freq_grid_np(
|
||||||
|
self.positional_embedding_theta,
|
||||||
|
indices_grid.shape[1],
|
||||||
|
self.inner_dim,
|
||||||
|
)
|
||||||
|
if self.double_precision_rope
|
||||||
|
else self.generate_freq_grid(spacing, dtype, fractional_positions.device)
|
||||||
|
).to(device=fractional_positions.device)
|
||||||
|
|
||||||
|
if spacing == "exp_2":
|
||||||
|
freqs = (
|
||||||
|
(indices * fractional_positions.unsqueeze(-1))
|
||||||
|
.transpose(-1, -2)
|
||||||
|
.flatten(2)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
freqs = (
|
||||||
|
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||||
|
.transpose(-1, -2)
|
||||||
|
.flatten(2)
|
||||||
|
)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
def generate_freq_grid(self, spacing, dtype, device):
|
||||||
|
dim = self.inner_dim
|
||||||
|
theta = self.positional_embedding_theta
|
||||||
|
n_pos_dims = 1
|
||||||
|
n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6
|
||||||
|
start = 1
|
||||||
|
end = theta
|
||||||
|
|
||||||
|
if spacing == "exp":
|
||||||
|
indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem))
|
||||||
|
indices = indices.to(dtype=dtype, device=device)
|
||||||
|
elif spacing == "exp_2":
|
||||||
|
indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim)
|
||||||
|
indices = indices.to(dtype=dtype)
|
||||||
|
elif spacing == "linear":
|
||||||
|
indices = torch.linspace(
|
||||||
|
start, end, dim // n_elem, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
elif spacing == "sqrt":
|
||||||
|
indices = torch.linspace(
|
||||||
|
start**2, end**2, dim // n_elem, device=device, dtype=dtype
|
||||||
|
).sqrt()
|
||||||
|
|
||||||
|
indices = indices * math.pi / 2
|
||||||
|
|
||||||
|
return indices
|
||||||
|
|
||||||
|
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
||||||
|
dim = self.inner_dim
|
||||||
|
n_elem = 2 # 2 because of cos and sin
|
||||||
|
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||||
|
if self.split_rope:
|
||||||
|
expected_freqs = dim // 2
|
||||||
|
current_freqs = freqs.shape[-1]
|
||||||
|
pad_size = expected_freqs - current_freqs
|
||||||
|
cos_freq, sin_freq = split_freqs_cis(
|
||||||
|
freqs, pad_size, self.num_attention_heads
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||||
|
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
The [`Transformer2DModel`] forward method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||||
|
Input `hidden_states`.
|
||||||
|
indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
|
||||||
|
attention_mask ( `torch.Tensor`, *optional*):
|
||||||
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||||
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||||
|
negative values to the attention scores corresponding to "discard" tokens.
|
||||||
|
Returns:
|
||||||
|
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||||
|
`tuple` where the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
# 1. Input
|
||||||
|
|
||||||
|
if self.num_learnable_registers:
|
||||||
|
num_registers_duplications = math.ceil(
|
||||||
|
max(1024, hidden_states.shape[1]) / self.num_learnable_registers
|
||||||
|
)
|
||||||
|
learnable_registers = torch.tile(
|
||||||
|
self.learnable_registers, (num_registers_duplications, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device)
|
||||||
|
|
||||||
|
indices_grid = torch.arange(
|
||||||
|
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
indices_grid = indices_grid[None, None, :]
|
||||||
|
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
||||||
|
|
||||||
|
# 2. Blocks
|
||||||
|
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||||
|
hidden_states = block(
|
||||||
|
hidden_states, attention_mask=attention_mask, pe=freqs_cis
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Output
|
||||||
|
# if self.output_scale is not None:
|
||||||
|
# hidden_states = hidden_states / self.output_scale
|
||||||
|
|
||||||
|
hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, attention_mask
|
||||||
292
comfy/ldm/lightricks/latent_upsampler.py
Normal file
292
comfy/ldm/lightricks/latent_upsampler.py
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||||
|
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
|
||||||
|
if float(scale) not in mapping:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}"
|
||||||
|
)
|
||||||
|
return mapping[float(scale)]
|
||||||
|
|
||||||
|
|
||||||
|
class PixelShuffleND(nn.Module):
|
||||||
|
def __init__(self, dims, upscale_factors=(2, 2, 2)):
|
||||||
|
super().__init__()
|
||||||
|
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
|
||||||
|
self.dims = dims
|
||||||
|
self.upscale_factors = upscale_factors
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.dims == 3:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
p2=self.upscale_factors[1],
|
||||||
|
p3=self.upscale_factors[2],
|
||||||
|
)
|
||||||
|
elif self.dims == 2:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
p2=self.upscale_factors[1],
|
||||||
|
)
|
||||||
|
elif self.dims == 1:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1) f h w -> b c (f p1) h w",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlurDownsample(nn.Module):
|
||||||
|
"""
|
||||||
|
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
|
||||||
|
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dims: int, stride: int):
|
||||||
|
super().__init__()
|
||||||
|
assert dims in (2, 3)
|
||||||
|
assert stride >= 1 and isinstance(stride, int)
|
||||||
|
self.dims = dims
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
# 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized
|
||||||
|
k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0])
|
||||||
|
k2d = k[:, None] @ k[None, :]
|
||||||
|
k2d = (k2d / k2d.sum()).float() # shape (5,5)
|
||||||
|
self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.stride == 1:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _apply_2d(x2d: torch.Tensor) -> torch.Tensor:
|
||||||
|
# x2d: (B, C, H, W)
|
||||||
|
B, C, H, W = x2d.shape
|
||||||
|
weight = self.kernel.expand(C, 1, 5, 5) # depthwise
|
||||||
|
x2d = F.conv2d(
|
||||||
|
x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C
|
||||||
|
)
|
||||||
|
return x2d
|
||||||
|
|
||||||
|
if self.dims == 2:
|
||||||
|
return _apply_2d(x)
|
||||||
|
else:
|
||||||
|
# dims == 3: apply per-frame on H,W
|
||||||
|
b, c, f, h, w = x.shape
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = _apply_2d(x)
|
||||||
|
h2, w2 = x.shape[-2:]
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialRationalResampler(nn.Module):
|
||||||
|
"""
|
||||||
|
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
|
||||||
|
downsample by 'den' using fixed blur + stride. Operates on H,W only.
|
||||||
|
|
||||||
|
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mid_channels: int, scale: float):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.num, self.den = _rational_for_scale(self.scale)
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1
|
||||||
|
)
|
||||||
|
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
|
||||||
|
self.blur_down = BlurDownsample(dims=2, stride=self.den)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
b, c, f, h, w = x.shape
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.pixel_shuffle(x)
|
||||||
|
x = self.blur_down(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if mid_channels is None:
|
||||||
|
mid_channels = channels
|
||||||
|
|
||||||
|
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
|
||||||
|
|
||||||
|
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.norm1 = nn.GroupNorm(32, mid_channels)
|
||||||
|
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||||
|
self.norm2 = nn.GroupNorm(32, channels)
|
||||||
|
self.activation = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
residual = x
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.activation(x + residual)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LatentUpsampler(nn.Module):
|
||||||
|
"""
|
||||||
|
Model to spatially upsample VAE latents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (`int`): Number of channels in the input latent
|
||||||
|
mid_channels (`int`): Number of channels in the middle layers
|
||||||
|
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||||
|
dims (`int`): Number of dimensions for convolutions (2 or 3)
|
||||||
|
spatial_upsample (`bool`): Whether to spatially upsample the latent
|
||||||
|
temporal_upsample (`bool`): Whether to temporally upsample the latent
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 128,
|
||||||
|
mid_channels: int = 512,
|
||||||
|
num_blocks_per_stage: int = 4,
|
||||||
|
dims: int = 3,
|
||||||
|
spatial_upsample: bool = True,
|
||||||
|
temporal_upsample: bool = False,
|
||||||
|
spatial_scale: float = 2.0,
|
||||||
|
rational_resampler: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.mid_channels = mid_channels
|
||||||
|
self.num_blocks_per_stage = num_blocks_per_stage
|
||||||
|
self.dims = dims
|
||||||
|
self.spatial_upsample = spatial_upsample
|
||||||
|
self.temporal_upsample = temporal_upsample
|
||||||
|
self.spatial_scale = float(spatial_scale)
|
||||||
|
self.rational_resampler = rational_resampler
|
||||||
|
|
||||||
|
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
|
||||||
|
|
||||||
|
self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.initial_norm = nn.GroupNorm(32, mid_channels)
|
||||||
|
self.initial_activation = nn.SiLU()
|
||||||
|
|
||||||
|
self.res_blocks = nn.ModuleList(
|
||||||
|
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||||
|
)
|
||||||
|
|
||||||
|
if spatial_upsample and temporal_upsample:
|
||||||
|
self.upsampler = nn.Sequential(
|
||||||
|
nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(3),
|
||||||
|
)
|
||||||
|
elif spatial_upsample:
|
||||||
|
if rational_resampler:
|
||||||
|
self.upsampler = SpatialRationalResampler(
|
||||||
|
mid_channels=mid_channels, scale=self.spatial_scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.upsampler = nn.Sequential(
|
||||||
|
nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(2),
|
||||||
|
)
|
||||||
|
elif temporal_upsample:
|
||||||
|
self.upsampler = nn.Sequential(
|
||||||
|
nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Either spatial_upsample or temporal_upsample must be True"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.post_upsample_res_blocks = nn.ModuleList(
|
||||||
|
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
||||||
|
b, c, f, h, w = latent.shape
|
||||||
|
|
||||||
|
if self.dims == 2:
|
||||||
|
x = rearrange(latent, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.initial_conv(x)
|
||||||
|
x = self.initial_norm(x)
|
||||||
|
x = self.initial_activation(x)
|
||||||
|
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.upsampler(x)
|
||||||
|
|
||||||
|
for block in self.post_upsample_res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
else:
|
||||||
|
x = self.initial_conv(latent)
|
||||||
|
x = self.initial_norm(x)
|
||||||
|
x = self.initial_activation(x)
|
||||||
|
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
if self.temporal_upsample:
|
||||||
|
x = self.upsampler(x)
|
||||||
|
x = x[:, :, 1:, :, :]
|
||||||
|
else:
|
||||||
|
if isinstance(self.upsampler, SpatialRationalResampler):
|
||||||
|
x = self.upsampler(x)
|
||||||
|
else:
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.upsampler(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
|
||||||
|
for block in self.post_upsample_res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
return cls(
|
||||||
|
in_channels=config.get("in_channels", 4),
|
||||||
|
mid_channels=config.get("mid_channels", 128),
|
||||||
|
num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
|
||||||
|
dims=config.get("dims", 2),
|
||||||
|
spatial_upsample=config.get("spatial_upsample", True),
|
||||||
|
temporal_upsample=config.get("temporal_upsample", False),
|
||||||
|
spatial_scale=config.get("spatial_scale", 2.0),
|
||||||
|
rational_resampler=config.get("rational_resampler", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def config(self):
|
||||||
|
return {
|
||||||
|
"_class_name": "LatentUpsampler",
|
||||||
|
"in_channels": self.in_channels,
|
||||||
|
"mid_channels": self.mid_channels,
|
||||||
|
"num_blocks_per_stage": self.num_blocks_per_stage,
|
||||||
|
"dims": self.dims,
|
||||||
|
"spatial_upsample": self.spatial_upsample,
|
||||||
|
"temporal_upsample": self.temporal_upsample,
|
||||||
|
"spatial_scale": self.spatial_scale,
|
||||||
|
"rational_resampler": self.rational_resampler,
|
||||||
|
}
|
||||||
@ -1,13 +1,47 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
import functools
|
||||||
|
import math
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.ldm.modules.attention
|
import comfy.ldm.modules.attention
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import math
|
|
||||||
from typing import Dict, Optional, Tuple
|
|
||||||
|
|
||||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||||
from comfy.ldm.flux.math import apply_rope1
|
|
||||||
|
def _log_base(x, base):
|
||||||
|
return np.log(x) / np.log(base)
|
||||||
|
|
||||||
|
class LTXRopeType(str, Enum):
|
||||||
|
INTERLEAVED = "interleaved"
|
||||||
|
SPLIT = "split"
|
||||||
|
|
||||||
|
KEY = "rope_type"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, kwargs, default=None):
|
||||||
|
if default is None:
|
||||||
|
default = cls.INTERLEAVED
|
||||||
|
return cls(kwargs.get(cls.KEY, default))
|
||||||
|
|
||||||
|
|
||||||
|
class LTXFrequenciesPrecision(str, Enum):
|
||||||
|
FLOAT32 = "float32"
|
||||||
|
FLOAT64 = "float64"
|
||||||
|
|
||||||
|
KEY = "frequencies_precision"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, kwargs, default=None):
|
||||||
|
if default is None:
|
||||||
|
default = cls.FLOAT32
|
||||||
|
return cls(kwargs.get(cls.KEY, default))
|
||||||
|
|
||||||
|
|
||||||
def get_timestep_embedding(
|
def get_timestep_embedding(
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
@ -39,9 +73,7 @@ def get_timestep_embedding(
|
|||||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||||
|
|
||||||
half_dim = embedding_dim // 2
|
half_dim = embedding_dim // 2
|
||||||
exponent = -math.log(max_period) * torch.arange(
|
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
||||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
|
||||||
)
|
|
||||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||||
|
|
||||||
emb = torch.exp(exponent)
|
emb = torch.exp(exponent)
|
||||||
@ -73,7 +105,9 @@ class TimestepEmbedding(nn.Module):
|
|||||||
post_act_fn: Optional[str] = None,
|
post_act_fn: Optional[str] = None,
|
||||||
cond_proj_dim=None,
|
cond_proj_dim=None,
|
||||||
sample_proj_bias=True,
|
sample_proj_bias=True,
|
||||||
dtype=None, device=None, operations=None,
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -90,7 +124,9 @@ class TimestepEmbedding(nn.Module):
|
|||||||
time_embed_dim_out = out_dim
|
time_embed_dim_out = out_dim
|
||||||
else:
|
else:
|
||||||
time_embed_dim_out = time_embed_dim
|
time_embed_dim_out = time_embed_dim
|
||||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
|
self.linear_2 = operations.Linear(
|
||||||
|
time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
if post_act_fn is None:
|
if post_act_fn is None:
|
||||||
self.post_act = None
|
self.post_act = None
|
||||||
@ -139,12 +175,22 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
|||||||
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim,
|
||||||
|
size_emb_dim,
|
||||||
|
use_additional_conditions: bool = False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.outdim = size_emb_dim
|
self.outdim = size_emb_dim
|
||||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
|
self.timestep_embedder = TimestepEmbedding(
|
||||||
|
in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
||||||
timesteps_proj = self.time_proj(timestep)
|
timesteps_proj = self.time_proj(timestep)
|
||||||
@ -163,15 +209,22 @@ class AdaLayerNormSingle(nn.Module):
|
|||||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
def __init__(
|
||||||
|
self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||||
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
|
embedding_dim,
|
||||||
|
size_emb_dim=embedding_dim // 3,
|
||||||
|
use_additional_conditions=use_additional_conditions,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.silu = nn.SiLU()
|
self.silu = nn.SiLU()
|
||||||
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
|
self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -185,6 +238,7 @@ class AdaLayerNormSingle(nn.Module):
|
|||||||
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
||||||
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
||||||
|
|
||||||
|
|
||||||
class PixArtAlphaTextProjection(nn.Module):
|
class PixArtAlphaTextProjection(nn.Module):
|
||||||
"""
|
"""
|
||||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||||
@ -192,18 +246,24 @@ class PixArtAlphaTextProjection(nn.Module):
|
|||||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
|
def __init__(
|
||||||
|
self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if out_features is None:
|
if out_features is None:
|
||||||
out_features = hidden_size
|
out_features = hidden_size
|
||||||
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
|
self.linear_1 = operations.Linear(
|
||||||
|
in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
if act_fn == "gelu_tanh":
|
if act_fn == "gelu_tanh":
|
||||||
self.act_1 = nn.GELU(approximate="tanh")
|
self.act_1 = nn.GELU(approximate="tanh")
|
||||||
elif act_fn == "silu":
|
elif act_fn == "silu":
|
||||||
self.act_1 = nn.SiLU()
|
self.act_1 = nn.SiLU()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||||
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
|
self.linear_2 = operations.Linear(
|
||||||
|
in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, caption):
|
def forward(self, caption):
|
||||||
hidden_states = self.linear_1(caption)
|
hidden_states = self.linear_1(caption)
|
||||||
@ -222,23 +282,68 @@ class GELU_approx(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
|
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = int(dim * mult)
|
||||||
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
project_in,
|
project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||||
nn.Dropout(dropout),
|
|
||||||
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
def apply_rotary_emb(input_tensor, freqs_cis):
|
||||||
|
cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1]
|
||||||
|
split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False
|
||||||
|
return (
|
||||||
|
apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
||||||
|
if split_pe else
|
||||||
|
apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one
|
||||||
|
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||||
|
t1, t2 = t_dup.unbind(dim=-1)
|
||||||
|
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||||
|
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||||
|
|
||||||
|
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def apply_split_rotary_emb(input_tensor, cos, sin):
|
||||||
|
needs_reshape = False
|
||||||
|
if input_tensor.ndim != 4 and cos.ndim == 4:
|
||||||
|
B, H, T, _ = cos.shape
|
||||||
|
input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2)
|
||||||
|
needs_reshape = True
|
||||||
|
split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2)
|
||||||
|
first_half_input = split_input[..., :1, :]
|
||||||
|
second_half_input = split_input[..., 1:, :]
|
||||||
|
output = split_input * cos.unsqueeze(-2)
|
||||||
|
first_half_output = output[..., :1, :]
|
||||||
|
second_half_output = output[..., 1:, :]
|
||||||
|
first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input)
|
||||||
|
second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input)
|
||||||
|
output = rearrange(output, "... d r -> ... (d r)")
|
||||||
|
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim,
|
||||||
|
context_dim=None,
|
||||||
|
heads=8,
|
||||||
|
dim_head=64,
|
||||||
|
dropout=0.0,
|
||||||
|
attn_precision=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = query_dim if context_dim is None else context_dim
|
context_dim = query_dim if context_dim is None else context_dim
|
||||||
@ -254,9 +359,11 @@ class CrossAttention(nn.Module):
|
|||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(
|
||||||
|
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
|
def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}):
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = x if context is None else context
|
context = x if context is None else context
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
@ -266,8 +373,8 @@ class CrossAttention(nn.Module):
|
|||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
|
|
||||||
if pe is not None:
|
if pe is not None:
|
||||||
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
|
q = apply_rotary_emb(q, pe)
|
||||||
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
|
k = apply_rotary_emb(k, pe if k_pe is None else k_pe)
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
@ -277,14 +384,34 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
|
def __init__(
|
||||||
|
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attn_precision = attn_precision
|
self.attn_precision = attn_precision
|
||||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
self.attn1 = CrossAttention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
context_dim=None,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
self.attn2 = CrossAttention(
|
||||||
|
query_dim=dim,
|
||||||
|
context_dim=context_dim,
|
||||||
|
heads=n_heads,
|
||||||
|
dim_head=d_head,
|
||||||
|
attn_precision=self.attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
@ -306,116 +433,446 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def get_fractional_positions(indices_grid, max_pos):
|
def get_fractional_positions(indices_grid, max_pos):
|
||||||
|
n_pos_dims = indices_grid.shape[1]
|
||||||
|
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
||||||
fractional_positions = torch.stack(
|
fractional_positions = torch.stack(
|
||||||
[
|
[indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],
|
||||||
indices_grid[:, i] / max_pos[i]
|
axis=-1,
|
||||||
for i in range(3)
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
)
|
||||||
return fractional_positions
|
return fractional_positions
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
@functools.lru_cache(maxsize=5)
|
||||||
dtype = torch.float32
|
def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None):
|
||||||
device = indices_grid.device
|
theta = positional_embedding_theta
|
||||||
|
start = 1
|
||||||
|
end = theta
|
||||||
|
|
||||||
|
n_elem = 2 * positional_embedding_max_pos_count
|
||||||
|
pow_indices = np.power(
|
||||||
|
theta,
|
||||||
|
np.linspace(
|
||||||
|
_log_base(start, theta),
|
||||||
|
_log_base(end, theta),
|
||||||
|
inner_dim // n_elem,
|
||||||
|
dtype=np.float64,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)
|
||||||
|
|
||||||
|
def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device):
|
||||||
|
theta = positional_embedding_theta
|
||||||
|
start = 1
|
||||||
|
end = theta
|
||||||
|
n_elem = 2 * positional_embedding_max_pos_count
|
||||||
|
|
||||||
|
indices = theta ** (
|
||||||
|
torch.linspace(
|
||||||
|
math.log(start, theta),
|
||||||
|
math.log(end, theta),
|
||||||
|
inner_dim // n_elem,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
indices = indices.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
indices = indices * math.pi / 2
|
||||||
|
|
||||||
|
return indices
|
||||||
|
|
||||||
|
def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid):
|
||||||
|
if use_middle_indices_grid:
|
||||||
|
assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2)
|
||||||
|
indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]
|
||||||
|
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
|
||||||
|
elif len(indices_grid.shape) == 4:
|
||||||
|
indices_grid = indices_grid[..., 0]
|
||||||
|
|
||||||
# Get fractional positions and compute frequency indices
|
# Get fractional positions and compute frequency indices
|
||||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||||
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
|
indices = indices.to(device=fractional_positions.device)
|
||||||
|
|
||||||
# Compute frequencies and apply cos/sin
|
freqs = (
|
||||||
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
|
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||||
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
|
.transpose(-1, -2)
|
||||||
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
|
.flatten(2)
|
||||||
|
)
|
||||||
|
return freqs
|
||||||
|
|
||||||
# Pad if dim is not divisible by 6
|
def interleaved_freqs_cis(freqs, pad_size):
|
||||||
if dim % 6 != 0:
|
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||||
padding_size = dim % 6
|
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||||
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
|
if pad_size != 0:
|
||||||
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
|
cos_padding = torch.ones_like(cos_freq[:, :, : pad_size])
|
||||||
|
sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size])
|
||||||
|
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||||
|
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||||
|
return cos_freq, sin_freq
|
||||||
|
|
||||||
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
|
def split_freqs_cis(freqs, pad_size, num_attention_heads):
|
||||||
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
cos_freq = freqs.cos()
|
||||||
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
sin_freq = freqs.sin()
|
||||||
|
|
||||||
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
|
if pad_size != 0:
|
||||||
freqs_cis = torch.stack([
|
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
|
||||||
torch.stack([cos_vals, -sin_vals], dim=-1),
|
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
|
||||||
torch.stack([sin_vals, cos_vals], dim=-1)
|
|
||||||
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
|
|
||||||
|
|
||||||
return freqs_cis
|
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
|
||||||
|
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
|
||||||
|
|
||||||
|
# Reshape freqs to be compatible with multi-head attention
|
||||||
|
B , T, half_HD = cos_freq.shape
|
||||||
|
|
||||||
class LTXVModel(torch.nn.Module):
|
cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
||||||
def __init__(self,
|
sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
||||||
in_channels=128,
|
|
||||||
cross_attention_dim=2048,
|
|
||||||
attention_head_dim=64,
|
|
||||||
num_attention_heads=32,
|
|
||||||
|
|
||||||
caption_channels=4096,
|
cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
|
||||||
num_layers=28,
|
sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
|
||||||
|
return cos_freq, sin_freq
|
||||||
|
|
||||||
|
class LTXBaseModel(torch.nn.Module, ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for LTX models (Lightricks Transformer models).
|
||||||
|
|
||||||
positional_embedding_theta=10000.0,
|
This class defines the common interface and shared functionality for all LTX models,
|
||||||
positional_embedding_max_pos=[20, 2048, 2048],
|
including LTXV (video) and LTXAV (audio-video) variants.
|
||||||
causal_temporal_positioning=False,
|
"""
|
||||||
vae_scale_factors=(8, 32, 32),
|
|
||||||
dtype=None, device=None, operations=None, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
cross_attention_dim: int,
|
||||||
|
attention_head_dim: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
caption_channels: int,
|
||||||
|
num_layers: int,
|
||||||
|
positional_embedding_theta: float = 10000.0,
|
||||||
|
positional_embedding_max_pos: list = [20, 2048, 2048],
|
||||||
|
causal_temporal_positioning: bool = False,
|
||||||
|
vae_scale_factors: tuple = (8, 32, 32),
|
||||||
|
use_middle_indices_grid=False,
|
||||||
|
timestep_scale_multiplier = 1000.0,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.generator = None
|
self.generator = None
|
||||||
self.vae_scale_factors = vae_scale_factors
|
self.vae_scale_factors = vae_scale_factors
|
||||||
|
self.use_middle_indices_grid = use_middle_indices_grid
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.out_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
self.attention_head_dim = attention_head_dim
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.caption_channels = caption_channels
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.positional_embedding_theta = positional_embedding_theta
|
||||||
|
self.positional_embedding_max_pos = positional_embedding_max_pos
|
||||||
|
self.split_positional_embedding = LTXRopeType.from_dict(kwargs)
|
||||||
|
self.freq_grid_generator = (
|
||||||
|
generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64
|
||||||
|
else generate_freq_grid_pytorch
|
||||||
|
)
|
||||||
self.causal_temporal_positioning = causal_temporal_positioning
|
self.causal_temporal_positioning = causal_temporal_positioning
|
||||||
|
self.operations = operations
|
||||||
|
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||||
|
|
||||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
# Common dimensions
|
||||||
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.out_channels = in_channels
|
||||||
|
|
||||||
|
# Initialize common components
|
||||||
|
self._init_common_components(device, dtype)
|
||||||
|
|
||||||
|
# Initialize model-specific components
|
||||||
|
self._init_model_components(device, dtype, **kwargs)
|
||||||
|
|
||||||
|
# Initialize transformer blocks
|
||||||
|
self._init_transformer_blocks(device, dtype, **kwargs)
|
||||||
|
|
||||||
|
# Initialize output components
|
||||||
|
self._init_output_components(device, dtype)
|
||||||
|
|
||||||
|
def _init_common_components(self, device, dtype):
|
||||||
|
"""Initialize components common to all LTX models
|
||||||
|
- patchify_proj: Linear projection for patchifying input
|
||||||
|
- adaln_single: AdaLN layer for timestep embedding
|
||||||
|
- caption_projection: Linear projection for caption embedding
|
||||||
|
"""
|
||||||
|
self.patchify_proj = self.operations.Linear(
|
||||||
|
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
self.adaln_single = AdaLayerNormSingle(
|
self.adaln_single = AdaLayerNormSingle(
|
||||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
|
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||||
)
|
)
|
||||||
|
|
||||||
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.caption_projection = PixArtAlphaTextProjection(
|
self.caption_projection = PixArtAlphaTextProjection(
|
||||||
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
|
in_features=self.caption_channels,
|
||||||
|
hidden_size=self.inner_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _init_model_components(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize model-specific components. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize transformer blocks. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _init_output_components(self, device, dtype):
|
||||||
|
"""Initialize output components. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||||
|
"""Process input data. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
|
||||||
|
"""Process transformer blocks. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||||
|
"""Process output data. Must be implemented by subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||||
|
"""Prepare timestep embeddings."""
|
||||||
|
grid_mask = kwargs.get("grid_mask", None)
|
||||||
|
if grid_mask is not None:
|
||||||
|
timestep = timestep[:, grid_mask]
|
||||||
|
|
||||||
|
timestep = timestep * self.timestep_scale_multiplier
|
||||||
|
timestep, embedded_timestep = self.adaln_single(
|
||||||
|
timestep.flatten(),
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||||
|
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||||
|
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||||
|
|
||||||
|
return timestep, embedded_timestep
|
||||||
|
|
||||||
|
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||||
|
"""Prepare context for transformer blocks."""
|
||||||
|
if self.caption_projection is not None:
|
||||||
|
context = self.caption_projection(context)
|
||||||
|
context = context.view(batch_size, -1, x.shape[-1])
|
||||||
|
|
||||||
|
return context, attention_mask
|
||||||
|
|
||||||
|
def _precompute_freqs_cis(
|
||||||
|
self,
|
||||||
|
indices_grid,
|
||||||
|
dim,
|
||||||
|
out_dtype,
|
||||||
|
theta=10000.0,
|
||||||
|
max_pos=[20, 2048, 2048],
|
||||||
|
use_middle_indices_grid=False,
|
||||||
|
num_attention_heads=32,
|
||||||
|
):
|
||||||
|
split_mode = self.split_positional_embedding == LTXRopeType.SPLIT
|
||||||
|
indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device)
|
||||||
|
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
|
||||||
|
|
||||||
|
if split_mode:
|
||||||
|
expected_freqs = dim // 2
|
||||||
|
current_freqs = freqs.shape[-1]
|
||||||
|
pad_size = expected_freqs - current_freqs
|
||||||
|
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
|
||||||
|
else:
|
||||||
|
# 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only
|
||||||
|
n_elem = 2 * indices_grid.shape[1]
|
||||||
|
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||||
|
return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode
|
||||||
|
|
||||||
|
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
||||||
|
"""Prepare positional embeddings."""
|
||||||
|
fractional_coords = pixel_coords.to(torch.float32)
|
||||||
|
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
||||||
|
pe = self._precompute_freqs_cis(
|
||||||
|
fractional_coords,
|
||||||
|
dim=self.inner_dim,
|
||||||
|
out_dtype=x_dtype,
|
||||||
|
max_pos=self.positional_embedding_max_pos,
|
||||||
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
)
|
||||||
|
return pe
|
||||||
|
|
||||||
|
def _prepare_attention_mask(self, attention_mask, x_dtype):
|
||||||
|
"""Prepare attention mask."""
|
||||||
|
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||||
|
attention_mask = (attention_mask - 1).to(x_dtype).reshape(
|
||||||
|
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||||
|
) * torch.finfo(x_dtype).max
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass for LTX models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
timestep: Timestep tensor
|
||||||
|
context: Context tensor (e.g., text embeddings)
|
||||||
|
attention_mask: Attention mask tensor
|
||||||
|
frame_rate: Frame rate for temporal processing
|
||||||
|
transformer_options: Additional options for transformer blocks
|
||||||
|
keyframe_idxs: Keyframe indices for temporal processing
|
||||||
|
**kwargs: Additional keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed output tensor
|
||||||
|
"""
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(
|
||||||
|
comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options
|
||||||
|
),
|
||||||
|
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Internal forward pass for LTX models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
timestep: Timestep tensor
|
||||||
|
context: Context tensor (e.g., text embeddings)
|
||||||
|
attention_mask: Attention mask tensor
|
||||||
|
frame_rate: Frame rate for temporal processing
|
||||||
|
transformer_options: Additional options for transformer blocks
|
||||||
|
keyframe_idxs: Keyframe indices for temporal processing
|
||||||
|
**kwargs: Additional keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed output tensor
|
||||||
|
"""
|
||||||
|
if isinstance(x, list):
|
||||||
|
input_dtype = x[0].dtype
|
||||||
|
batch_size = x[0].shape[0]
|
||||||
|
else:
|
||||||
|
input_dtype = x.dtype
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
# Process input
|
||||||
|
merged_args = {**transformer_options, **kwargs}
|
||||||
|
x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args)
|
||||||
|
merged_args.update(additional_args)
|
||||||
|
|
||||||
|
# Prepare timestep and context
|
||||||
|
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||||
|
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
||||||
|
|
||||||
|
# Prepare attention mask and positional embeddings
|
||||||
|
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
||||||
|
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
||||||
|
|
||||||
|
# Process transformer blocks
|
||||||
|
x = self._process_transformer_blocks(
|
||||||
|
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process output
|
||||||
|
x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVModel(LTXBaseModel):
|
||||||
|
"""LTXV model for video generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=128,
|
||||||
|
cross_attention_dim=2048,
|
||||||
|
attention_head_dim=64,
|
||||||
|
num_attention_heads=32,
|
||||||
|
caption_channels=4096,
|
||||||
|
num_layers=28,
|
||||||
|
positional_embedding_theta=10000.0,
|
||||||
|
positional_embedding_max_pos=[20, 2048, 2048],
|
||||||
|
causal_temporal_positioning=False,
|
||||||
|
vae_scale_factors=(8, 32, 32),
|
||||||
|
use_middle_indices_grid=False,
|
||||||
|
timestep_scale_multiplier = 1000.0,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
in_channels=in_channels,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
caption_channels=caption_channels,
|
||||||
|
num_layers=num_layers,
|
||||||
|
positional_embedding_theta=positional_embedding_theta,
|
||||||
|
positional_embedding_max_pos=positional_embedding_max_pos,
|
||||||
|
causal_temporal_positioning=causal_temporal_positioning,
|
||||||
|
vae_scale_factors=vae_scale_factors,
|
||||||
|
use_middle_indices_grid=use_middle_indices_grid,
|
||||||
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_model_components(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize LTXV-specific components."""
|
||||||
|
# No additional components needed for LTXV beyond base class
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
|
"""Initialize transformer blocks for LTXV."""
|
||||||
self.transformer_blocks = nn.ModuleList(
|
self.transformer_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
BasicTransformerBlock(
|
BasicTransformerBlock(
|
||||||
self.inner_dim,
|
self.inner_dim,
|
||||||
num_attention_heads,
|
self.num_attention_heads,
|
||||||
attention_head_dim,
|
self.attention_head_dim,
|
||||||
context_dim=cross_attention_dim,
|
context_dim=self.cross_attention_dim,
|
||||||
# attn_precision=attn_precision,
|
dtype=dtype,
|
||||||
dtype=dtype, device=device, operations=operations
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
)
|
)
|
||||||
for d in range(num_layers)
|
for _ in range(self.num_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _init_output_components(self, device, dtype):
|
||||||
|
"""Initialize output components for LTXV."""
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
||||||
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.norm_out = self.operations.LayerNorm(
|
||||||
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||||
|
)
|
||||||
self.patchifier = SymmetricPatchifier(1)
|
self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||||
|
self.patchifier = SymmetricPatchifier(1, start_end=True)
|
||||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
|
||||||
self._forward,
|
|
||||||
self,
|
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
|
||||||
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
|
|
||||||
|
|
||||||
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
|
||||||
|
|
||||||
orig_shape = list(x.shape)
|
|
||||||
|
|
||||||
|
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||||
|
"""Process input for LTXV."""
|
||||||
|
additional_args = {"orig_shape": list(x.shape)}
|
||||||
x, latent_coords = self.patchifier.patchify(x)
|
x, latent_coords = self.patchifier.patchify(x)
|
||||||
pixel_coords = latent_to_pixel_coords(
|
pixel_coords = latent_to_pixel_coords(
|
||||||
latent_coords=latent_coords,
|
latent_coords=latent_coords,
|
||||||
@ -423,44 +880,30 @@ class LTXVModel(torch.nn.Module):
|
|||||||
causal_fix=self.causal_temporal_positioning,
|
causal_fix=self.causal_temporal_positioning,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
grid_mask = None
|
||||||
if keyframe_idxs is not None:
|
if keyframe_idxs is not None:
|
||||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
|
additional_args.update({ "orig_patchified_shape": list(x.shape)})
|
||||||
|
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
|
||||||
|
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
|
||||||
|
additional_args.update({"grid_mask": grid_mask})
|
||||||
|
x = x[:, grid_mask, :]
|
||||||
|
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
||||||
|
|
||||||
fractional_coords = pixel_coords.to(torch.float32)
|
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
||||||
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||||
|
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||||
|
|
||||||
x = self.patchify_proj(x)
|
x = self.patchify_proj(x)
|
||||||
timestep = timestep * 1000.0
|
return x, pixel_coords, additional_args
|
||||||
|
|
||||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
|
||||||
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
|
||||||
|
|
||||||
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
|
|
||||||
|
|
||||||
batch_size = x.shape[0]
|
|
||||||
timestep, embedded_timestep = self.adaln_single(
|
|
||||||
timestep.flatten(),
|
|
||||||
{"resolution": None, "aspect_ratio": None},
|
|
||||||
batch_size=batch_size,
|
|
||||||
hidden_dtype=x.dtype,
|
|
||||||
)
|
|
||||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
|
||||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
|
||||||
embedded_timestep = embedded_timestep.view(
|
|
||||||
batch_size, -1, embedded_timestep.shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Blocks
|
|
||||||
if self.caption_projection is not None:
|
|
||||||
batch_size = x.shape[0]
|
|
||||||
context = self.caption_projection(context)
|
|
||||||
context = context.view(
|
|
||||||
batch_size, -1, x.shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
|
||||||
|
"""Process transformer blocks for LTXV."""
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
|
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||||
@ -478,16 +921,28 @@ class LTXVModel(torch.nn.Module):
|
|||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Output
|
return x
|
||||||
|
|
||||||
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||||
|
"""Process output for LTXV."""
|
||||||
|
# Apply scale-shift modulation
|
||||||
scale_shift_values = (
|
scale_shift_values = (
|
||||||
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||||
)
|
)
|
||||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||||
|
|
||||||
x = self.norm_out(x)
|
x = self.norm_out(x)
|
||||||
# Modulation
|
x = x * (1 + scale) + shift
|
||||||
x = torch.addcmul(x, x, scale).add_(shift)
|
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
|
|
||||||
|
if keyframe_idxs is not None:
|
||||||
|
grid_mask = kwargs["grid_mask"]
|
||||||
|
orig_patchified_shape = kwargs["orig_patchified_shape"]
|
||||||
|
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
|
||||||
|
full_x[:, grid_mask, :] = x
|
||||||
|
x = full_x
|
||||||
|
# Unpatchify to restore original dimensions
|
||||||
|
orig_shape = kwargs["orig_shape"]
|
||||||
x = self.patchifier.unpatchify(
|
x = self.patchifier.unpatchify(
|
||||||
latents=x,
|
latents=x,
|
||||||
output_height=orig_shape[3],
|
output_height=orig_shape[3],
|
||||||
|
|||||||
@ -21,20 +21,23 @@ def latent_to_pixel_coords(
|
|||||||
Returns:
|
Returns:
|
||||||
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
||||||
"""
|
"""
|
||||||
|
shape = [1] * latent_coords.ndim
|
||||||
|
shape[1] = -1
|
||||||
pixel_coords = (
|
pixel_coords = (
|
||||||
latent_coords
|
latent_coords
|
||||||
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
|
* torch.tensor(scale_factors, device=latent_coords.device).view(*shape)
|
||||||
)
|
)
|
||||||
if causal_fix:
|
if causal_fix:
|
||||||
# Fix temporal scale for first frame to 1 due to causality
|
# Fix temporal scale for first frame to 1 due to causality
|
||||||
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
|
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
||||||
return pixel_coords
|
return pixel_coords
|
||||||
|
|
||||||
|
|
||||||
class Patchifier(ABC):
|
class Patchifier(ABC):
|
||||||
def __init__(self, patch_size: int):
|
def __init__(self, patch_size: int, start_end: bool=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._patch_size = (1, patch_size, patch_size)
|
self._patch_size = (1, patch_size, patch_size)
|
||||||
|
self.start_end = start_end
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def patchify(
|
def patchify(
|
||||||
@ -71,11 +74,23 @@ class Patchifier(ABC):
|
|||||||
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
||||||
indexing="ij",
|
indexing="ij",
|
||||||
)
|
)
|
||||||
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
|
latent_sample_coords_start = torch.stack(latent_sample_coords, dim=0)
|
||||||
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
delta = torch.tensor(self._patch_size, device=latent_sample_coords_start.device, dtype=latent_sample_coords_start.dtype)[:, None, None, None]
|
||||||
latent_coords = rearrange(
|
latent_sample_coords_end = latent_sample_coords_start + delta
|
||||||
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
|
|
||||||
|
latent_sample_coords_start = latent_sample_coords_start.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||||
|
latent_sample_coords_start = rearrange(
|
||||||
|
latent_sample_coords_start, "b c f h w -> b c (f h w)", b=batch_size
|
||||||
)
|
)
|
||||||
|
if self.start_end:
|
||||||
|
latent_sample_coords_end = latent_sample_coords_end.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||||
|
latent_sample_coords_end = rearrange(
|
||||||
|
latent_sample_coords_end, "b c f h w -> b c (f h w)", b=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_coords = torch.stack((latent_sample_coords_start, latent_sample_coords_end), dim=-1)
|
||||||
|
else:
|
||||||
|
latent_coords = latent_sample_coords_start
|
||||||
return latent_coords
|
return latent_coords
|
||||||
|
|
||||||
|
|
||||||
@ -115,3 +130,61 @@ class SymmetricPatchifier(Patchifier):
|
|||||||
q=self._patch_size[2],
|
q=self._patch_size[2],
|
||||||
)
|
)
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
class AudioPatchifier(Patchifier):
|
||||||
|
def __init__(self, patch_size: int,
|
||||||
|
sample_rate=16000,
|
||||||
|
hop_length=160,
|
||||||
|
audio_latent_downsample_factor=4,
|
||||||
|
is_causal=True,
|
||||||
|
start_end=False,
|
||||||
|
shift = 0
|
||||||
|
):
|
||||||
|
super().__init__(patch_size, start_end=start_end)
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
||||||
|
self.is_causal = is_causal
|
||||||
|
self.shift = shift
|
||||||
|
|
||||||
|
def copy_with_shift(self, shift):
|
||||||
|
return AudioPatchifier(
|
||||||
|
self.patch_size, self.sample_rate, self.hop_length, self.audio_latent_downsample_factor,
|
||||||
|
self.is_causal, self.start_end, shift
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_audio_latent_time_in_sec(self, start_latent, end_latent: int, dtype: torch.dtype, device=torch.device):
|
||||||
|
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
|
||||||
|
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
|
||||||
|
if self.is_causal:
|
||||||
|
audio_mel_frame = (audio_mel_frame + 1 - self.audio_latent_downsample_factor).clip(min=0)
|
||||||
|
return audio_mel_frame * self.hop_length / self.sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def patchify(self, audio_latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# audio_latents: (batch, channels, time, freq)
|
||||||
|
b, _, t, _ = audio_latents.shape
|
||||||
|
audio_latents = rearrange(
|
||||||
|
audio_latents,
|
||||||
|
"b c t f -> b t (c f)",
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_latents_start_timings = self._get_audio_latent_time_in_sec(self.shift, t + self.shift, torch.float32, audio_latents.device)
|
||||||
|
audio_latents_start_timings = audio_latents_start_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
|
||||||
|
|
||||||
|
if self.start_end:
|
||||||
|
audio_latents_end_timings = self._get_audio_latent_time_in_sec(self.shift + 1, t + self.shift + 1, torch.float32, audio_latents.device)
|
||||||
|
audio_latents_end_timings = audio_latents_end_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
|
||||||
|
|
||||||
|
audio_latents_timings = torch.stack([audio_latents_start_timings, audio_latents_end_timings], dim=-1)
|
||||||
|
else:
|
||||||
|
audio_latents_timings = audio_latents_start_timings
|
||||||
|
return audio_latents, audio_latents_timings
|
||||||
|
|
||||||
|
def unpatchify(self, audio_latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor:
|
||||||
|
# audio_latents: (batch, time, freq * channels)
|
||||||
|
audio_latents = rearrange(
|
||||||
|
audio_latents, "b t (c f) -> b c t f", c=channels, f=freq
|
||||||
|
)
|
||||||
|
return audio_latents
|
||||||
|
|||||||
286
comfy/ldm/lightricks/vae/audio_vae.py
Normal file
286
comfy/ldm/lightricks/vae/audio_vae.py
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.model_patcher
|
||||||
|
import comfy.utils as utils
|
||||||
|
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
|
||||||
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
|
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
||||||
|
CausalityAxis,
|
||||||
|
CausalAudioAutoencoder,
|
||||||
|
)
|
||||||
|
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
|
||||||
|
|
||||||
|
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AudioVAEComponentConfig:
|
||||||
|
"""Container for model component configuration extracted from metadata."""
|
||||||
|
|
||||||
|
autoencoder: dict
|
||||||
|
vocoder: dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig":
|
||||||
|
assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE"
|
||||||
|
|
||||||
|
raw_config = metadata["config"]
|
||||||
|
if isinstance(raw_config, str):
|
||||||
|
parsed_config = json.loads(raw_config)
|
||||||
|
else:
|
||||||
|
parsed_config = raw_config
|
||||||
|
|
||||||
|
audio_config = parsed_config.get("audio_vae")
|
||||||
|
vocoder_config = parsed_config.get("vocoder")
|
||||||
|
|
||||||
|
assert audio_config is not None, "Audio VAE config is required for audio VAE"
|
||||||
|
assert vocoder_config is not None, "Vocoder config is required for audio VAE"
|
||||||
|
|
||||||
|
return cls(autoencoder=audio_config, vocoder=vocoder_config)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDeviceManager:
|
||||||
|
"""Manages device placement and GPU residency for the composed model."""
|
||||||
|
|
||||||
|
def __init__(self, module: torch.nn.Module):
|
||||||
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
offload_device = comfy.model_management.vae_offload_device()
|
||||||
|
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
|
||||||
|
|
||||||
|
def ensure_model_loaded(self) -> None:
|
||||||
|
comfy.model_management.free_memory(
|
||||||
|
self.patcher.model_size(),
|
||||||
|
self.patcher.load_device,
|
||||||
|
)
|
||||||
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
|
|
||||||
|
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
return tensor.to(self.patcher.load_device)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def load_device(self):
|
||||||
|
return self.patcher.load_device
|
||||||
|
|
||||||
|
|
||||||
|
class AudioLatentNormalizer:
|
||||||
|
"""Applies per-channel statistics in patch space and restores original layout."""
|
||||||
|
|
||||||
|
def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module):
|
||||||
|
self.patchifier = patchfier
|
||||||
|
self.statistics = statistics_processor
|
||||||
|
|
||||||
|
def normalize(self, latents: torch.Tensor) -> torch.Tensor:
|
||||||
|
channels = latents.shape[1]
|
||||||
|
freq = latents.shape[3]
|
||||||
|
patched, _ = self.patchifier.patchify(latents)
|
||||||
|
normalized = self.statistics.normalize(patched)
|
||||||
|
return self.patchifier.unpatchify(normalized, channels=channels, freq=freq)
|
||||||
|
|
||||||
|
def denormalize(self, latents: torch.Tensor) -> torch.Tensor:
|
||||||
|
channels = latents.shape[1]
|
||||||
|
freq = latents.shape[3]
|
||||||
|
patched, _ = self.patchifier.patchify(latents)
|
||||||
|
denormalized = self.statistics.un_normalize(patched)
|
||||||
|
return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioPreprocessor:
|
||||||
|
"""Prepares raw waveforms for the autoencoder by matching training conditions."""
|
||||||
|
|
||||||
|
def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int):
|
||||||
|
self.target_sample_rate = target_sample_rate
|
||||||
|
self.mel_bins = mel_bins
|
||||||
|
self.mel_hop_length = mel_hop_length
|
||||||
|
self.n_fft = n_fft
|
||||||
|
|
||||||
|
def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor:
|
||||||
|
if source_rate == self.target_sample_rate:
|
||||||
|
return waveform
|
||||||
|
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_amplitude(
|
||||||
|
waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5
|
||||||
|
) -> torch.Tensor:
|
||||||
|
waveform = waveform - waveform.mean(dim=2, keepdim=True)
|
||||||
|
peak = torch.max(torch.abs(waveform)) + eps
|
||||||
|
scale = peak.clamp(max=max_amplitude) / peak
|
||||||
|
return waveform * scale
|
||||||
|
|
||||||
|
def waveform_to_mel(
|
||||||
|
self, waveform: torch.Tensor, waveform_sample_rate: int, device
|
||||||
|
) -> torch.Tensor:
|
||||||
|
waveform = self.resample(waveform, waveform_sample_rate)
|
||||||
|
waveform = self.normalize_amplitude(waveform)
|
||||||
|
|
||||||
|
mel_transform = torchaudio.transforms.MelSpectrogram(
|
||||||
|
sample_rate=self.target_sample_rate,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
win_length=self.n_fft,
|
||||||
|
hop_length=self.mel_hop_length,
|
||||||
|
f_min=0.0,
|
||||||
|
f_max=self.target_sample_rate / 2.0,
|
||||||
|
n_mels=self.mel_bins,
|
||||||
|
window_fn=torch.hann_window,
|
||||||
|
center=True,
|
||||||
|
pad_mode="reflect",
|
||||||
|
power=1.0,
|
||||||
|
mel_scale="slaney",
|
||||||
|
norm="slaney",
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
mel = mel_transform(waveform)
|
||||||
|
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||||
|
return mel.permute(0, 1, 3, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class AudioVAE(torch.nn.Module):
|
||||||
|
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
|
||||||
|
|
||||||
|
def __init__(self, state_dict: dict, metadata: dict):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
component_config = AudioVAEComponentConfig.from_metadata(metadata)
|
||||||
|
|
||||||
|
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
|
||||||
|
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
||||||
|
|
||||||
|
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||||
|
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||||
|
|
||||||
|
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||||
|
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||||
|
|
||||||
|
autoencoder_config = self.autoencoder.get_config()
|
||||||
|
self.normalizer = AudioLatentNormalizer(
|
||||||
|
AudioPatchifier(
|
||||||
|
patch_size=1,
|
||||||
|
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||||
|
sample_rate=autoencoder_config["sampling_rate"],
|
||||||
|
hop_length=autoencoder_config["mel_hop_length"],
|
||||||
|
is_causal=autoencoder_config["is_causal"],
|
||||||
|
),
|
||||||
|
self.autoencoder.per_channel_statistics,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.preprocessor = AudioPreprocessor(
|
||||||
|
target_sample_rate=autoencoder_config["sampling_rate"],
|
||||||
|
mel_bins=autoencoder_config["mel_bins"],
|
||||||
|
mel_hop_length=autoencoder_config["mel_hop_length"],
|
||||||
|
n_fft=autoencoder_config["n_fft"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.device_manager = ModelDeviceManager(self)
|
||||||
|
|
||||||
|
def encode(self, audio: dict) -> torch.Tensor:
|
||||||
|
"""Encode a waveform dictionary into normalized latent tensors."""
|
||||||
|
|
||||||
|
waveform = audio["waveform"]
|
||||||
|
waveform_sample_rate = audio["sample_rate"]
|
||||||
|
input_device = waveform.device
|
||||||
|
# Ensure that Audio VAE is loaded on the correct device.
|
||||||
|
self.device_manager.ensure_model_loaded()
|
||||||
|
|
||||||
|
waveform = self.device_manager.move_to_load_device(waveform)
|
||||||
|
expected_channels = self.autoencoder.encoder.in_channels
|
||||||
|
if waveform.shape[1] != expected_channels:
|
||||||
|
raise ValueError(
|
||||||
|
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
mel_spec = self.preprocessor.waveform_to_mel(
|
||||||
|
waveform, waveform_sample_rate, device=self.device_manager.load_device
|
||||||
|
)
|
||||||
|
|
||||||
|
latents = self.autoencoder.encode(mel_spec)
|
||||||
|
posterior = DiagonalGaussianDistribution(latents)
|
||||||
|
latent_mode = posterior.mode()
|
||||||
|
|
||||||
|
normalized = self.normalizer.normalize(latent_mode)
|
||||||
|
return normalized.to(input_device)
|
||||||
|
|
||||||
|
def decode(self, latents: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Decode normalized latent tensors into an audio waveform."""
|
||||||
|
original_shape = latents.shape
|
||||||
|
|
||||||
|
# Ensure that Audio VAE is loaded on the correct device.
|
||||||
|
self.device_manager.ensure_model_loaded()
|
||||||
|
|
||||||
|
latents = self.device_manager.move_to_load_device(latents)
|
||||||
|
latents = self.normalizer.denormalize(latents)
|
||||||
|
|
||||||
|
target_shape = self.target_shape_from_latents(original_shape)
|
||||||
|
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
|
||||||
|
|
||||||
|
waveform = self.run_vocoder(mel_spec)
|
||||||
|
return self.device_manager.move_to_load_device(waveform)
|
||||||
|
|
||||||
|
def target_shape_from_latents(self, latents_shape):
|
||||||
|
batch, _, time, _ = latents_shape
|
||||||
|
target_length = time * LATENT_DOWNSAMPLE_FACTOR
|
||||||
|
if self.autoencoder.causality_axis != CausalityAxis.NONE:
|
||||||
|
target_length -= LATENT_DOWNSAMPLE_FACTOR - 1
|
||||||
|
return (
|
||||||
|
batch,
|
||||||
|
self.autoencoder.decoder.out_ch,
|
||||||
|
target_length,
|
||||||
|
self.autoencoder.mel_bins,
|
||||||
|
)
|
||||||
|
|
||||||
|
def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int:
|
||||||
|
return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second)
|
||||||
|
|
||||||
|
def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
audio_channels = self.autoencoder.decoder.out_ch
|
||||||
|
vocoder_input = mel_spec.transpose(2, 3)
|
||||||
|
|
||||||
|
if audio_channels == 1:
|
||||||
|
vocoder_input = vocoder_input.squeeze(1)
|
||||||
|
elif audio_channels != 2:
|
||||||
|
raise ValueError(f"Unsupported audio_channels: {audio_channels}")
|
||||||
|
|
||||||
|
return self.vocoder(vocoder_input)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_rate(self) -> int:
|
||||||
|
return int(self.autoencoder.sampling_rate)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mel_hop_length(self) -> int:
|
||||||
|
return int(self.autoencoder.mel_hop_length)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mel_bins(self) -> int:
|
||||||
|
return int(self.autoencoder.mel_bins)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latent_channels(self) -> int:
|
||||||
|
return int(self.autoencoder.decoder.z_channels)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latent_frequency_bins(self) -> int:
|
||||||
|
return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latents_per_second(self) -> float:
|
||||||
|
return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_sample_rate(self) -> int:
|
||||||
|
output_rate = getattr(self.vocoder, "output_sample_rate", None)
|
||||||
|
if output_rate is not None:
|
||||||
|
return int(output_rate)
|
||||||
|
upsample_factor = getattr(self.vocoder, "upsample_factor", None)
|
||||||
|
if upsample_factor is None:
|
||||||
|
raise AttributeError(
|
||||||
|
"Vocoder is missing upsample_factor; cannot infer output sample rate"
|
||||||
|
)
|
||||||
|
return int(self.sample_rate * upsample_factor / self.mel_hop_length)
|
||||||
|
|
||||||
|
def memory_required(self, input_shape):
|
||||||
|
return self.device_manager.patcher.model_size()
|
||||||
909
comfy/ldm/lightricks/vae/causal_audio_autoencoder.py
Normal file
909
comfy/ldm/lightricks/vae/causal_audio_autoencoder.py
Normal file
@ -0,0 +1,909 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from typing import Optional
|
||||||
|
from enum import Enum
|
||||||
|
from .pixel_norm import PixelNorm
|
||||||
|
import comfy.ops
|
||||||
|
import logging
|
||||||
|
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
|
||||||
|
class StringConvertibleEnum(Enum):
|
||||||
|
"""
|
||||||
|
Base enum class that provides string-to-enum conversion functionality.
|
||||||
|
|
||||||
|
This mixin adds a str_to_enum() class method that handles conversion from
|
||||||
|
strings, None, or existing enum instances with case-insensitive matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def str_to_enum(cls, value):
|
||||||
|
"""
|
||||||
|
Convert a string, enum instance, or None to the appropriate enum member.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Can be an enum instance of this class, a string, or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Enum member of this class
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the value cannot be converted to a valid enum member
|
||||||
|
"""
|
||||||
|
# Already an enum instance of this class
|
||||||
|
if isinstance(value, cls):
|
||||||
|
return value
|
||||||
|
|
||||||
|
# None maps to NONE member if it exists
|
||||||
|
if value is None:
|
||||||
|
if hasattr(cls, "NONE"):
|
||||||
|
return cls.NONE
|
||||||
|
raise ValueError(f"{cls.__name__} does not have a NONE member to map None to")
|
||||||
|
|
||||||
|
# String conversion (case-insensitive)
|
||||||
|
if isinstance(value, str):
|
||||||
|
value_lower = value.lower()
|
||||||
|
|
||||||
|
# Try to match against enum values
|
||||||
|
for member in cls:
|
||||||
|
# Handle members with None values
|
||||||
|
if member.value is None:
|
||||||
|
if value_lower == "none":
|
||||||
|
return member
|
||||||
|
# Handle members with string values
|
||||||
|
elif isinstance(member.value, str) and member.value.lower() == value_lower:
|
||||||
|
return member
|
||||||
|
|
||||||
|
# Build helpful error message with valid values
|
||||||
|
valid_values = []
|
||||||
|
for member in cls:
|
||||||
|
if member.value is None:
|
||||||
|
valid_values.append("none")
|
||||||
|
elif isinstance(member.value, str):
|
||||||
|
valid_values.append(member.value)
|
||||||
|
|
||||||
|
raise ValueError(f"Invalid {cls.__name__} string: '{value}'. " f"Valid values are: {valid_values}")
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot convert type {type(value).__name__} to {cls.__name__} enum. "
|
||||||
|
f"Expected string, None, or {cls.__name__} instance."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionType(StringConvertibleEnum):
|
||||||
|
"""Enum for specifying the attention mechanism type."""
|
||||||
|
|
||||||
|
VANILLA = "vanilla"
|
||||||
|
LINEAR = "linear"
|
||||||
|
NONE = "none"
|
||||||
|
|
||||||
|
|
||||||
|
class CausalityAxis(StringConvertibleEnum):
|
||||||
|
"""Enum for specifying the causality axis in causal convolutions."""
|
||||||
|
|
||||||
|
NONE = None
|
||||||
|
WIDTH = "width"
|
||||||
|
HEIGHT = "height"
|
||||||
|
WIDTH_COMPATIBILITY = "width-compatibility"
|
||||||
|
|
||||||
|
|
||||||
|
def Normalize(in_channels, *, num_groups=32, normtype="group"):
|
||||||
|
if normtype == "group":
|
||||||
|
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
elif normtype == "pixel":
|
||||||
|
return PixelNorm(dim=1, eps=1e-6)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv2d(nn.Module):
|
||||||
|
"""
|
||||||
|
A causal 2D convolution.
|
||||||
|
|
||||||
|
This layer ensures that the output at time `t` only depends on inputs
|
||||||
|
at time `t` and earlier. It achieves this by applying asymmetric padding
|
||||||
|
to the time dimension (width) before the convolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
bias=True,
|
||||||
|
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
|
||||||
|
# Ensure kernel_size and dilation are tuples
|
||||||
|
kernel_size = nn.modules.utils._pair(kernel_size)
|
||||||
|
dilation = nn.modules.utils._pair(dilation)
|
||||||
|
|
||||||
|
# Calculate padding dimensions
|
||||||
|
pad_h = (kernel_size[0] - 1) * dilation[0]
|
||||||
|
pad_w = (kernel_size[1] - 1) * dilation[1]
|
||||||
|
|
||||||
|
# The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
|
||||||
|
match self.causality_axis:
|
||||||
|
case CausalityAxis.NONE:
|
||||||
|
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
||||||
|
case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
|
||||||
|
self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
||||||
|
case CausalityAxis.HEIGHT:
|
||||||
|
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
||||||
|
|
||||||
|
# The internal convolution layer uses no padding, as we handle it manually
|
||||||
|
self.conv = ops.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=0,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Apply causal padding before convolution
|
||||||
|
x = F.pad(x, self.padding)
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
def make_conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=None,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
bias=True,
|
||||||
|
causality_axis: Optional[CausalityAxis] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a 2D convolution layer that can be either causal or non-causal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels: Number of input channels
|
||||||
|
out_channels: Number of output channels
|
||||||
|
kernel_size: Size of the convolution kernel
|
||||||
|
stride: Convolution stride
|
||||||
|
padding: Padding (if None, will be calculated based on causal flag)
|
||||||
|
dilation: Dilation rate
|
||||||
|
groups: Number of groups for grouped convolution
|
||||||
|
bias: Whether to use bias
|
||||||
|
causality_axis: Dimension along which to apply causality.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either a regular Conv2d or CausalConv2d layer
|
||||||
|
"""
|
||||||
|
if causality_axis is not None:
|
||||||
|
# For causal convolution, padding is handled internally by CausalConv2d
|
||||||
|
return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
|
||||||
|
else:
|
||||||
|
# For non-causal convolution, use symmetric padding if not specified
|
||||||
|
if padding is None:
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
padding = kernel_size // 2
|
||||||
|
else:
|
||||||
|
padding = tuple(k // 2 for k in kernel_size)
|
||||||
|
return ops.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
groups,
|
||||||
|
bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.HEIGHT):
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = with_conv
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
if self.with_conv:
|
||||||
|
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
|
if self.with_conv:
|
||||||
|
x = self.conv(x)
|
||||||
|
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
|
||||||
|
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
|
||||||
|
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
|
||||||
|
# So the output elements rely on the following windows:
|
||||||
|
# 0: [-,-,0]
|
||||||
|
# 1: [-,0,0]
|
||||||
|
# 2: [0,0,1]
|
||||||
|
# 3: [0,1,1]
|
||||||
|
# 4: [1,1,2]
|
||||||
|
# 5: [1,2,2]
|
||||||
|
# Notice that the first and second elements in the output rely only on the first element in the input,
|
||||||
|
# while all other elements rely on two elements in the input.
|
||||||
|
# So we can drop the first element to undo the padding (rather than the last element).
|
||||||
|
# This is a no-op for non-causal convolutions.
|
||||||
|
match self.causality_axis:
|
||||||
|
case CausalityAxis.NONE:
|
||||||
|
pass # x remains unchanged
|
||||||
|
case CausalityAxis.HEIGHT:
|
||||||
|
x = x[:, :, 1:, :]
|
||||||
|
case CausalityAxis.WIDTH:
|
||||||
|
x = x[:, :, :, 1:]
|
||||||
|
case CausalityAxis.WIDTH_COMPATIBILITY:
|
||||||
|
pass # x remains unchanged
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
"""
|
||||||
|
A downsampling layer that can use either a strided convolution
|
||||||
|
or average pooling. Supports standard and causal padding for the
|
||||||
|
convolutional mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.WIDTH):
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = with_conv
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
|
||||||
|
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
|
||||||
|
raise ValueError("causality is only supported when `with_conv=True`.")
|
||||||
|
|
||||||
|
if self.with_conv:
|
||||||
|
# Do time downsampling here
|
||||||
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
|
self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.with_conv:
|
||||||
|
# (pad_left, pad_right, pad_top, pad_bottom)
|
||||||
|
match self.causality_axis:
|
||||||
|
case CausalityAxis.NONE:
|
||||||
|
pad = (0, 1, 0, 1)
|
||||||
|
case CausalityAxis.WIDTH:
|
||||||
|
pad = (2, 0, 0, 1)
|
||||||
|
case CausalityAxis.HEIGHT:
|
||||||
|
pad = (0, 1, 2, 0)
|
||||||
|
case CausalityAxis.WIDTH_COMPATIBILITY:
|
||||||
|
pad = (1, 0, 0, 1)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||||
|
|
||||||
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||||
|
x = self.conv(x)
|
||||||
|
else:
|
||||||
|
# This branch is only taken if with_conv=False, which implies causality_axis is NONE.
|
||||||
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
in_channels,
|
||||||
|
out_channels=None,
|
||||||
|
conv_shortcut=False,
|
||||||
|
dropout,
|
||||||
|
temb_channels=512,
|
||||||
|
norm_type="group",
|
||||||
|
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
|
||||||
|
if self.causality_axis != CausalityAxis.NONE and norm_type == "group":
|
||||||
|
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
||||||
|
self.in_channels = in_channels
|
||||||
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
|
||||||
|
self.norm1 = Normalize(in_channels, normtype=norm_type)
|
||||||
|
self.non_linearity = nn.SiLU()
|
||||||
|
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||||
|
if temb_channels > 0:
|
||||||
|
self.temb_proj = ops.Linear(temb_channels, out_channels)
|
||||||
|
self.norm2 = Normalize(out_channels, normtype=norm_type)
|
||||||
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
|
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
self.conv_shortcut = make_conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.nin_shortcut = make_conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, temb):
|
||||||
|
h = x
|
||||||
|
h = self.norm1(h)
|
||||||
|
h = self.non_linearity(h)
|
||||||
|
h = self.conv1(h)
|
||||||
|
|
||||||
|
if temb is not None:
|
||||||
|
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
||||||
|
|
||||||
|
h = self.norm2(h)
|
||||||
|
h = self.non_linearity(h)
|
||||||
|
h = self.dropout(h)
|
||||||
|
h = self.conv2(h)
|
||||||
|
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
x = self.conv_shortcut(x)
|
||||||
|
else:
|
||||||
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
|
class AttnBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, norm_type="group"):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.norm = Normalize(in_channels, normtype=norm_type)
|
||||||
|
self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q = self.q(h_)
|
||||||
|
k = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
b, c, h, w = q.shape
|
||||||
|
q = q.reshape(b, c, h * w).contiguous()
|
||||||
|
q = q.permute(0, 2, 1).contiguous() # b,hw,c
|
||||||
|
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
|
||||||
|
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||||
|
w_ = w_ * (int(c) ** (-0.5))
|
||||||
|
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||||
|
|
||||||
|
# attend to values
|
||||||
|
v = v.reshape(b, c, h * w).contiguous()
|
||||||
|
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
|
||||||
|
h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||||
|
h_ = h_.reshape(b, c, h, w).contiguous()
|
||||||
|
|
||||||
|
h_ = self.proj_out(h_)
|
||||||
|
|
||||||
|
return x + h_
|
||||||
|
|
||||||
|
|
||||||
|
def make_attn(in_channels, attn_type="vanilla", norm_type="group"):
|
||||||
|
# Convert string to enum if needed
|
||||||
|
attn_type = AttentionType.str_to_enum(attn_type)
|
||||||
|
|
||||||
|
if attn_type != AttentionType.NONE:
|
||||||
|
logging.info(f"making attention of type '{attn_type.value}' with {in_channels} in_channels")
|
||||||
|
else:
|
||||||
|
logging.info(f"making identity attention with {in_channels} in_channels")
|
||||||
|
|
||||||
|
match attn_type:
|
||||||
|
case AttentionType.VANILLA:
|
||||||
|
return AttnBlock(in_channels, norm_type=norm_type)
|
||||||
|
case AttentionType.NONE:
|
||||||
|
return nn.Identity(in_channels)
|
||||||
|
case AttentionType.LINEAR:
|
||||||
|
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ch,
|
||||||
|
out_ch,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
num_res_blocks,
|
||||||
|
attn_resolutions,
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
in_channels,
|
||||||
|
resolution,
|
||||||
|
z_channels,
|
||||||
|
double_z=True,
|
||||||
|
attn_type="vanilla",
|
||||||
|
mid_block_add_attention=True,
|
||||||
|
norm_type="group",
|
||||||
|
causality_axis=CausalityAxis.WIDTH.value,
|
||||||
|
**ignore_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = ch
|
||||||
|
self.temb_ch = 0
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.z_channels = z_channels
|
||||||
|
self.double_z = double_z
|
||||||
|
self.norm_type = norm_type
|
||||||
|
# Convert string to enum if needed (for config loading)
|
||||||
|
causality_axis = CausalityAxis.str_to_enum(causality_axis)
|
||||||
|
self.attn_type = AttentionType.str_to_enum(attn_type)
|
||||||
|
|
||||||
|
# downsampling
|
||||||
|
self.conv_in = make_conv2d(
|
||||||
|
in_channels,
|
||||||
|
self.ch,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.non_linearity = nn.SiLU()
|
||||||
|
|
||||||
|
curr_res = resolution
|
||||||
|
in_ch_mult = (1,) + tuple(ch_mult)
|
||||||
|
self.in_ch_mult = in_ch_mult
|
||||||
|
self.down = nn.ModuleList()
|
||||||
|
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_in = ch * in_ch_mult[i_level]
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
|
||||||
|
for _ in range(self.num_res_blocks):
|
||||||
|
block.append(
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
block_in = block_out
|
||||||
|
if curr_res in attn_resolutions:
|
||||||
|
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
|
||||||
|
|
||||||
|
down = nn.Module()
|
||||||
|
down.block = block
|
||||||
|
down.attn = attn
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||||
|
curr_res = curr_res // 2
|
||||||
|
self.down.append(down)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
if mid_block_add_attention:
|
||||||
|
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
|
||||||
|
else:
|
||||||
|
self.mid.attn_1 = nn.Identity()
|
||||||
|
self.mid.block_2 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(block_in, normtype=self.norm_type)
|
||||||
|
self.conv_out = make_conv2d(
|
||||||
|
block_in,
|
||||||
|
2 * z_channels if double_z else z_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass through the encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape [batch, channels, time, n_mels]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Encoded latent representation
|
||||||
|
"""
|
||||||
|
feature_maps = [self.conv_in(x)]
|
||||||
|
|
||||||
|
# Process each resolution level (from high to low resolution)
|
||||||
|
for resolution_level in range(self.num_resolutions):
|
||||||
|
# Apply residual blocks at current resolution level
|
||||||
|
for block_idx in range(self.num_res_blocks):
|
||||||
|
# Apply ResNet block with optional timestep embedding
|
||||||
|
current_features = self.down[resolution_level].block[block_idx](feature_maps[-1], temb=None)
|
||||||
|
|
||||||
|
# Apply attention if configured for this resolution level
|
||||||
|
if len(self.down[resolution_level].attn) > 0:
|
||||||
|
current_features = self.down[resolution_level].attn[block_idx](current_features)
|
||||||
|
|
||||||
|
# Store processed features
|
||||||
|
feature_maps.append(current_features)
|
||||||
|
|
||||||
|
# Downsample spatial dimensions (except at the final resolution level)
|
||||||
|
if resolution_level != self.num_resolutions - 1:
|
||||||
|
downsampled_features = self.down[resolution_level].downsample(feature_maps[-1])
|
||||||
|
feature_maps.append(downsampled_features)
|
||||||
|
|
||||||
|
# === MIDDLE PROCESSING PHASE ===
|
||||||
|
# Take the lowest resolution features for middle processing
|
||||||
|
bottleneck_features = feature_maps[-1]
|
||||||
|
|
||||||
|
# Apply first middle ResNet block
|
||||||
|
bottleneck_features = self.mid.block_1(bottleneck_features, temb=None)
|
||||||
|
|
||||||
|
# Apply middle attention block
|
||||||
|
bottleneck_features = self.mid.attn_1(bottleneck_features)
|
||||||
|
|
||||||
|
# Apply second middle ResNet block
|
||||||
|
bottleneck_features = self.mid.block_2(bottleneck_features, temb=None)
|
||||||
|
|
||||||
|
# === OUTPUT PHASE ===
|
||||||
|
# Normalize the bottleneck features
|
||||||
|
output_features = self.norm_out(bottleneck_features)
|
||||||
|
|
||||||
|
# Apply non-linearity (SiLU activation)
|
||||||
|
output_features = self.non_linearity(output_features)
|
||||||
|
|
||||||
|
# Final convolution to produce latent representation
|
||||||
|
# [batch, channels, time, n_mels] -> [batch, 2 * z_channels if double_z else z_channels, time, n_mels]
|
||||||
|
return self.conv_out(output_features)
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ch,
|
||||||
|
out_ch,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
num_res_blocks,
|
||||||
|
attn_resolutions,
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
in_channels,
|
||||||
|
resolution,
|
||||||
|
z_channels,
|
||||||
|
give_pre_end=False,
|
||||||
|
tanh_out=False,
|
||||||
|
attn_type="vanilla",
|
||||||
|
mid_block_add_attention=True,
|
||||||
|
norm_type="group",
|
||||||
|
causality_axis=CausalityAxis.WIDTH.value,
|
||||||
|
**ignorekwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = ch
|
||||||
|
self.temb_ch = 0
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_ch = out_ch
|
||||||
|
self.give_pre_end = give_pre_end
|
||||||
|
self.tanh_out = tanh_out
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.z_channels = z_channels
|
||||||
|
# Convert string to enum if needed (for config loading)
|
||||||
|
causality_axis = CausalityAxis.str_to_enum(causality_axis)
|
||||||
|
self.attn_type = AttentionType.str_to_enum(attn_type)
|
||||||
|
|
||||||
|
# compute block_in and curr_res at lowest res
|
||||||
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
self.conv_in = make_conv2d(z_channels, block_in, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||||
|
|
||||||
|
self.non_linearity = nn.SiLU()
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
if mid_block_add_attention:
|
||||||
|
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
|
||||||
|
else:
|
||||||
|
self.mid.attn_1 = nn.Identity()
|
||||||
|
self.mid.block_2 = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for _ in range(self.num_res_blocks + 1):
|
||||||
|
block.append(
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
block_in = block_out
|
||||||
|
if curr_res in attn_resolutions:
|
||||||
|
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
|
||||||
|
up = nn.Module()
|
||||||
|
up.block = block
|
||||||
|
up.attn = attn
|
||||||
|
if i_level != 0:
|
||||||
|
up.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||||
|
curr_res = curr_res * 2
|
||||||
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(block_in, normtype=self.norm_type)
|
||||||
|
self.conv_out = make_conv2d(block_in, out_ch, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||||
|
|
||||||
|
def _adjust_output_shape(self, decoded_output, target_shape):
|
||||||
|
"""
|
||||||
|
Adjust output shape to match target dimensions for variable-length audio.
|
||||||
|
|
||||||
|
This function handles the common case where decoded audio spectrograms need to be
|
||||||
|
resized to match a specific target shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoded_output: Tensor of shape (batch, channels, time, frequency)
|
||||||
|
target_shape: Target shape tuple (batch, channels, time, frequency)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor adjusted to match target_shape exactly
|
||||||
|
"""
|
||||||
|
# Current output shape: (batch, channels, time, frequency)
|
||||||
|
_, _, current_time, current_freq = decoded_output.shape
|
||||||
|
_, target_channels, target_time, target_freq = target_shape
|
||||||
|
|
||||||
|
# Step 1: Crop first to avoid exceeding target dimensions
|
||||||
|
decoded_output = decoded_output[
|
||||||
|
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Step 2: Calculate padding needed for time and frequency dimensions
|
||||||
|
time_padding_needed = target_time - decoded_output.shape[2]
|
||||||
|
freq_padding_needed = target_freq - decoded_output.shape[3]
|
||||||
|
|
||||||
|
# Step 3: Apply padding if needed
|
||||||
|
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||||
|
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
|
||||||
|
# For audio: pad_left/right = frequency, pad_top/bottom = time
|
||||||
|
padding = (
|
||||||
|
0,
|
||||||
|
max(freq_padding_needed, 0), # frequency padding (left, right)
|
||||||
|
0,
|
||||||
|
max(time_padding_needed, 0), # time padding (top, bottom)
|
||||||
|
)
|
||||||
|
decoded_output = F.pad(decoded_output, padding)
|
||||||
|
|
||||||
|
# Step 4: Final safety crop to ensure exact target shape
|
||||||
|
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
||||||
|
|
||||||
|
return decoded_output
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {
|
||||||
|
"ch": self.ch,
|
||||||
|
"out_ch": self.out_ch,
|
||||||
|
"ch_mult": self.ch_mult,
|
||||||
|
"num_res_blocks": self.num_res_blocks,
|
||||||
|
"in_channels": self.in_channels,
|
||||||
|
"resolution": self.resolution,
|
||||||
|
"z_channels": self.z_channels,
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, latent_features, target_shape=None):
|
||||||
|
"""
|
||||||
|
Decode latent features back to audio spectrograms.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latent_features: Encoded latent representation of shape (batch, channels, height, width)
|
||||||
|
target_shape: Optional target output shape (batch, channels, time, frequency)
|
||||||
|
If provided, output will be cropped/padded to match this shape
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
|
||||||
|
"""
|
||||||
|
assert target_shape is not None, "Target shape is required for CausalAudioAutoencoder Decoder"
|
||||||
|
|
||||||
|
# Transform latent features to decoder's internal feature dimension
|
||||||
|
hidden_features = self.conv_in(latent_features)
|
||||||
|
|
||||||
|
# Middle processing
|
||||||
|
hidden_features = self.mid.block_1(hidden_features, temb=None)
|
||||||
|
hidden_features = self.mid.attn_1(hidden_features)
|
||||||
|
hidden_features = self.mid.block_2(hidden_features, temb=None)
|
||||||
|
|
||||||
|
# Upsampling
|
||||||
|
# Progressively increase spatial resolution from lowest to highest
|
||||||
|
for resolution_level in reversed(range(self.num_resolutions)):
|
||||||
|
# Apply residual blocks at current resolution level
|
||||||
|
for block_index in range(self.num_res_blocks + 1):
|
||||||
|
hidden_features = self.up[resolution_level].block[block_index](hidden_features, temb=None)
|
||||||
|
|
||||||
|
if len(self.up[resolution_level].attn) > 0:
|
||||||
|
hidden_features = self.up[resolution_level].attn[block_index](hidden_features)
|
||||||
|
|
||||||
|
if resolution_level != 0:
|
||||||
|
hidden_features = self.up[resolution_level].upsample(hidden_features)
|
||||||
|
|
||||||
|
# Output
|
||||||
|
if self.give_pre_end:
|
||||||
|
# Return intermediate features before final processing (for debugging/analysis)
|
||||||
|
decoded_output = hidden_features
|
||||||
|
else:
|
||||||
|
# Standard output path: normalize, activate, and convert to output channels
|
||||||
|
# Final normalization layer
|
||||||
|
hidden_features = self.norm_out(hidden_features)
|
||||||
|
|
||||||
|
# Apply SiLU (Swish) activation function
|
||||||
|
hidden_features = self.non_linearity(hidden_features)
|
||||||
|
|
||||||
|
# Final convolution to map to output channels (typically 2 for stereo audio)
|
||||||
|
decoded_output = self.conv_out(hidden_features)
|
||||||
|
|
||||||
|
# Optional tanh activation to bound output values to [-1, 1] range
|
||||||
|
if self.tanh_out:
|
||||||
|
decoded_output = torch.tanh(decoded_output)
|
||||||
|
|
||||||
|
# Adjust shape for audio data
|
||||||
|
if target_shape is not None:
|
||||||
|
decoded_output = self._adjust_output_shape(decoded_output, target_shape)
|
||||||
|
|
||||||
|
return decoded_output
|
||||||
|
|
||||||
|
|
||||||
|
class processor(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("std-of-means", torch.empty(128))
|
||||||
|
self.register_buffer("mean-of-means", torch.empty(128))
|
||||||
|
|
||||||
|
def un_normalize(self, x):
|
||||||
|
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
|
||||||
|
|
||||||
|
def normalize(self, x):
|
||||||
|
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
|
||||||
|
|
||||||
|
|
||||||
|
class CausalAudioAutoencoder(nn.Module):
|
||||||
|
def __init__(self, config=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = self._guess_config()
|
||||||
|
|
||||||
|
# Extract encoder and decoder configs from the new format
|
||||||
|
model_config = config.get("model", {}).get("params", {})
|
||||||
|
variables_config = config.get("variables", {})
|
||||||
|
|
||||||
|
self.sampling_rate = variables_config.get(
|
||||||
|
"sampling_rate",
|
||||||
|
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
|
||||||
|
)
|
||||||
|
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
||||||
|
decoder_config = model_config.get("decoder", encoder_config)
|
||||||
|
|
||||||
|
# Load mel spectrogram parameters
|
||||||
|
self.mel_bins = encoder_config.get("mel_bins", 64)
|
||||||
|
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||||
|
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||||
|
|
||||||
|
# Store causality configuration at VAE level (not just in encoder internals)
|
||||||
|
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
|
||||||
|
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
||||||
|
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
||||||
|
|
||||||
|
self.encoder = Encoder(**encoder_config)
|
||||||
|
self.decoder = Decoder(**decoder_config)
|
||||||
|
|
||||||
|
self.per_channel_statistics = processor()
|
||||||
|
|
||||||
|
def _guess_config(self):
|
||||||
|
encoder_config = {
|
||||||
|
# Required parameters - based on ltx-video-av-1679000 model metadata
|
||||||
|
"ch": 128,
|
||||||
|
"out_ch": 8,
|
||||||
|
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
|
||||||
|
"num_res_blocks": 2,
|
||||||
|
"attn_resolutions": [], # Based on metadata: empty list, no attention
|
||||||
|
"dropout": 0.0,
|
||||||
|
"resamp_with_conv": True,
|
||||||
|
"in_channels": 2, # stereo
|
||||||
|
"resolution": 256,
|
||||||
|
"z_channels": 8,
|
||||||
|
"double_z": True,
|
||||||
|
"attn_type": "vanilla",
|
||||||
|
"mid_block_add_attention": False, # Based on metadata: false
|
||||||
|
"norm_type": "pixel",
|
||||||
|
"causality_axis": "height", # Based on metadata
|
||||||
|
"mel_bins": 64, # Based on metadata: mel_bins = 64
|
||||||
|
}
|
||||||
|
|
||||||
|
decoder_config = {
|
||||||
|
# Inherits encoder config, can override specific params
|
||||||
|
**encoder_config,
|
||||||
|
"out_ch": 2, # Stereo audio output (2 channels)
|
||||||
|
"give_pre_end": False,
|
||||||
|
"tanh_out": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"_class_name": "CausalAudioAutoencoder",
|
||||||
|
"sampling_rate": 16000,
|
||||||
|
"model": {
|
||||||
|
"params": {
|
||||||
|
"encoder": encoder_config,
|
||||||
|
"decoder": decoder_config,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {
|
||||||
|
"sampling_rate": self.sampling_rate,
|
||||||
|
"mel_bins": self.mel_bins,
|
||||||
|
"mel_hop_length": self.mel_hop_length,
|
||||||
|
"n_fft": self.n_fft,
|
||||||
|
"causality_axis": self.causality_axis.value,
|
||||||
|
"is_causal": self.is_causal,
|
||||||
|
}
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self.encoder(x)
|
||||||
|
|
||||||
|
def decode(self, x, target_shape=None):
|
||||||
|
return self.decoder(x, target_shape=target_shape)
|
||||||
213
comfy/ldm/lightricks/vocoders/vocoder.py
Normal file
213
comfy/ldm/lightricks/vocoders/vocoder.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn as nn
|
||||||
|
import comfy.ops
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
def get_padding(kernel_size, dilation=1):
|
||||||
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock1(torch.nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||||
|
super(ResBlock1, self).__init__()
|
||||||
|
self.convs1 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[2],
|
||||||
|
padding=get_padding(kernel_size, dilation[2]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.convs2 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c1, c2 in zip(self.convs1, self.convs2):
|
||||||
|
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||||
|
xt = c2(xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock2(torch.nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||||
|
super(ResBlock2, self).__init__()
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c in self.convs:
|
||||||
|
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
xt = c(xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Vocoder(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config=None):
|
||||||
|
super(Vocoder, self).__init__()
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = self.get_default_config()
|
||||||
|
|
||||||
|
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
||||||
|
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
|
||||||
|
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
|
||||||
|
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||||
|
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
||||||
|
stereo = config.get("stereo", True)
|
||||||
|
resblock = config.get("resblock", "1")
|
||||||
|
|
||||||
|
self.output_sample_rate = config.get("output_sample_rate")
|
||||||
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(upsample_rates)
|
||||||
|
in_channels = 128 if stereo else 64
|
||||||
|
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||||
|
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||||
|
|
||||||
|
self.ups = nn.ModuleList()
|
||||||
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
|
self.ups.append(
|
||||||
|
ops.ConvTranspose1d(
|
||||||
|
upsample_initial_channel // (2**i),
|
||||||
|
upsample_initial_channel // (2 ** (i + 1)),
|
||||||
|
k,
|
||||||
|
u,
|
||||||
|
padding=(k - u) // 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||||
|
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||||
|
self.resblocks.append(resblock_class(ch, k, d))
|
||||||
|
|
||||||
|
out_channels = 2 if stereo else 1
|
||||||
|
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
|
||||||
|
|
||||||
|
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
||||||
|
|
||||||
|
def get_default_config(self):
|
||||||
|
"""Generate default configuration for the vocoder."""
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"resblock_kernel_sizes": [3, 7, 11],
|
||||||
|
"upsample_rates": [6, 5, 2, 2, 2],
|
||||||
|
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||||
|
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
"upsample_initial_channel": 1024,
|
||||||
|
"stereo": True,
|
||||||
|
"resblock": "1",
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass of the vocoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input spectrogram tensor. Can be:
|
||||||
|
- 3D: (batch_size, channels, time_steps) for mono
|
||||||
|
- 4D: (batch_size, 2, channels, time_steps) for stereo
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Audio tensor of shape (batch_size, out_channels, audio_length)
|
||||||
|
"""
|
||||||
|
if x.dim() == 4: # stereo
|
||||||
|
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
||||||
|
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
||||||
|
x = self.conv_pre(x)
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
x = self.ups[i](x)
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
x = F.leaky_relu(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
x = torch.tanh(x)
|
||||||
|
|
||||||
|
return x
|
||||||
@ -491,7 +491,8 @@ class NextDiT(nn.Module):
|
|||||||
for layer_id in range(n_layers)
|
for layer_id in range(n_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
# This norm final is in the lumina 2.0 code but isn't actually used for anything.
|
||||||
|
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
||||||
|
|
||||||
if self.pad_tokens_multiple is not None:
|
if self.pad_tokens_multiple is not None:
|
||||||
@ -625,7 +626,7 @@ class NextDiT(nn.Module):
|
|||||||
if pooled is not None:
|
if pooled is not None:
|
||||||
pooled = self.clip_text_pooled_proj(pooled)
|
pooled = self.clip_text_pooled_proj(pooled)
|
||||||
else:
|
else:
|
||||||
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
|
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||||
|
|
||||||
|
|||||||
@ -30,6 +30,13 @@ except ImportError as e:
|
|||||||
raise e
|
raise e
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
|
SAGE_ATTENTION3_IS_AVAILABLE = False
|
||||||
|
try:
|
||||||
|
from sageattn3 import sageattn3_blackwell
|
||||||
|
SAGE_ATTENTION3_IS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_func
|
from flash_attn import flash_attn_func
|
||||||
@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
out = out.reshape(b, -1, heads * dim_head)
|
out = out.reshape(b, -1, heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@wrap_attn
|
||||||
|
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
exception_fallback = False
|
||||||
|
if (q.device.type != "cuda" or
|
||||||
|
q.dtype not in (torch.float16, torch.bfloat16) or
|
||||||
|
mask is not None):
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if skip_reshape:
|
||||||
|
B, H, L, D = q.shape
|
||||||
|
if H != heads:
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
q_s, k_s, v_s = q, k, v
|
||||||
|
N = q.shape[2]
|
||||||
|
dim_head = D
|
||||||
|
else:
|
||||||
|
B, N, inner_dim = q.shape
|
||||||
|
if inner_dim % heads != 0:
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=False,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
dim_head = inner_dim // heads
|
||||||
|
|
||||||
|
if dim_head >= 256 or N <= 1024:
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not skip_reshape:
|
||||||
|
q_s, k_s, v_s = map(
|
||||||
|
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
B, H, L, D = q_s.shape
|
||||||
|
|
||||||
|
try:
|
||||||
|
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
|
||||||
|
except Exception as e:
|
||||||
|
exception_fallback = True
|
||||||
|
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
|
||||||
|
|
||||||
|
if exception_fallback:
|
||||||
|
if not skip_reshape:
|
||||||
|
del q_s, k_s, v_s
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=False,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if skip_reshape:
|
||||||
|
if not skip_output_reshape:
|
||||||
|
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||||
|
else:
|
||||||
|
if skip_output_reshape:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||||
@ -650,6 +744,8 @@ optimized_attention_masked = optimized_attention
|
|||||||
# register core-supported attention functions
|
# register core-supported attention functions
|
||||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||||
register_attention_function("sage", attention_sage)
|
register_attention_function("sage", attention_sage)
|
||||||
|
if SAGE_ATTENTION3_IS_AVAILABLE:
|
||||||
|
register_attention_function("sage3", attention3_sage)
|
||||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||||
register_attention_function("flash", attention_flash)
|
register_attention_function("flash", attention_flash)
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
|
|||||||
@ -394,7 +394,8 @@ class Model(nn.Module):
|
|||||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_linear_attn: attn_type = "linear"
|
if use_linear_attn:
|
||||||
|
attn_type = "linear"
|
||||||
self.ch = ch
|
self.ch = ch
|
||||||
self.temb_ch = self.ch*4
|
self.temb_ch = self.ch*4
|
||||||
self.num_resolutions = len(ch_mult)
|
self.num_resolutions = len(ch_mult)
|
||||||
@ -548,7 +549,8 @@ class Encoder(nn.Module):
|
|||||||
conv3d=False, time_compress=None,
|
conv3d=False, time_compress=None,
|
||||||
**ignore_kwargs):
|
**ignore_kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_linear_attn: attn_type = "linear"
|
if use_linear_attn:
|
||||||
|
attn_type = "linear"
|
||||||
self.ch = ch
|
self.ch = ch
|
||||||
self.temb_ch = 0
|
self.temb_ch = 0
|
||||||
self.num_resolutions = len(ch_mult)
|
self.num_resolutions = len(ch_mult)
|
||||||
|
|||||||
@ -45,7 +45,7 @@ class LitEma(nn.Module):
|
|||||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||||
else:
|
else:
|
||||||
assert not key in self.m_name2s_name
|
assert key not in self.m_name2s_name
|
||||||
|
|
||||||
def copy_to(self, model):
|
def copy_to(self, model):
|
||||||
m_param = dict(model.named_parameters())
|
m_param = dict(model.named_parameters())
|
||||||
@ -54,7 +54,7 @@ class LitEma(nn.Module):
|
|||||||
if m_param[key].requires_grad:
|
if m_param[key].requires_grad:
|
||||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||||
else:
|
else:
|
||||||
assert not key in self.m_name2s_name
|
assert key not in self.m_name2s_name
|
||||||
|
|
||||||
def store(self, parameters):
|
def store(self, parameters):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -71,7 +71,7 @@ def count_params(model, verbose=False):
|
|||||||
|
|
||||||
|
|
||||||
def instantiate_from_config(config):
|
def instantiate_from_config(config):
|
||||||
if not "target" in config:
|
if "target" not in config:
|
||||||
if config == '__is_first_stage__':
|
if config == '__is_first_stage__':
|
||||||
return None
|
return None
|
||||||
elif config == "__is_unconditional__":
|
elif config == "__is_unconditional__":
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1
|
|||||||
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
|
import comfy.ldm.lightricks.av_model
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
from comfy.ldm.cascade.stage_c import StageC
|
from comfy.ldm.cascade.stage_c import StageC
|
||||||
from comfy.ldm.cascade.stage_b import StageB
|
from comfy.ldm.cascade.stage_b import StageB
|
||||||
@ -946,7 +947,7 @@ class GenmoMochi(BaseModel):
|
|||||||
|
|
||||||
class LTXV(BaseModel):
|
class LTXV(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel)
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
@ -977,6 +978,60 @@ class LTXV(BaseModel):
|
|||||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
return latent_image
|
return latent_image
|
||||||
|
|
||||||
|
class LTXAV(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||||
|
|
||||||
|
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
|
|
||||||
|
audio_denoise_mask = None
|
||||||
|
if denoise_mask is not None and "latent_shapes" in kwargs:
|
||||||
|
denoise_mask = utils.unpack_latents(denoise_mask, kwargs["latent_shapes"])
|
||||||
|
if len(denoise_mask) > 1:
|
||||||
|
audio_denoise_mask = denoise_mask[1]
|
||||||
|
denoise_mask = denoise_mask[0]
|
||||||
|
|
||||||
|
if denoise_mask is not None:
|
||||||
|
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||||
|
|
||||||
|
if audio_denoise_mask is not None:
|
||||||
|
out["audio_denoise_mask"] = comfy.conds.CONDRegular(audio_denoise_mask)
|
||||||
|
|
||||||
|
keyframe_idxs = kwargs.get("keyframe_idxs", None)
|
||||||
|
if keyframe_idxs is not None:
|
||||||
|
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||||
|
|
||||||
|
latent_shapes = kwargs.get("latent_shapes", None)
|
||||||
|
if latent_shapes is not None:
|
||||||
|
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||||
|
v_timestep = timestep
|
||||||
|
a_timestep = timestep
|
||||||
|
|
||||||
|
if denoise_mask is not None:
|
||||||
|
v_timestep = self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
|
||||||
|
if audio_denoise_mask is not None:
|
||||||
|
a_timestep = self.diffusion_model.a_patchifier.patchify(((audio_denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (audio_denoise_mask.ndim - 1)))[:, :1, :, :1])[0]
|
||||||
|
|
||||||
|
return v_timestep, a_timestep
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
|
return latent_image
|
||||||
|
|
||||||
class HunyuanVideo(BaseModel):
|
class HunyuanVideo(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||||
@ -1110,7 +1165,7 @@ class Lumina2(BaseModel):
|
|||||||
if 'num_tokens' not in out:
|
if 'num_tokens' not in out:
|
||||||
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
||||||
|
|
||||||
clip_text_pooled = kwargs["pooled_output"] # Newbie
|
clip_text_pooled = kwargs.get("pooled_output", None) # NewBie
|
||||||
if clip_text_pooled is not None:
|
if clip_text_pooled is not None:
|
||||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||||
|
|
||||||
|
|||||||
@ -305,7 +305,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "ltxv"
|
dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv"
|
||||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
|
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
|
||||||
dit_config["attention_head_dim"] = shape[0] // 32
|
dit_config["attention_head_dim"] = shape[0] // 32
|
||||||
@ -430,8 +430,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["rope_theta"] = 10000.0
|
dit_config["rope_theta"] = 10000.0
|
||||||
dit_config["ffn_dim_multiplier"] = 4.0
|
dit_config["ffn_dim_multiplier"] = 4.0
|
||||||
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
||||||
if ctd_weight is not None:
|
if ctd_weight is not None: # NewBie
|
||||||
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
||||||
|
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
|
||||||
elif dit_config["dim"] == 3840: # Z image
|
elif dit_config["dim"] == 3840: # Z image
|
||||||
dit_config["n_heads"] = 30
|
dit_config["n_heads"] = 30
|
||||||
dit_config["n_kv_heads"] = 30
|
dit_config["n_kv_heads"] = 30
|
||||||
|
|||||||
@ -1019,8 +1019,8 @@ NUM_STREAMS = 0
|
|||||||
if args.async_offload is not None:
|
if args.async_offload is not None:
|
||||||
NUM_STREAMS = args.async_offload
|
NUM_STREAMS = args.async_offload
|
||||||
else:
|
else:
|
||||||
# Enable by default on Nvidia
|
# Enable by default on Nvidia and AMD
|
||||||
if is_nvidia():
|
if is_nvidia() or is_amd():
|
||||||
NUM_STREAMS = 2
|
NUM_STREAMS = 2
|
||||||
|
|
||||||
if args.disable_async_offload:
|
if args.disable_async_offload:
|
||||||
@ -1126,6 +1126,16 @@ if not args.disable_pinned_memory:
|
|||||||
|
|
||||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||||
|
|
||||||
|
def discard_cuda_async_error():
|
||||||
|
try:
|
||||||
|
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
|
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
|
_ = a + b
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
except torch.AcceleratorError:
|
||||||
|
#Dump it! We already know about it from the synchronous return
|
||||||
|
pass
|
||||||
|
|
||||||
def pin_memory(tensor):
|
def pin_memory(tensor):
|
||||||
global TOTAL_PINNED_MEMORY
|
global TOTAL_PINNED_MEMORY
|
||||||
if MAX_PINNED_MEMORY <= 0:
|
if MAX_PINNED_MEMORY <= 0:
|
||||||
@ -1158,6 +1168,9 @@ def pin_memory(tensor):
|
|||||||
PINNED_MEMORY[ptr] = size
|
PINNED_MEMORY[ptr] = size
|
||||||
TOTAL_PINNED_MEMORY += size
|
TOTAL_PINNED_MEMORY += size
|
||||||
return True
|
return True
|
||||||
|
else:
|
||||||
|
logging.warning("Pin error.")
|
||||||
|
discard_cuda_async_error()
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -1186,6 +1199,9 @@ def unpin_memory(tensor):
|
|||||||
if len(PINNED_MEMORY) == 0:
|
if len(PINNED_MEMORY) == 0:
|
||||||
TOTAL_PINNED_MEMORY = 0
|
TOTAL_PINNED_MEMORY = 0
|
||||||
return True
|
return True
|
||||||
|
else:
|
||||||
|
logging.warning("Unpin error.")
|
||||||
|
discard_cuda_async_error()
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -1526,6 +1542,10 @@ def soft_empty_cache(force=False):
|
|||||||
def unload_all_models():
|
def unload_all_models():
|
||||||
free_memory(1e30, get_torch_device())
|
free_memory(1e30, get_torch_device())
|
||||||
|
|
||||||
|
def debug_memory_summary():
|
||||||
|
if is_amd() or is_nvidia():
|
||||||
|
return torch.cuda.memory.memory_summary()
|
||||||
|
return ""
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
#TODO: might be cleaner to put this somewhere else
|
||||||
import threading
|
import threading
|
||||||
|
|||||||
27
comfy/sd.py
27
comfy/sd.py
@ -55,6 +55,8 @@ import comfy.text_encoders.hunyuan_image
|
|||||||
import comfy.text_encoders.z_image
|
import comfy.text_encoders.z_image
|
||||||
import comfy.text_encoders.ovis
|
import comfy.text_encoders.ovis
|
||||||
import comfy.text_encoders.kandinsky5
|
import comfy.text_encoders.kandinsky5
|
||||||
|
import comfy.text_encoders.jina_clip_2
|
||||||
|
import comfy.text_encoders.newbie
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -1008,6 +1010,7 @@ class CLIPType(Enum):
|
|||||||
OVIS = 21
|
OVIS = 21
|
||||||
KANDINSKY5 = 22
|
KANDINSKY5 = 22
|
||||||
KANDINSKY5_IMAGE = 23
|
KANDINSKY5_IMAGE = 23
|
||||||
|
NEWBIE = 24
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
@ -1038,6 +1041,8 @@ class TEModel(Enum):
|
|||||||
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
||||||
QWEN3_4B = 16
|
QWEN3_4B = 16
|
||||||
QWEN3_2B = 17
|
QWEN3_2B = 17
|
||||||
|
GEMMA_3_12B = 18
|
||||||
|
JINA_CLIP_2 = 19
|
||||||
|
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
@ -1047,6 +1052,8 @@ def detect_te_model(sd):
|
|||||||
return TEModel.CLIP_H
|
return TEModel.CLIP_H
|
||||||
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
|
||||||
return TEModel.CLIP_L
|
return TEModel.CLIP_L
|
||||||
|
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
|
||||||
|
return TEModel.JINA_CLIP_2
|
||||||
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||||
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||||
if weight.shape[-1] == 4096:
|
if weight.shape[-1] == 4096:
|
||||||
@ -1061,6 +1068,8 @@ def detect_te_model(sd):
|
|||||||
return TEModel.BYT5_SMALL_GLYPH
|
return TEModel.BYT5_SMALL_GLYPH
|
||||||
return TEModel.T5_BASE
|
return TEModel.T5_BASE
|
||||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||||
|
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||||
|
return TEModel.GEMMA_3_12B
|
||||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||||
return TEModel.GEMMA_3_4B
|
return TEModel.GEMMA_3_4B
|
||||||
return TEModel.GEMMA_2_2B
|
return TEModel.GEMMA_2_2B
|
||||||
@ -1207,6 +1216,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif te_model == TEModel.QWEN3_2B:
|
elif te_model == TEModel.QWEN3_2B:
|
||||||
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
||||||
|
elif te_model == TEModel.JINA_CLIP_2:
|
||||||
|
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||||
else:
|
else:
|
||||||
# clip_l
|
# clip_l
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
@ -1262,6 +1274,21 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
|
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
|
||||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||||
|
elif clip_type == CLIPType.LTXV:
|
||||||
|
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
|
||||||
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
|
elif clip_type == CLIPType.NEWBIE:
|
||||||
|
clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer
|
||||||
|
if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]:
|
||||||
|
clip_data_gemma = clip_data[0]
|
||||||
|
clip_data_jina = clip_data[1]
|
||||||
|
else:
|
||||||
|
clip_data_gemma = clip_data[1]
|
||||||
|
clip_data_jina = clip_data[0]
|
||||||
|
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
|
||||||
|
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
|
|||||||
@ -466,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
return embed_out
|
return embed_out
|
||||||
|
|
||||||
class SDTokenizer:
|
class SDTokenizer:
|
||||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}):
|
||||||
if tokenizer_path is None:
|
if tokenizer_path is None:
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||||
@ -513,6 +513,8 @@ class SDTokenizer:
|
|||||||
self.embedding_size = embedding_size
|
self.embedding_size = embedding_size
|
||||||
self.embedding_key = embedding_key
|
self.embedding_key = embedding_key
|
||||||
|
|
||||||
|
self.disable_weights = disable_weights
|
||||||
|
|
||||||
def _try_get_embedding(self, embedding_name:str):
|
def _try_get_embedding(self, embedding_name:str):
|
||||||
'''
|
'''
|
||||||
Takes a potential embedding name and tries to retrieve it.
|
Takes a potential embedding name and tries to retrieve it.
|
||||||
@ -547,7 +549,7 @@ class SDTokenizer:
|
|||||||
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
||||||
|
|
||||||
text = escape_important(text)
|
text = escape_important(text)
|
||||||
if kwargs.get("disable_weights", False):
|
if kwargs.get("disable_weights", self.disable_weights):
|
||||||
parsed_weights = [(text, 1.0)]
|
parsed_weights = [(text, 1.0)]
|
||||||
else:
|
else:
|
||||||
parsed_weights = token_weights(text, 1.0)
|
parsed_weights = token_weights(text, 1.0)
|
||||||
|
|||||||
@ -836,6 +836,21 @@ class LTXV(supported_models_base.BASE):
|
|||||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
||||||
|
|
||||||
|
class LTXAV(LTXV):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "ltxav",
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.LTXAV
|
||||||
|
|
||||||
|
def __init__(self, unet_config):
|
||||||
|
super().__init__(unet_config)
|
||||||
|
self.memory_usage_factor = 0.055 # TODO
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.LTXAV(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class HunyuanVideo(supported_models_base.BASE):
|
class HunyuanVideo(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan_video",
|
"image_model": "hunyuan_video",
|
||||||
@ -1536,6 +1551,6 @@ class Kandinsky5Image(Kandinsky5):
|
|||||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -154,7 +154,8 @@ class TAEHV(nn.Module):
|
|||||||
self._show_progress_bar = value
|
self._show_progress_bar = value
|
||||||
|
|
||||||
def encode(self, x, **kwargs):
|
def encode(self, x, **kwargs):
|
||||||
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
|
if self.patch_size > 1:
|
||||||
|
x = F.pixel_unshuffle(x, self.patch_size)
|
||||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||||
if x.shape[1] % 4 != 0:
|
if x.shape[1] % 4 != 0:
|
||||||
# pad at end to multiple of 4
|
# pad at end to multiple of 4
|
||||||
@ -167,5 +168,6 @@ class TAEHV(nn.Module):
|
|||||||
def decode(self, x, **kwargs):
|
def decode(self, x, **kwargs):
|
||||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||||
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
|
if self.patch_size > 1:
|
||||||
|
x = F.pixel_shuffle(x, self.patch_size)
|
||||||
return x[:, self.frames_to_trim:].movedim(2, 1)
|
return x[:, self.frames_to_trim:].movedim(2, 1)
|
||||||
|
|||||||
219
comfy/text_encoders/jina_clip_2.py
Normal file
219
comfy/text_encoders/jina_clip_2.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation:
|
||||||
|
# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py
|
||||||
|
# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.ops
|
||||||
|
from comfy import sd1_clip
|
||||||
|
from .spiece_tokenizer import SPieceTokenizer
|
||||||
|
|
||||||
|
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
|
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
|
||||||
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
|
||||||
|
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
|
||||||
|
|
||||||
|
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
|
||||||
|
@dataclass
|
||||||
|
class XLMRobertaConfig:
|
||||||
|
vocab_size: int = 250002
|
||||||
|
type_vocab_size: int = 1
|
||||||
|
hidden_size: int = 1024
|
||||||
|
num_hidden_layers: int = 24
|
||||||
|
num_attention_heads: int = 16
|
||||||
|
rotary_emb_base: float = 20000.0
|
||||||
|
intermediate_size: int = 4096
|
||||||
|
hidden_act: str = "gelu"
|
||||||
|
hidden_dropout_prob: float = 0.1
|
||||||
|
attention_probs_dropout_prob: float = 0.1
|
||||||
|
layer_norm_eps: float = 1e-05
|
||||||
|
bos_token_id: int = 0
|
||||||
|
eos_token_id: int = 2
|
||||||
|
pad_token_id: int = 1
|
||||||
|
|
||||||
|
class XLMRobertaEmbeddings(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
embed_dim = config.hidden_size
|
||||||
|
self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype)
|
||||||
|
self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, input_ids=None, embeddings=None):
|
||||||
|
if input_ids is not None and embeddings is None:
|
||||||
|
embeddings = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
if embeddings is not None:
|
||||||
|
token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32)
|
||||||
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||||
|
embeddings = embeddings + token_type_embeddings
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, base, device=None):
|
||||||
|
super().__init__()
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
self._seq_len_cached = 0
|
||||||
|
self._cos_cached = None
|
||||||
|
self._sin_cached = None
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
||||||
|
if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
||||||
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
self._cos_cached = emb.cos().to(dtype)
|
||||||
|
self._sin_cached = emb.sin().to(dtype)
|
||||||
|
|
||||||
|
def forward(self, q, k):
|
||||||
|
batch, seqlen, heads, head_dim = q.shape
|
||||||
|
self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
|
cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim)
|
||||||
|
sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim)
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
size = x.shape[-1] // 2
|
||||||
|
x1, x2 = x[..., :size], x[..., size:]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
class MHA(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = embed_dim // config.num_attention_heads
|
||||||
|
|
||||||
|
self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device)
|
||||||
|
self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
|
||||||
|
self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, optimized_attention=None):
|
||||||
|
qkv = self.Wqkv(x)
|
||||||
|
batch_size, seq_len, _ = qkv.shape
|
||||||
|
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||||
|
q, k, v = qkv.unbind(2)
|
||||||
|
|
||||||
|
q, k = self.rotary_emb(q, k)
|
||||||
|
|
||||||
|
# NHD -> HND
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
|
||||||
|
return self.out_proj(out)
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
|
||||||
|
self.activation = F.gelu
|
||||||
|
self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.mixer = MHA(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||||
|
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, mask=None, optimized_attention=None):
|
||||||
|
mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention)
|
||||||
|
hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states)
|
||||||
|
mlp_out = self.mlp(hidden_states)
|
||||||
|
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class XLMRobertaEncoder(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attention_mask=None):
|
||||||
|
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class XLMRobertaModel_(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||||
|
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||||
|
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
|
||||||
|
x = self.emb_ln(x)
|
||||||
|
x = self.emb_drop(x)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if attention_mask is not None:
|
||||||
|
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1]))
|
||||||
|
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
||||||
|
|
||||||
|
sequence_output = self.encoder(x, attention_mask=mask)
|
||||||
|
|
||||||
|
# Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py
|
||||||
|
pooled_output = None
|
||||||
|
if attention_mask is None:
|
||||||
|
pooled_output = sequence_output.mean(dim=1)
|
||||||
|
else:
|
||||||
|
attention_mask = attention_mask.to(sequence_output.dtype)
|
||||||
|
pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Intermediate output is not yet implemented, use None for placeholder
|
||||||
|
return sequence_output, None, pooled_output
|
||||||
|
|
||||||
|
class XLMRobertaModel(nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.config = XLMRobertaConfig(**config_dict)
|
||||||
|
self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.num_layers = self.config.num_hidden_layers
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embeddings.word_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, embeddings):
|
||||||
|
self.model.embeddings.word_embeddings = embeddings
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.model(*args, **kwargs)
|
||||||
|
|
||||||
|
class JinaClip2TextModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||||
|
|
||||||
|
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)
|
||||||
@ -3,13 +3,12 @@ import torch.nn as nn
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
import math
|
import math
|
||||||
import logging
|
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.clip_model
|
||||||
|
|
||||||
import comfy.model_management
|
|
||||||
from . import qwen_vl
|
from . import qwen_vl
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -177,7 +176,7 @@ class Gemma3_4B_Config:
|
|||||||
num_key_value_heads: int = 4
|
num_key_value_heads: int = 4
|
||||||
max_position_embeddings: int = 131072
|
max_position_embeddings: int = 131072
|
||||||
rms_norm_eps: float = 1e-6
|
rms_norm_eps: float = 1e-6
|
||||||
rope_theta = [10000.0, 1000000.0]
|
rope_theta = [1000000.0, 10000.0]
|
||||||
transformer_type: str = "gemma3"
|
transformer_type: str = "gemma3"
|
||||||
head_dim = 256
|
head_dim = 256
|
||||||
rms_norm_add = True
|
rms_norm_add = True
|
||||||
@ -186,10 +185,35 @@ class Gemma3_4B_Config:
|
|||||||
rope_dims = None
|
rope_dims = None
|
||||||
q_norm = "gemma3"
|
q_norm = "gemma3"
|
||||||
k_norm = "gemma3"
|
k_norm = "gemma3"
|
||||||
sliding_attention = [False, False, False, False, False, 1024]
|
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||||
rope_scale = [1.0, 8.0]
|
rope_scale = [8.0, 1.0]
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Gemma3_12B_Config:
|
||||||
|
vocab_size: int = 262208
|
||||||
|
hidden_size: int = 3840
|
||||||
|
intermediate_size: int = 15360
|
||||||
|
num_hidden_layers: int = 48
|
||||||
|
num_attention_heads: int = 16
|
||||||
|
num_key_value_heads: int = 8
|
||||||
|
max_position_embeddings: int = 131072
|
||||||
|
rms_norm_eps: float = 1e-6
|
||||||
|
rope_theta = [1000000.0, 10000.0]
|
||||||
|
transformer_type: str = "gemma3"
|
||||||
|
head_dim = 256
|
||||||
|
rms_norm_add = True
|
||||||
|
mlp_activation = "gelu_pytorch_tanh"
|
||||||
|
qkv_bias = False
|
||||||
|
rope_dims = None
|
||||||
|
q_norm = "gemma3"
|
||||||
|
k_norm = "gemma3"
|
||||||
|
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||||
|
rope_scale = [8.0, 1.0]
|
||||||
|
final_norm: bool = True
|
||||||
|
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
||||||
|
mm_tokens_per_image = 256
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -370,7 +394,7 @@ class TransformerBlockGemma2(nn.Module):
|
|||||||
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
|
||||||
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
|
if config.sliding_attention is not None:
|
||||||
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
|
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
|
||||||
else:
|
else:
|
||||||
self.sliding_attention = False
|
self.sliding_attention = False
|
||||||
@ -387,7 +411,12 @@ class TransformerBlockGemma2(nn.Module):
|
|||||||
if self.transformer_type == 'gemma3':
|
if self.transformer_type == 'gemma3':
|
||||||
if self.sliding_attention:
|
if self.sliding_attention:
|
||||||
if x.shape[1] > self.sliding_attention:
|
if x.shape[1] > self.sliding_attention:
|
||||||
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
|
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
|
||||||
|
sliding_mask.tril_(diagonal=-self.sliding_attention)
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask + sliding_mask
|
||||||
|
else:
|
||||||
|
attention_mask = sliding_mask
|
||||||
freqs_cis = freqs_cis[1]
|
freqs_cis = freqs_cis[1]
|
||||||
else:
|
else:
|
||||||
freqs_cis = freqs_cis[0]
|
freqs_cis = freqs_cis[0]
|
||||||
@ -517,6 +546,41 @@ class Llama2_(nn.Module):
|
|||||||
|
|
||||||
return x, intermediate
|
return x, intermediate
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3MultiModalProjector(torch.nn.Module):
|
||||||
|
def __init__(self, config, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.mm_input_projection_weight = nn.Parameter(
|
||||||
|
torch.empty(config.vision_config["hidden_size"], config.hidden_size, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mm_soft_emb_norm = RMSNorm(config.vision_config["hidden_size"], eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.patches_per_image = int(config.vision_config["image_size"] // config.vision_config["patch_size"])
|
||||||
|
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
|
||||||
|
self.kernel_size = self.patches_per_image // self.tokens_per_side
|
||||||
|
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
|
||||||
|
|
||||||
|
def forward(self, vision_outputs: torch.Tensor):
|
||||||
|
batch_size, _, seq_length = vision_outputs.shape
|
||||||
|
|
||||||
|
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
|
||||||
|
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
|
||||||
|
batch_size, seq_length, self.patches_per_image, self.patches_per_image
|
||||||
|
)
|
||||||
|
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
|
||||||
|
|
||||||
|
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
|
||||||
|
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
|
||||||
|
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
|
||||||
|
|
||||||
|
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
|
||||||
|
|
||||||
|
projected_vision_outputs = torch.matmul(normed_vision_outputs, comfy.model_management.cast_to_device(self.mm_input_projection_weight, device=normed_vision_outputs.device, dtype=normed_vision_outputs.dtype))
|
||||||
|
return projected_vision_outputs.type_as(vision_outputs)
|
||||||
|
|
||||||
|
|
||||||
class BaseLlama:
|
class BaseLlama:
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.model.embed_tokens
|
return self.model.embed_tokens
|
||||||
@ -633,3 +697,21 @@ class Gemma3_4B(BaseLlama, torch.nn.Module):
|
|||||||
|
|
||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
class Gemma3_12B(BaseLlama, torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
config = Gemma3_12B_Config(**config_dict)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
|
||||||
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
|
||||||
|
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.image_size = config.vision_config["image_size"]
|
||||||
|
|
||||||
|
def preprocess_embed(self, embed, device):
|
||||||
|
if embed["type"] == "image":
|
||||||
|
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
|
||||||
|
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
|
||||||
|
return None, None
|
||||||
|
|||||||
@ -1,7 +1,11 @@
|
|||||||
from comfy import sd1_clip
|
from comfy import sd1_clip
|
||||||
import os
|
import os
|
||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
|
from .spiece_tokenizer import SPieceTokenizer
|
||||||
import comfy.text_encoders.genmo
|
import comfy.text_encoders.genmo
|
||||||
|
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||||
|
import torch
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -16,3 +20,110 @@ class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
|
|
||||||
def ltxv_te(*args, **kwargs):
|
def ltxv_te(*args, **kwargs):
|
||||||
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
|
||||||
|
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
|
||||||
|
|
||||||
|
class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
llama_scaled_fp8 = model_options.get("gemma_scaled_fp8", None)
|
||||||
|
if llama_scaled_fp8 is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||||
|
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs):
|
||||||
|
text = llama_template.format(text)
|
||||||
|
text_tokens = super().tokenize_with_weights(text, return_word_ids)
|
||||||
|
embed_count = 0
|
||||||
|
for k in text_tokens:
|
||||||
|
tt = text_tokens[k]
|
||||||
|
for r in tt:
|
||||||
|
for i in range(len(r)):
|
||||||
|
if r[i][0] == 262144:
|
||||||
|
if image_embeds is not None and embed_count < image_embeds.shape[0]:
|
||||||
|
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
|
||||||
|
embed_count += 1
|
||||||
|
return text_tokens
|
||||||
|
|
||||||
|
class LTXAVTEModel(torch.nn.Module):
|
||||||
|
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__()
|
||||||
|
self.dtypes = set()
|
||||||
|
self.dtypes.add(dtype)
|
||||||
|
|
||||||
|
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
||||||
|
self.dtypes.add(dtype_llama)
|
||||||
|
|
||||||
|
operations = self.gemma3_12b.operations # TODO
|
||||||
|
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||||
|
split_rope=True,
|
||||||
|
double_precision_rope=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.video_embeddings_connector = Embeddings1DConnector(
|
||||||
|
split_rope=True,
|
||||||
|
double_precision_rope=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_clip_options(self, options):
|
||||||
|
self.gemma3_12b.set_clip_options(options)
|
||||||
|
|
||||||
|
def reset_clip_options(self):
|
||||||
|
self.gemma3_12b.reset_clip_options()
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
||||||
|
|
||||||
|
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
||||||
|
out_device = out.device
|
||||||
|
out = out.movedim(1, -1).to(self.text_embedding_projection.weight.device)
|
||||||
|
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
||||||
|
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||||
|
out = self.text_embedding_projection(out)
|
||||||
|
out_vid = self.video_embeddings_connector(out)[0]
|
||||||
|
out_audio = self.audio_embeddings_connector(out)[0]
|
||||||
|
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||||
|
|
||||||
|
return out.to(out_device), pooled
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||||
|
return self.gemma3_12b.load_sd(sd)
|
||||||
|
else:
|
||||||
|
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
|
||||||
|
if len(sdo) == 0:
|
||||||
|
sdo = sd
|
||||||
|
|
||||||
|
return self.load_state_dict(sdo, strict=False)
|
||||||
|
|
||||||
|
|
||||||
|
def ltxav_te(dtype_llama=None, llama_scaled_fp8=None):
|
||||||
|
class LTXAVTEModel_(LTXAVTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return LTXAVTEModel_
|
||||||
|
|||||||
@ -14,7 +14,7 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
|||||||
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
|
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
@ -33,6 +33,11 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
|
|||||||
|
|
||||||
class Gemma3_4BModel(sd1_clip.SDClipModel):
|
class Gemma3_4BModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||||
|
|||||||
62
comfy/text_encoders/newbie.py
Normal file
62
comfy/text_encoders/newbie.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.text_encoders.jina_clip_2
|
||||||
|
import comfy.text_encoders.lumina2
|
||||||
|
|
||||||
|
class NewBieTokenizer:
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
|
||||||
|
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
|
out = {}
|
||||||
|
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
|
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def untokenize(self, token_weight_pair):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
class NewBieTEModel(torch.nn.Module):
|
||||||
|
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__()
|
||||||
|
dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device)
|
||||||
|
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
|
||||||
|
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
self.dtypes = {dtype, dtype_gemma}
|
||||||
|
|
||||||
|
def set_clip_options(self, options):
|
||||||
|
self.gemma.set_clip_options(options)
|
||||||
|
self.jina.set_clip_options(options)
|
||||||
|
|
||||||
|
def reset_clip_options(self):
|
||||||
|
self.gemma.reset_clip_options()
|
||||||
|
self.jina.reset_clip_options()
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
token_weight_pairs_gemma = token_weight_pairs["gemma"]
|
||||||
|
token_weight_pairs_jina = token_weight_pairs["jina"]
|
||||||
|
|
||||||
|
gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma)
|
||||||
|
jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina)
|
||||||
|
|
||||||
|
return gemma_out, jina_pooled, gemma_extra
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "model.layers.0.self_attn.q_norm.weight" in sd:
|
||||||
|
return self.gemma.load_sd(sd)
|
||||||
|
else:
|
||||||
|
return self.jina.load_sd(sd)
|
||||||
|
|
||||||
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
class NewBieTEModel_(NewBieTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
|
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return NewBieTEModel_
|
||||||
@ -1198,7 +1198,7 @@ def unpack_latents(combined_latent, latent_shapes):
|
|||||||
combined_latent = combined_latent[:, :, cut:]
|
combined_latent = combined_latent[:, :, cut:]
|
||||||
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
|
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
|
||||||
else:
|
else:
|
||||||
output_tensors = combined_latent
|
output_tensors = [combined_latent]
|
||||||
return output_tensors
|
return output_tensors
|
||||||
|
|
||||||
def detect_layer_quantization(state_dict, prefix):
|
def detect_layer_quantization(state_dict, prefix):
|
||||||
@ -1230,6 +1230,8 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
|||||||
out_sd = {}
|
out_sd = {}
|
||||||
layers = {}
|
layers = {}
|
||||||
for k in list(state_dict.keys()):
|
for k in list(state_dict.keys()):
|
||||||
|
if k == scaled_fp8_key:
|
||||||
|
continue
|
||||||
if not k.startswith(model_prefix):
|
if not k.startswith(model_prefix):
|
||||||
out_sd[k] = state_dict[k]
|
out_sd[k] = state_dict[k]
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -10,7 +10,6 @@ from ._input_impl import VideoFromFile, VideoFromComponents
|
|||||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||||
from . import _io_public as io
|
from . import _io_public as io
|
||||||
from . import _ui_public as ui
|
from . import _ui_public as ui
|
||||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
|
||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|||||||
@ -26,11 +26,9 @@ if TYPE_CHECKING:
|
|||||||
from comfy_api.input import VideoInput
|
from comfy_api.input import VideoInput
|
||||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||||
prune_dict, shallow_clone_class)
|
prune_dict, shallow_clone_class)
|
||||||
from ._resources import Resources, ResourcesLocal
|
|
||||||
from comfy_execution.graph_utils import ExecutionBlocker
|
from comfy_execution.graph_utils import ExecutionBlocker
|
||||||
from ._util import MESH, VOXEL
|
from ._util import MESH, VOXEL, SVG as _SVG
|
||||||
|
|
||||||
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
|
|
||||||
|
|
||||||
class FolderType(str, Enum):
|
class FolderType(str, Enum):
|
||||||
input = "input"
|
input = "input"
|
||||||
@ -77,16 +75,6 @@ class NumberDisplay(str, Enum):
|
|||||||
slider = "slider"
|
slider = "slider"
|
||||||
|
|
||||||
|
|
||||||
class _StringIOType(str):
|
|
||||||
def __ne__(self, value: object) -> bool:
|
|
||||||
if self == "*" or value == "*":
|
|
||||||
return False
|
|
||||||
if not isinstance(value, str):
|
|
||||||
return True
|
|
||||||
a = frozenset(self.split(","))
|
|
||||||
b = frozenset(value.split(","))
|
|
||||||
return not (b.issubset(a) or a.issubset(b))
|
|
||||||
|
|
||||||
class _ComfyType(ABC):
|
class _ComfyType(ABC):
|
||||||
Type = Any
|
Type = Any
|
||||||
io_type: str = None
|
io_type: str = None
|
||||||
@ -126,8 +114,7 @@ def comfytype(io_type: str, **kwargs):
|
|||||||
new_cls.__module__ = cls.__module__
|
new_cls.__module__ = cls.__module__
|
||||||
new_cls.__doc__ = cls.__doc__
|
new_cls.__doc__ = cls.__doc__
|
||||||
# assign ComfyType attributes, if needed
|
# assign ComfyType attributes, if needed
|
||||||
# NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details)
|
new_cls.io_type = io_type
|
||||||
new_cls.io_type = _StringIOType(io_type)
|
|
||||||
if hasattr(new_cls, "Input") and new_cls.Input is not None:
|
if hasattr(new_cls, "Input") and new_cls.Input is not None:
|
||||||
new_cls.Input.Parent = new_cls
|
new_cls.Input.Parent = new_cls
|
||||||
if hasattr(new_cls, "Output") and new_cls.Output is not None:
|
if hasattr(new_cls, "Output") and new_cls.Output is not None:
|
||||||
@ -166,7 +153,7 @@ class Input(_IO_V3):
|
|||||||
'''
|
'''
|
||||||
Base class for a V3 Input.
|
Base class for a V3 Input.
|
||||||
'''
|
'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.id = id
|
self.id = id
|
||||||
self.display_name = display_name
|
self.display_name = display_name
|
||||||
@ -174,6 +161,7 @@ class Input(_IO_V3):
|
|||||||
self.tooltip = tooltip
|
self.tooltip = tooltip
|
||||||
self.lazy = lazy
|
self.lazy = lazy
|
||||||
self.extra_dict = extra_dict if extra_dict is not None else {}
|
self.extra_dict = extra_dict if extra_dict is not None else {}
|
||||||
|
self.rawLink = raw_link
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return prune_dict({
|
return prune_dict({
|
||||||
@ -181,10 +169,11 @@ class Input(_IO_V3):
|
|||||||
"optional": self.optional,
|
"optional": self.optional,
|
||||||
"tooltip": self.tooltip,
|
"tooltip": self.tooltip,
|
||||||
"lazy": self.lazy,
|
"lazy": self.lazy,
|
||||||
|
"rawLink": self.rawLink,
|
||||||
}) | prune_dict(self.extra_dict)
|
}) | prune_dict(self.extra_dict)
|
||||||
|
|
||||||
def get_io_type(self):
|
def get_io_type(self):
|
||||||
return _StringIOType(self.io_type)
|
return self.io_type
|
||||||
|
|
||||||
def get_all(self) -> list[Input]:
|
def get_all(self) -> list[Input]:
|
||||||
return [self]
|
return [self]
|
||||||
@ -195,8 +184,8 @@ class WidgetInput(Input):
|
|||||||
'''
|
'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: Any=None,
|
default: Any=None,
|
||||||
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None):
|
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||||
self.default = default
|
self.default = default
|
||||||
self.socketless = socketless
|
self.socketless = socketless
|
||||||
self.widget_type = widget_type
|
self.widget_type = widget_type
|
||||||
@ -218,13 +207,14 @@ class Output(_IO_V3):
|
|||||||
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
|
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
|
||||||
is_output_list=False):
|
is_output_list=False):
|
||||||
self.id = id
|
self.id = id
|
||||||
self.display_name = display_name
|
self.display_name = display_name if display_name else id
|
||||||
self.tooltip = tooltip
|
self.tooltip = tooltip
|
||||||
self.is_output_list = is_output_list
|
self.is_output_list = is_output_list
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
|
display_name = self.display_name if self.display_name else self.id
|
||||||
return prune_dict({
|
return prune_dict({
|
||||||
"display_name": self.display_name,
|
"display_name": display_name,
|
||||||
"tooltip": self.tooltip,
|
"tooltip": self.tooltip,
|
||||||
"is_output_list": self.is_output_list,
|
"is_output_list": self.is_output_list,
|
||||||
})
|
})
|
||||||
@ -252,8 +242,8 @@ class Boolean(ComfyTypeIO):
|
|||||||
'''Boolean input.'''
|
'''Boolean input.'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: bool=None, label_on: str=None, label_off: str=None,
|
default: bool=None, label_on: str=None, label_off: str=None,
|
||||||
socketless: bool=None, force_input: bool=None):
|
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||||
self.label_on = label_on
|
self.label_on = label_on
|
||||||
self.label_off = label_off
|
self.label_off = label_off
|
||||||
self.default: bool
|
self.default: bool
|
||||||
@ -272,8 +262,8 @@ class Int(ComfyTypeIO):
|
|||||||
'''Integer input.'''
|
'''Integer input.'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
||||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None):
|
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||||
self.min = min
|
self.min = min
|
||||||
self.max = max
|
self.max = max
|
||||||
self.step = step
|
self.step = step
|
||||||
@ -298,8 +288,8 @@ class Float(ComfyTypeIO):
|
|||||||
'''Float input.'''
|
'''Float input.'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
||||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None):
|
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||||
self.min = min
|
self.min = min
|
||||||
self.max = max
|
self.max = max
|
||||||
self.step = step
|
self.step = step
|
||||||
@ -324,8 +314,8 @@ class String(ComfyTypeIO):
|
|||||||
'''String input.'''
|
'''String input.'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
|
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
|
||||||
socketless: bool=None, force_input: bool=None):
|
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||||
self.multiline = multiline
|
self.multiline = multiline
|
||||||
self.placeholder = placeholder
|
self.placeholder = placeholder
|
||||||
self.dynamic_prompts = dynamic_prompts
|
self.dynamic_prompts = dynamic_prompts
|
||||||
@ -358,12 +348,14 @@ class Combo(ComfyTypeIO):
|
|||||||
image_folder: FolderType=None,
|
image_folder: FolderType=None,
|
||||||
remote: RemoteOptions=None,
|
remote: RemoteOptions=None,
|
||||||
socketless: bool=None,
|
socketless: bool=None,
|
||||||
|
extra_dict=None,
|
||||||
|
raw_link: bool=None,
|
||||||
):
|
):
|
||||||
if isinstance(options, type) and issubclass(options, Enum):
|
if isinstance(options, type) and issubclass(options, Enum):
|
||||||
options = [v.value for v in options]
|
options = [v.value for v in options]
|
||||||
if isinstance(default, Enum):
|
if isinstance(default, Enum):
|
||||||
default = default.value
|
default = default.value
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
|
||||||
self.multiselect = False
|
self.multiselect = False
|
||||||
self.options = options
|
self.options = options
|
||||||
self.control_after_generate = control_after_generate
|
self.control_after_generate = control_after_generate
|
||||||
@ -387,10 +379,6 @@ class Combo(ComfyTypeIO):
|
|||||||
super().__init__(id, display_name, tooltip, is_output_list)
|
super().__init__(id, display_name, tooltip, is_output_list)
|
||||||
self.options = options if options is not None else []
|
self.options = options if options is not None else []
|
||||||
|
|
||||||
@property
|
|
||||||
def io_type(self):
|
|
||||||
return self.options
|
|
||||||
|
|
||||||
@comfytype(io_type="COMBO")
|
@comfytype(io_type="COMBO")
|
||||||
class MultiCombo(ComfyTypeI):
|
class MultiCombo(ComfyTypeI):
|
||||||
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
|
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
|
||||||
@ -399,8 +387,8 @@ class MultiCombo(ComfyTypeI):
|
|||||||
class Input(Combo.Input):
|
class Input(Combo.Input):
|
||||||
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
||||||
socketless: bool=None):
|
socketless: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless)
|
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link)
|
||||||
self.multiselect = True
|
self.multiselect = True
|
||||||
self.placeholder = placeholder
|
self.placeholder = placeholder
|
||||||
self.chip = chip
|
self.chip = chip
|
||||||
@ -433,9 +421,9 @@ class Webcam(ComfyTypeIO):
|
|||||||
Type = str
|
Type = str
|
||||||
def __init__(
|
def __init__(
|
||||||
self, id: str, display_name: str=None, optional=False,
|
self, id: str, display_name: str=None, optional=False,
|
||||||
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None
|
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None
|
||||||
):
|
):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
|
||||||
|
|
||||||
|
|
||||||
@comfytype(io_type="MASK")
|
@comfytype(io_type="MASK")
|
||||||
@ -656,7 +644,7 @@ class Video(ComfyTypeIO):
|
|||||||
|
|
||||||
@comfytype(io_type="SVG")
|
@comfytype(io_type="SVG")
|
||||||
class SVG(ComfyTypeIO):
|
class SVG(ComfyTypeIO):
|
||||||
Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3
|
Type = _SVG
|
||||||
|
|
||||||
@comfytype(io_type="LORA_MODEL")
|
@comfytype(io_type="LORA_MODEL")
|
||||||
class LoraModel(ComfyTypeIO):
|
class LoraModel(ComfyTypeIO):
|
||||||
@ -788,7 +776,7 @@ class MultiType:
|
|||||||
'''
|
'''
|
||||||
Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values.
|
Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values.
|
||||||
'''
|
'''
|
||||||
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
# if id is an Input, then use that Input with overridden values
|
# if id is an Input, then use that Input with overridden values
|
||||||
self.input_override = None
|
self.input_override = None
|
||||||
if isinstance(id, Input):
|
if isinstance(id, Input):
|
||||||
@ -801,7 +789,7 @@ class MultiType:
|
|||||||
# if is a widget input, make sure widget_type is set appropriately
|
# if is a widget input, make sure widget_type is set appropriately
|
||||||
if isinstance(self.input_override, WidgetInput):
|
if isinstance(self.input_override, WidgetInput):
|
||||||
self.input_override.widget_type = self.input_override.get_io_type()
|
self.input_override.widget_type = self.input_override.get_io_type()
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||||
self._io_types = types
|
self._io_types = types
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -855,8 +843,8 @@ class MatchType(ComfyTypeIO):
|
|||||||
|
|
||||||
class Input(Input):
|
class Input(Input):
|
||||||
def __init__(self, id: str, template: MatchType.Template,
|
def __init__(self, id: str, template: MatchType.Template,
|
||||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||||
self.template = template
|
self.template = template
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
@ -867,6 +855,8 @@ class MatchType(ComfyTypeIO):
|
|||||||
class Output(Output):
|
class Output(Output):
|
||||||
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
|
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
|
||||||
is_output_list=False):
|
is_output_list=False):
|
||||||
|
if not id and not display_name:
|
||||||
|
display_name = "MATCHTYPE"
|
||||||
super().__init__(id, display_name, tooltip, is_output_list)
|
super().__init__(id, display_name, tooltip, is_output_list)
|
||||||
self.template = template
|
self.template = template
|
||||||
|
|
||||||
@ -879,24 +869,30 @@ class DynamicInput(Input, ABC):
|
|||||||
'''
|
'''
|
||||||
Abstract class for dynamic input registration.
|
Abstract class for dynamic input registration.
|
||||||
'''
|
'''
|
||||||
def get_dynamic(self) -> list[Input]:
|
pass
|
||||||
return []
|
|
||||||
|
|
||||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicOutput(Output, ABC):
|
class DynamicOutput(Output, ABC):
|
||||||
'''
|
'''
|
||||||
Abstract class for dynamic output registration.
|
Abstract class for dynamic output registration.
|
||||||
'''
|
'''
|
||||||
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
|
pass
|
||||||
is_output_list=False):
|
|
||||||
super().__init__(id, display_name, tooltip, is_output_list)
|
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Output]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
def handle_prefix(prefix_list: list[str] | None, id: str | None = None) -> list[str]:
|
||||||
|
if prefix_list is None:
|
||||||
|
prefix_list = []
|
||||||
|
if id is not None:
|
||||||
|
prefix_list = prefix_list + [id]
|
||||||
|
return prefix_list
|
||||||
|
|
||||||
|
def finalize_prefix(prefix_list: list[str] | None, id: str | None = None) -> str:
|
||||||
|
assert not (prefix_list is None and id is None)
|
||||||
|
if prefix_list is None:
|
||||||
|
return id
|
||||||
|
elif id is not None:
|
||||||
|
prefix_list = prefix_list + [id]
|
||||||
|
return ".".join(prefix_list)
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_AUTOGROW_V3")
|
@comfytype(io_type="COMFY_AUTOGROW_V3")
|
||||||
class Autogrow(ComfyTypeI):
|
class Autogrow(ComfyTypeI):
|
||||||
@ -933,14 +929,6 @@ class Autogrow(ComfyTypeI):
|
|||||||
def validate(self):
|
def validate(self):
|
||||||
self.input.validate()
|
self.input.validate()
|
||||||
|
|
||||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
|
||||||
real_inputs = []
|
|
||||||
for name, input in self.cached_inputs.items():
|
|
||||||
if name in live_inputs:
|
|
||||||
real_inputs.append(input)
|
|
||||||
add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
|
|
||||||
add_dynamic_id_mapping(d, real_inputs, curr_prefix)
|
|
||||||
|
|
||||||
class TemplatePrefix(_AutogrowTemplate):
|
class TemplatePrefix(_AutogrowTemplate):
|
||||||
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
|
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
|
||||||
super().__init__(input)
|
super().__init__(input)
|
||||||
@ -985,22 +973,45 @@ class Autogrow(ComfyTypeI):
|
|||||||
"template": self.template.as_dict(),
|
"template": self.template.as_dict(),
|
||||||
})
|
})
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Input]:
|
|
||||||
return self.template.get_all()
|
|
||||||
|
|
||||||
def get_all(self) -> list[Input]:
|
def get_all(self) -> list[Input]:
|
||||||
return [self] + self.template.get_all()
|
return [self] + self.template.get_all()
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
self.template.validate()
|
self.template.validate()
|
||||||
|
|
||||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
@staticmethod
|
||||||
curr_prefix = f"{curr_prefix}{self.id}."
|
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||||
# need to remove self from expected inputs dictionary; replaced by template inputs in frontend
|
# NOTE: purposely do not include self in out_dict; instead use only the template inputs
|
||||||
for inner_dict in d.values():
|
# need to figure out names based on template type
|
||||||
if self.id in inner_dict:
|
is_names = ("names" in value[1]["template"])
|
||||||
del inner_dict[self.id]
|
is_prefix = ("prefix" in value[1]["template"])
|
||||||
self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
|
input = value[1]["template"]["input"]
|
||||||
|
if is_names:
|
||||||
|
min = value[1]["template"]["min"]
|
||||||
|
names = value[1]["template"]["names"]
|
||||||
|
max = len(names)
|
||||||
|
elif is_prefix:
|
||||||
|
prefix = value[1]["template"]["prefix"]
|
||||||
|
min = value[1]["template"]["min"]
|
||||||
|
max = value[1]["template"]["max"]
|
||||||
|
names = [f"{prefix}{i}" for i in range(max)]
|
||||||
|
# need to create a new input based on the contents of input
|
||||||
|
template_input = None
|
||||||
|
for _, dict_input in input.items():
|
||||||
|
# for now, get just the first value from dict_input
|
||||||
|
template_input = list(dict_input.values())[0]
|
||||||
|
new_dict = {}
|
||||||
|
for i, name in enumerate(names):
|
||||||
|
expected_id = finalize_prefix(curr_prefix, name)
|
||||||
|
if expected_id in live_inputs:
|
||||||
|
# required
|
||||||
|
if i < min:
|
||||||
|
type_dict = new_dict.setdefault("required", {})
|
||||||
|
# optional
|
||||||
|
else:
|
||||||
|
type_dict = new_dict.setdefault("optional", {})
|
||||||
|
type_dict[name] = template_input
|
||||||
|
parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix)
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
|
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
|
||||||
class DynamicCombo(ComfyTypeI):
|
class DynamicCombo(ComfyTypeI):
|
||||||
@ -1023,23 +1034,6 @@ class DynamicCombo(ComfyTypeI):
|
|||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||||
self.options = options
|
self.options = options
|
||||||
|
|
||||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
|
||||||
# check if dynamic input's id is in live_inputs
|
|
||||||
if self.id in live_inputs:
|
|
||||||
curr_prefix = f"{curr_prefix}{self.id}."
|
|
||||||
key = live_inputs[self.id]
|
|
||||||
selected_option = None
|
|
||||||
for option in self.options:
|
|
||||||
if option.key == key:
|
|
||||||
selected_option = option
|
|
||||||
break
|
|
||||||
if selected_option is not None:
|
|
||||||
add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
|
|
||||||
add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
|
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Input]:
|
|
||||||
return [input for option in self.options for input in option.inputs]
|
|
||||||
|
|
||||||
def get_all(self) -> list[Input]:
|
def get_all(self) -> list[Input]:
|
||||||
return [self] + [input for option in self.options for input in option.inputs]
|
return [self] + [input for option in self.options for input in option.inputs]
|
||||||
|
|
||||||
@ -1054,6 +1048,24 @@ class DynamicCombo(ComfyTypeI):
|
|||||||
for input in option.inputs:
|
for input in option.inputs:
|
||||||
input.validate()
|
input.validate()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||||
|
finalized_id = finalize_prefix(curr_prefix)
|
||||||
|
if finalized_id in live_inputs:
|
||||||
|
key = live_inputs[finalized_id]
|
||||||
|
selected_option = None
|
||||||
|
# get options from dict
|
||||||
|
options: list[dict[str, str | dict[str, Any]]] = value[1]["options"]
|
||||||
|
for option in options:
|
||||||
|
if option["key"] == key:
|
||||||
|
selected_option = option
|
||||||
|
break
|
||||||
|
if selected_option is not None:
|
||||||
|
parse_class_inputs(out_dict, live_inputs, selected_option["inputs"], curr_prefix)
|
||||||
|
# add self to inputs
|
||||||
|
out_dict[input_type][finalized_id] = value
|
||||||
|
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
|
@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
|
||||||
class DynamicSlot(ComfyTypeI):
|
class DynamicSlot(ComfyTypeI):
|
||||||
Type = dict[str, Any]
|
Type = dict[str, Any]
|
||||||
@ -1076,17 +1088,8 @@ class DynamicSlot(ComfyTypeI):
|
|||||||
self.force_input = True
|
self.force_input = True
|
||||||
self.slot.force_input = True
|
self.slot.force_input = True
|
||||||
|
|
||||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
|
||||||
if self.id in live_inputs:
|
|
||||||
curr_prefix = f"{curr_prefix}{self.id}."
|
|
||||||
add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
|
|
||||||
add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
|
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Input]:
|
|
||||||
return [self.slot] + self.inputs
|
|
||||||
|
|
||||||
def get_all(self) -> list[Input]:
|
def get_all(self) -> list[Input]:
|
||||||
return [self] + [self.slot] + self.inputs
|
return [self.slot] + self.inputs
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return super().as_dict() | prune_dict({
|
return super().as_dict() | prune_dict({
|
||||||
@ -1100,17 +1103,41 @@ class DynamicSlot(ComfyTypeI):
|
|||||||
for input in self.inputs:
|
for input in self.inputs:
|
||||||
input.validate()
|
input.validate()
|
||||||
|
|
||||||
def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None):
|
@staticmethod
|
||||||
dynamic = d.setdefault("dynamic_paths", {})
|
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||||
if self is not None:
|
finalized_id = finalize_prefix(curr_prefix)
|
||||||
dynamic[self.id] = f"{curr_prefix}{self.id}"
|
if finalized_id in live_inputs:
|
||||||
for i in inputs:
|
inputs = value[1]["inputs"]
|
||||||
if not isinstance(i, DynamicInput):
|
parse_class_inputs(out_dict, live_inputs, inputs, curr_prefix)
|
||||||
dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}"
|
# add self to inputs
|
||||||
|
out_dict[input_type][finalized_id] = value
|
||||||
|
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||||
|
|
||||||
|
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||||
|
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||||
|
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||||
|
|
||||||
|
def get_dynamic_input_func(io_type: str) -> Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]:
|
||||||
|
return DYNAMIC_INPUT_LOOKUP[io_type]
|
||||||
|
|
||||||
|
def setup_dynamic_input_funcs():
|
||||||
|
# Autogrow.Input
|
||||||
|
register_dynamic_input_func(Autogrow.io_type, Autogrow._expand_schema_for_dynamic)
|
||||||
|
# DynamicCombo.Input
|
||||||
|
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
|
||||||
|
# DynamicSlot.Input
|
||||||
|
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
|
||||||
|
|
||||||
|
if len(DYNAMIC_INPUT_LOOKUP) == 0:
|
||||||
|
setup_dynamic_input_funcs()
|
||||||
|
|
||||||
class V3Data(TypedDict):
|
class V3Data(TypedDict):
|
||||||
hidden_inputs: dict[str, Any]
|
hidden_inputs: dict[str, Any]
|
||||||
|
'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.'
|
||||||
dynamic_paths: dict[str, Any]
|
dynamic_paths: dict[str, Any]
|
||||||
|
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
|
||||||
|
create_dynamic_tuple: bool
|
||||||
|
'When True, the value of the dynamic input will be in the format (value, path_key).'
|
||||||
|
|
||||||
class HiddenHolder:
|
class HiddenHolder:
|
||||||
def __init__(self, unique_id: str, prompt: Any,
|
def __init__(self, unique_id: str, prompt: Any,
|
||||||
@ -1146,6 +1173,10 @@ class HiddenHolder:
|
|||||||
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
|
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_v3_data(cls, v3_data: V3Data | None) -> HiddenHolder:
|
||||||
|
return cls.from_dict(v3_data["hidden_inputs"] if v3_data else None)
|
||||||
|
|
||||||
class Hidden(str, Enum):
|
class Hidden(str, Enum):
|
||||||
'''
|
'''
|
||||||
Enumerator for requesting hidden variables in nodes.
|
Enumerator for requesting hidden variables in nodes.
|
||||||
@ -1251,61 +1282,56 @@ class Schema:
|
|||||||
- verify ids on inputs and outputs are unique - both internally and in relation to each other
|
- verify ids on inputs and outputs are unique - both internally and in relation to each other
|
||||||
'''
|
'''
|
||||||
nested_inputs: list[Input] = []
|
nested_inputs: list[Input] = []
|
||||||
if self.inputs is not None:
|
for input in self.inputs:
|
||||||
for input in self.inputs:
|
if not isinstance(input, DynamicInput):
|
||||||
nested_inputs.extend(input.get_all())
|
nested_inputs.extend(input.get_all())
|
||||||
input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else []
|
input_ids = [i.id for i in nested_inputs]
|
||||||
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
|
output_ids = [o.id for o in self.outputs]
|
||||||
input_set = set(input_ids)
|
input_set = set(input_ids)
|
||||||
output_set = set(output_ids)
|
output_set = set(output_ids)
|
||||||
issues = []
|
issues: list[str] = []
|
||||||
# verify ids are unique per list
|
# verify ids are unique per list
|
||||||
if len(input_set) != len(input_ids):
|
if len(input_set) != len(input_ids):
|
||||||
issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.")
|
issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.")
|
||||||
if len(output_set) != len(output_ids):
|
if len(output_set) != len(output_ids):
|
||||||
issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.")
|
issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.")
|
||||||
# verify ids are unique between lists
|
|
||||||
intersection = input_set & output_set
|
|
||||||
if len(intersection) > 0:
|
|
||||||
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
|
|
||||||
if len(issues) > 0:
|
if len(issues) > 0:
|
||||||
raise ValueError("\n".join(issues))
|
raise ValueError("\n".join(issues))
|
||||||
# validate inputs and outputs
|
# validate inputs and outputs
|
||||||
if self.inputs is not None:
|
for input in self.inputs:
|
||||||
for input in self.inputs:
|
input.validate()
|
||||||
input.validate()
|
for output in self.outputs:
|
||||||
if self.outputs is not None:
|
output.validate()
|
||||||
for output in self.outputs:
|
|
||||||
output.validate()
|
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
|
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
|
||||||
|
# ensure inputs, outputs, and hidden are lists
|
||||||
|
if self.inputs is None:
|
||||||
|
self.inputs = []
|
||||||
|
if self.outputs is None:
|
||||||
|
self.outputs = []
|
||||||
|
if self.hidden is None:
|
||||||
|
self.hidden = []
|
||||||
# if is an api_node, will need key-related hidden
|
# if is an api_node, will need key-related hidden
|
||||||
if self.is_api_node:
|
if self.is_api_node:
|
||||||
if self.hidden is None:
|
|
||||||
self.hidden = []
|
|
||||||
if Hidden.auth_token_comfy_org not in self.hidden:
|
if Hidden.auth_token_comfy_org not in self.hidden:
|
||||||
self.hidden.append(Hidden.auth_token_comfy_org)
|
self.hidden.append(Hidden.auth_token_comfy_org)
|
||||||
if Hidden.api_key_comfy_org not in self.hidden:
|
if Hidden.api_key_comfy_org not in self.hidden:
|
||||||
self.hidden.append(Hidden.api_key_comfy_org)
|
self.hidden.append(Hidden.api_key_comfy_org)
|
||||||
# if is an output_node, will need prompt and extra_pnginfo
|
# if is an output_node, will need prompt and extra_pnginfo
|
||||||
if self.is_output_node:
|
if self.is_output_node:
|
||||||
if self.hidden is None:
|
|
||||||
self.hidden = []
|
|
||||||
if Hidden.prompt not in self.hidden:
|
if Hidden.prompt not in self.hidden:
|
||||||
self.hidden.append(Hidden.prompt)
|
self.hidden.append(Hidden.prompt)
|
||||||
if Hidden.extra_pnginfo not in self.hidden:
|
if Hidden.extra_pnginfo not in self.hidden:
|
||||||
self.hidden.append(Hidden.extra_pnginfo)
|
self.hidden.append(Hidden.extra_pnginfo)
|
||||||
# give outputs without ids default ids
|
# give outputs without ids default ids
|
||||||
if self.outputs is not None:
|
for i, output in enumerate(self.outputs):
|
||||||
for i, output in enumerate(self.outputs):
|
if output.id is None:
|
||||||
if output.id is None:
|
output.id = f"_{i}_{output.io_type}_"
|
||||||
output.id = f"_{i}_{output.io_type}_"
|
|
||||||
|
|
||||||
def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1:
|
def get_v1_info(self, cls) -> NodeInfoV1:
|
||||||
# NOTE: live_inputs will not be used anymore very soon and this will be done another way
|
|
||||||
# get V1 inputs
|
# get V1 inputs
|
||||||
input = create_input_dict_v1(self.inputs, live_inputs)
|
input = create_input_dict_v1(self.inputs)
|
||||||
if self.hidden:
|
if self.hidden:
|
||||||
for hidden in self.hidden:
|
for hidden in self.hidden:
|
||||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||||
@ -1385,33 +1411,54 @@ class Schema:
|
|||||||
)
|
)
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]:
|
||||||
|
out_dict = {
|
||||||
|
"required": {},
|
||||||
|
"optional": {},
|
||||||
|
"dynamic_paths": {},
|
||||||
|
}
|
||||||
|
d = d.copy()
|
||||||
|
# ignore hidden for parsing
|
||||||
|
hidden = d.pop("hidden", None)
|
||||||
|
parse_class_inputs(out_dict, live_inputs, d)
|
||||||
|
if hidden is not None and include_hidden:
|
||||||
|
out_dict["hidden"] = hidden
|
||||||
|
v3_data = {}
|
||||||
|
dynamic_paths = out_dict.pop("dynamic_paths", None)
|
||||||
|
if dynamic_paths is not None:
|
||||||
|
v3_data["dynamic_paths"] = dynamic_paths
|
||||||
|
return out_dict, hidden, v3_data
|
||||||
|
|
||||||
def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict:
|
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
|
||||||
|
for input_type, inner_d in curr_dict.items():
|
||||||
|
for id, value in inner_d.items():
|
||||||
|
io_type = value[0]
|
||||||
|
if io_type in DYNAMIC_INPUT_LOOKUP:
|
||||||
|
# dynamic inputs need to be handled with lookup functions
|
||||||
|
dynamic_input_func = get_dynamic_input_func(io_type)
|
||||||
|
new_prefix = handle_prefix(curr_prefix, id)
|
||||||
|
dynamic_input_func(out_dict, live_inputs, value, input_type, new_prefix)
|
||||||
|
else:
|
||||||
|
# non-dynamic inputs get directly transferred
|
||||||
|
finalized_id = finalize_prefix(curr_prefix, id)
|
||||||
|
out_dict[input_type][finalized_id] = value
|
||||||
|
if curr_prefix:
|
||||||
|
out_dict["dynamic_paths"][finalized_id] = finalized_id
|
||||||
|
|
||||||
|
def create_input_dict_v1(inputs: list[Input]) -> dict:
|
||||||
input = {
|
input = {
|
||||||
"required": {}
|
"required": {}
|
||||||
}
|
}
|
||||||
add_to_input_dict_v1(input, inputs, live_inputs)
|
for i in inputs:
|
||||||
|
add_to_dict_v1(i, input)
|
||||||
return input
|
return input
|
||||||
|
|
||||||
def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''):
|
def add_to_dict_v1(i: Input, d: dict):
|
||||||
for i in inputs:
|
|
||||||
if isinstance(i, DynamicInput):
|
|
||||||
add_to_dict_v1(i, d)
|
|
||||||
if live_inputs is not None:
|
|
||||||
i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
|
|
||||||
else:
|
|
||||||
add_to_dict_v1(i, d)
|
|
||||||
|
|
||||||
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
|
|
||||||
key = "optional" if i.optional else "required"
|
key = "optional" if i.optional else "required"
|
||||||
as_dict = i.as_dict()
|
as_dict = i.as_dict()
|
||||||
# for v1, we don't want to include the optional key
|
# for v1, we don't want to include the optional key
|
||||||
as_dict.pop("optional", None)
|
as_dict.pop("optional", None)
|
||||||
if dynamic_dict is None:
|
d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
|
||||||
value = (i.get_io_type(), as_dict)
|
|
||||||
else:
|
|
||||||
value = (i.get_io_type(), as_dict, dynamic_dict)
|
|
||||||
d.setdefault(key, {})[i.id] = value
|
|
||||||
|
|
||||||
def add_to_dict_v3(io: Input | Output, d: dict):
|
def add_to_dict_v3(io: Input | Output, d: dict):
|
||||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||||
@ -1423,6 +1470,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
|||||||
values = values.copy()
|
values = values.copy()
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
|
create_tuple = v3_data.get("create_dynamic_tuple", False)
|
||||||
|
|
||||||
for key, path in paths.items():
|
for key, path in paths.items():
|
||||||
parts = path.split(".")
|
parts = path.split(".")
|
||||||
current = result
|
current = result
|
||||||
@ -1431,7 +1480,10 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
|||||||
is_last = (i == len(parts) - 1)
|
is_last = (i == len(parts) - 1)
|
||||||
|
|
||||||
if is_last:
|
if is_last:
|
||||||
current[p] = values.pop(key, None)
|
value = values.pop(key, None)
|
||||||
|
if create_tuple:
|
||||||
|
value = (value, key)
|
||||||
|
current[p] = value
|
||||||
else:
|
else:
|
||||||
current = current.setdefault(p, {})
|
current = current.setdefault(p, {})
|
||||||
|
|
||||||
@ -1446,7 +1498,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
SCHEMA = None
|
SCHEMA = None
|
||||||
|
|
||||||
# filled in during execution
|
# filled in during execution
|
||||||
resources: Resources = None
|
|
||||||
hidden: HiddenHolder = None
|
hidden: HiddenHolder = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1493,7 +1544,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
return [name for name in kwargs if kwargs[name] is None]
|
return [name for name in kwargs if kwargs[name] is None]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.local_resources: ResourcesLocal = None
|
|
||||||
self.__class__.VALIDATE_CLASS()
|
self.__class__.VALIDATE_CLASS()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1561,7 +1611,7 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
||||||
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
||||||
# set hidden
|
# set hidden
|
||||||
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None)
|
type_clone.hidden = HiddenHolder.from_v3_data(v3_data)
|
||||||
return type_clone
|
return type_clone
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@ -1678,19 +1728,10 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
|
def INPUT_TYPES(cls) -> dict[str, dict]:
|
||||||
schema = cls.FINALIZE_SCHEMA()
|
schema = cls.FINALIZE_SCHEMA()
|
||||||
info = schema.get_v1_info(cls, live_inputs)
|
info = schema.get_v1_info(cls)
|
||||||
input = info.input
|
return info.input
|
||||||
if not include_hidden:
|
|
||||||
input.pop("hidden", None)
|
|
||||||
if return_schema:
|
|
||||||
v3_data: V3Data = {}
|
|
||||||
dynamic = input.pop("dynamic_paths", None)
|
|
||||||
if dynamic is not None:
|
|
||||||
v3_data["dynamic_paths"] = dynamic
|
|
||||||
return input, schema, v3_data
|
|
||||||
return input
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1809,7 +1850,7 @@ class NodeOutput(_NodeOutputInternal):
|
|||||||
return self.args if len(self.args) > 0 else None
|
return self.args if len(self.args) > 0 else None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> "NodeOutput":
|
def from_dict(cls, data: dict[str, Any]) -> NodeOutput:
|
||||||
args = ()
|
args = ()
|
||||||
ui = None
|
ui = None
|
||||||
expand = None
|
expand = None
|
||||||
@ -1904,8 +1945,8 @@ __all__ = [
|
|||||||
"Tracks",
|
"Tracks",
|
||||||
# Dynamic Types
|
# Dynamic Types
|
||||||
"MatchType",
|
"MatchType",
|
||||||
# "DynamicCombo",
|
"DynamicCombo",
|
||||||
# "Autogrow",
|
"Autogrow",
|
||||||
# Other classes
|
# Other classes
|
||||||
"HiddenHolder",
|
"HiddenHolder",
|
||||||
"Hidden",
|
"Hidden",
|
||||||
|
|||||||
@ -1,72 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import comfy.utils
|
|
||||||
import folder_paths
|
|
||||||
import logging
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any
|
|
||||||
import torch
|
|
||||||
|
|
||||||
class ResourceKey(ABC):
|
|
||||||
Type = Any
|
|
||||||
def __init__(self):
|
|
||||||
...
|
|
||||||
|
|
||||||
class TorchDictFolderFilename(ResourceKey):
|
|
||||||
'''Key for requesting a torch file via file_name from a folder category.'''
|
|
||||||
Type = dict[str, torch.Tensor]
|
|
||||||
def __init__(self, folder_name: str, file_name: str):
|
|
||||||
self.folder_name = folder_name
|
|
||||||
self.file_name = file_name
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return hash((self.folder_name, self.file_name))
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, TorchDictFolderFilename):
|
|
||||||
return False
|
|
||||||
return self.folder_name == other.folder_name and self.file_name == other.file_name
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"{self.folder_name} -> {self.file_name}"
|
|
||||||
|
|
||||||
class Resources(ABC):
|
|
||||||
def __init__(self):
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class ResourcesLocal(Resources):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.local_resources: dict[ResourceKey, Any] = {}
|
|
||||||
|
|
||||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
|
||||||
cached = self.local_resources.get(key, None)
|
|
||||||
if cached is not None:
|
|
||||||
logging.info(f"Using cached resource '{key}'")
|
|
||||||
return cached
|
|
||||||
logging.info(f"Loading resource '{key}'")
|
|
||||||
to_return = None
|
|
||||||
if isinstance(key, TorchDictFolderFilename):
|
|
||||||
if default is ...:
|
|
||||||
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
|
|
||||||
else:
|
|
||||||
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
|
|
||||||
if full_path is not None:
|
|
||||||
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
|
|
||||||
|
|
||||||
if to_return is not None:
|
|
||||||
self.local_resources[key] = to_return
|
|
||||||
return to_return
|
|
||||||
if default is not ...:
|
|
||||||
return default
|
|
||||||
raise Exception(f"Unsupported resource key type: {type(key)}")
|
|
||||||
|
|
||||||
|
|
||||||
class _RESOURCES:
|
|
||||||
ResourceKey = ResourceKey
|
|
||||||
TorchDictFolderFilename = TorchDictFolderFilename
|
|
||||||
Resources = Resources
|
|
||||||
ResourcesLocal = ResourcesLocal
|
|
||||||
@ -1,5 +1,6 @@
|
|||||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||||
from .geometry_types import VOXEL, MESH
|
from .geometry_types import VOXEL, MESH
|
||||||
|
from .image_types import SVG
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Utility Types
|
# Utility Types
|
||||||
@ -8,4 +9,5 @@ __all__ = [
|
|||||||
"VideoComponents",
|
"VideoComponents",
|
||||||
"VOXEL",
|
"VOXEL",
|
||||||
"MESH",
|
"MESH",
|
||||||
|
"SVG",
|
||||||
]
|
]
|
||||||
|
|||||||
18
comfy_api/latest/_util/image_types.py
Normal file
18
comfy_api/latest/_util/image_types.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
|
class SVG:
|
||||||
|
"""Stores SVG representations via a list of BytesIO objects."""
|
||||||
|
|
||||||
|
def __init__(self, data: list[BytesIO]):
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def combine(self, other: 'SVG') -> 'SVG':
|
||||||
|
return SVG(self.data + other.data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def combine_all(svgs: list['SVG']) -> 'SVG':
|
||||||
|
all_svgs_list: list[BytesIO] = []
|
||||||
|
for svg_item in svgs:
|
||||||
|
all_svgs_list.extend(svg_item.data)
|
||||||
|
return SVG(all_svgs_list)
|
||||||
@ -10,7 +10,7 @@ class Text2ImageTaskCreationRequest(BaseModel):
|
|||||||
size: str | None = Field(None)
|
size: str | None = Field(None)
|
||||||
seed: int | None = Field(0, ge=0, le=2147483647)
|
seed: int | None = Field(0, ge=0, le=2147483647)
|
||||||
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||||
watermark: bool | None = Field(True)
|
watermark: bool | None = Field(False)
|
||||||
|
|
||||||
|
|
||||||
class Image2ImageTaskCreationRequest(BaseModel):
|
class Image2ImageTaskCreationRequest(BaseModel):
|
||||||
@ -21,7 +21,7 @@ class Image2ImageTaskCreationRequest(BaseModel):
|
|||||||
size: str | None = Field("adaptive")
|
size: str | None = Field("adaptive")
|
||||||
seed: int | None = Field(..., ge=0, le=2147483647)
|
seed: int | None = Field(..., ge=0, le=2147483647)
|
||||||
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||||
watermark: bool | None = Field(True)
|
watermark: bool | None = Field(False)
|
||||||
|
|
||||||
|
|
||||||
class Seedream4Options(BaseModel):
|
class Seedream4Options(BaseModel):
|
||||||
@ -37,7 +37,7 @@ class Seedream4TaskCreationRequest(BaseModel):
|
|||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
sequential_image_generation: str = Field("disabled")
|
sequential_image_generation: str = Field("disabled")
|
||||||
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(False)
|
||||||
|
|
||||||
|
|
||||||
class ImageTaskCreationResponse(BaseModel):
|
class ImageTaskCreationResponse(BaseModel):
|
||||||
|
|||||||
@ -133,6 +133,7 @@ class GeminiImageGenerateContentRequest(BaseModel):
|
|||||||
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
|
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
|
||||||
tools: list[GeminiTool] | None = Field(None)
|
tools: list[GeminiTool] | None = Field(None)
|
||||||
videoMetadata: GeminiVideoMetadata | None = Field(None)
|
videoMetadata: GeminiVideoMetadata | None = Field(None)
|
||||||
|
uploadImagesToStorage: bool = Field(True)
|
||||||
|
|
||||||
|
|
||||||
class GeminiGenerateContentRequest(BaseModel):
|
class GeminiGenerateContentRequest(BaseModel):
|
||||||
|
|||||||
@ -102,3 +102,12 @@ class ImageToVideoWithAudioRequest(BaseModel):
|
|||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
mode: str = Field("pro")
|
mode: str = Field("pro")
|
||||||
sound: str = Field(..., description="'on' or 'off'")
|
sound: str = Field(..., description="'on' or 'off'")
|
||||||
|
|
||||||
|
|
||||||
|
class MotionControlRequest(BaseModel):
|
||||||
|
prompt: str = Field(...)
|
||||||
|
image_url: str = Field(...)
|
||||||
|
video_url: str = Field(...)
|
||||||
|
keep_original_sound: str = Field(...)
|
||||||
|
character_orientation: str = Field(...)
|
||||||
|
mode: str = Field(..., description="'pro' or 'std'")
|
||||||
|
|||||||
@ -112,7 +112,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the image',
|
tooltip='Whether to add an "AI generated" watermark to the image',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -215,7 +215,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the image',
|
tooltip='Whether to add an "AI generated" watermark to the image',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -229,6 +229,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -269,7 +270,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="ByteDanceSeedreamNode",
|
node_id="ByteDanceSeedreamNode",
|
||||||
display_name="ByteDance Seedream 4",
|
display_name="ByteDance Seedream 4.5",
|
||||||
category="api node/image/ByteDance",
|
category="api node/image/ByteDance",
|
||||||
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -346,7 +347,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the image.',
|
tooltip='Whether to add an "AI generated" watermark to the image.',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -380,7 +381,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
sequential_image_generation: str = "disabled",
|
sequential_image_generation: str = "disabled",
|
||||||
max_images: int = 1,
|
max_images: int = 1,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
watermark: bool = True,
|
watermark: bool = False,
|
||||||
fail_on_partial: bool = True,
|
fail_on_partial: bool = True,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
@ -507,7 +508,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -617,7 +618,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -739,7 +740,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -862,7 +863,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
|
|||||||
@ -34,6 +34,7 @@ from comfy_api_nodes.util import (
|
|||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
audio_to_base64_string,
|
audio_to_base64_string,
|
||||||
bytesio_to_image_tensor,
|
bytesio_to_image_tensor,
|
||||||
|
download_url_to_image_tensor,
|
||||||
get_number_of_images,
|
get_number_of_images,
|
||||||
sync_op,
|
sync_op,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
@ -141,9 +142,11 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
|||||||
)
|
)
|
||||||
parts = []
|
parts = []
|
||||||
for part in response.candidates[0].content.parts:
|
for part in response.candidates[0].content.parts:
|
||||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
if part_type == "text" and part.text:
|
||||||
parts.append(part)
|
parts.append(part)
|
||||||
elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type:
|
elif part.inlineData and part.inlineData.mimeType == part_type:
|
||||||
|
parts.append(part)
|
||||||
|
elif part.fileData and part.fileData.mimeType == part_type:
|
||||||
parts.append(part)
|
parts.append(part)
|
||||||
# Skip parts that don't match the requested type
|
# Skip parts that don't match the requested type
|
||||||
return parts
|
return parts
|
||||||
@ -163,12 +166,15 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
|||||||
return "\n".join([part.text for part in parts])
|
return "\n".join([part.text for part in parts])
|
||||||
|
|
||||||
|
|
||||||
def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||||
image_tensors: list[Input.Image] = []
|
image_tensors: list[Input.Image] = []
|
||||||
parts = get_parts_by_type(response, "image/png")
|
parts = get_parts_by_type(response, "image/png")
|
||||||
for part in parts:
|
for part in parts:
|
||||||
image_data = base64.b64decode(part.inlineData.data)
|
if part.inlineData:
|
||||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
image_data = base64.b64decode(part.inlineData.data)
|
||||||
|
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||||
|
else:
|
||||||
|
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
|
||||||
image_tensors.append(returned_image)
|
image_tensors.append(returned_image)
|
||||||
if len(image_tensors) == 0:
|
if len(image_tensors) == 0:
|
||||||
return torch.zeros((1, 1024, 1024, 4))
|
return torch.zeros((1, 1024, 1024, 4))
|
||||||
@ -596,7 +602,7 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
|
|
||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
|
||||||
data=GeminiImageGenerateContentRequest(
|
data=GeminiImageGenerateContentRequest(
|
||||||
contents=[
|
contents=[
|
||||||
GeminiContent(role=GeminiRole.user, parts=parts),
|
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||||
@ -610,7 +616,7 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
response_model=GeminiGenerateContentResponse,
|
response_model=GeminiGenerateContentResponse,
|
||||||
price_extractor=calculate_tokens_price,
|
price_extractor=calculate_tokens_price,
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
|
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
||||||
|
|
||||||
|
|
||||||
class GeminiImage2(IO.ComfyNode):
|
class GeminiImage2(IO.ComfyNode):
|
||||||
@ -729,7 +735,7 @@ class GeminiImage2(IO.ComfyNode):
|
|||||||
|
|
||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
|
||||||
data=GeminiImageGenerateContentRequest(
|
data=GeminiImageGenerateContentRequest(
|
||||||
contents=[
|
contents=[
|
||||||
GeminiContent(role=GeminiRole.user, parts=parts),
|
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||||
@ -743,7 +749,7 @@ class GeminiImage2(IO.ComfyNode):
|
|||||||
response_model=GeminiGenerateContentResponse,
|
response_model=GeminiGenerateContentResponse,
|
||||||
price_extractor=calculate_tokens_price,
|
price_extractor=calculate_tokens_price,
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
|
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
||||||
|
|
||||||
|
|
||||||
class GeminiExtension(ComfyExtension):
|
class GeminiExtension(ComfyExtension):
|
||||||
|
|||||||
@ -51,6 +51,7 @@ from comfy_api_nodes.apis import (
|
|||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.kling_api import (
|
from comfy_api_nodes.apis.kling_api import (
|
||||||
ImageToVideoWithAudioRequest,
|
ImageToVideoWithAudioRequest,
|
||||||
|
MotionControlRequest,
|
||||||
OmniImageParamImage,
|
OmniImageParamImage,
|
||||||
OmniParamImage,
|
OmniParamImage,
|
||||||
OmniParamVideo,
|
OmniParamVideo,
|
||||||
@ -806,6 +807,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||||
IO.Combo.Input("duration", options=[5, 10]),
|
IO.Combo.Input("duration", options=[5, 10]),
|
||||||
|
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Video.Output(),
|
IO.Video.Output(),
|
||||||
@ -825,6 +827,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
duration: int,
|
duration: int,
|
||||||
|
resolution: str = "1080p",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
@ -836,6 +839,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
duration=str(duration),
|
duration=str(duration),
|
||||||
|
mode="pro" if resolution == "1080p" else "std",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await finish_omni_video_task(cls, response)
|
return await finish_omni_video_task(cls, response)
|
||||||
@ -858,7 +862,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||||||
tooltip="A text prompt describing the video content. "
|
tooltip="A text prompt describing the video content. "
|
||||||
"This can include both positive and negative descriptions.",
|
"This can include both positive and negative descriptions.",
|
||||||
),
|
),
|
||||||
IO.Combo.Input("duration", options=["5", "10"]),
|
IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider),
|
||||||
IO.Image.Input("first_frame"),
|
IO.Image.Input("first_frame"),
|
||||||
IO.Image.Input(
|
IO.Image.Input(
|
||||||
"end_frame",
|
"end_frame",
|
||||||
@ -871,6 +875,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Up to 6 additional reference images.",
|
tooltip="Up to 6 additional reference images.",
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Video.Output(),
|
IO.Video.Output(),
|
||||||
@ -892,11 +897,16 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||||||
first_frame: Input.Image,
|
first_frame: Input.Image,
|
||||||
end_frame: Input.Image | None = None,
|
end_frame: Input.Image | None = None,
|
||||||
reference_images: Input.Image | None = None,
|
reference_images: Input.Image | None = None,
|
||||||
|
resolution: str = "1080p",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
prompt = normalize_omni_prompt_references(prompt)
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
if end_frame is not None and reference_images is not None:
|
if end_frame is not None and reference_images is not None:
|
||||||
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
|
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
|
||||||
|
if duration not in (5, 10) and end_frame is None and reference_images is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Duration is only supported for 5 or 10 seconds if there is no end frame or reference images."
|
||||||
|
)
|
||||||
validate_image_dimensions(first_frame, min_width=300, min_height=300)
|
validate_image_dimensions(first_frame, min_width=300, min_height=300)
|
||||||
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
|
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
|
||||||
image_list: list[OmniParamImage] = [
|
image_list: list[OmniParamImage] = [
|
||||||
@ -931,6 +941,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
duration=str(duration),
|
duration=str(duration),
|
||||||
image_list=image_list,
|
image_list=image_list,
|
||||||
|
mode="pro" if resolution == "1080p" else "std",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await finish_omni_video_task(cls, response)
|
return await finish_omni_video_task(cls, response)
|
||||||
@ -959,6 +970,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||||||
"reference_images",
|
"reference_images",
|
||||||
tooltip="Up to 7 reference images.",
|
tooltip="Up to 7 reference images.",
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Video.Output(),
|
IO.Video.Output(),
|
||||||
@ -979,6 +991,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
duration: int,
|
duration: int,
|
||||||
reference_images: Input.Image,
|
reference_images: Input.Image,
|
||||||
|
resolution: str = "1080p",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
prompt = normalize_omni_prompt_references(prompt)
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
@ -1000,6 +1013,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
duration=str(duration),
|
duration=str(duration),
|
||||||
image_list=image_list,
|
image_list=image_list,
|
||||||
|
mode="pro" if resolution == "1080p" else "std",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await finish_omni_video_task(cls, response)
|
return await finish_omni_video_task(cls, response)
|
||||||
@ -1031,6 +1045,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
|||||||
tooltip="Up to 4 additional reference images.",
|
tooltip="Up to 4 additional reference images.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Video.Output(),
|
IO.Video.Output(),
|
||||||
@ -1053,6 +1068,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
|||||||
reference_video: Input.Video,
|
reference_video: Input.Video,
|
||||||
keep_original_sound: bool,
|
keep_original_sound: bool,
|
||||||
reference_images: Input.Image | None = None,
|
reference_images: Input.Image | None = None,
|
||||||
|
resolution: str = "1080p",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
prompt = normalize_omni_prompt_references(prompt)
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
@ -1085,6 +1101,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
|||||||
duration=str(duration),
|
duration=str(duration),
|
||||||
image_list=image_list if image_list else None,
|
image_list=image_list if image_list else None,
|
||||||
video_list=video_list,
|
video_list=video_list,
|
||||||
|
mode="pro" if resolution == "1080p" else "std",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await finish_omni_video_task(cls, response)
|
return await finish_omni_video_task(cls, response)
|
||||||
@ -1114,6 +1131,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||||||
tooltip="Up to 4 additional reference images.",
|
tooltip="Up to 4 additional reference images.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Video.Output(),
|
IO.Video.Output(),
|
||||||
@ -1134,6 +1152,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||||||
video: Input.Video,
|
video: Input.Video,
|
||||||
keep_original_sound: bool,
|
keep_original_sound: bool,
|
||||||
reference_images: Input.Image | None = None,
|
reference_images: Input.Image | None = None,
|
||||||
|
resolution: str = "1080p",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
prompt = normalize_omni_prompt_references(prompt)
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
@ -1166,6 +1185,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||||||
duration=None,
|
duration=None,
|
||||||
image_list=image_list if image_list else None,
|
image_list=image_list if image_list else None,
|
||||||
video_list=video_list,
|
video_list=video_list,
|
||||||
|
mode="pro" if resolution == "1080p" else "std",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await finish_omni_video_task(cls, response)
|
return await finish_omni_video_task(cls, response)
|
||||||
@ -2159,6 +2179,91 @@ class ImageToVideoWithAudio(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||||
|
|
||||||
|
|
||||||
|
class MotionControl(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingMotionControl",
|
||||||
|
display_name="Kling Motion Control",
|
||||||
|
category="api node/video/Kling",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input("prompt", multiline=True),
|
||||||
|
IO.Image.Input("reference_image"),
|
||||||
|
IO.Video.Input(
|
||||||
|
"reference_video",
|
||||||
|
tooltip="Motion reference video used to drive movement/expression.\n"
|
||||||
|
"Duration limits depend on character_orientation:\n"
|
||||||
|
" - image: 3–10s (max 10s)\n"
|
||||||
|
" - video: 3–30s (max 30s)",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input("keep_original_sound", default=True),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"character_orientation",
|
||||||
|
options=["video", "image"],
|
||||||
|
tooltip="Controls where the character's facing/orientation comes from.\n"
|
||||||
|
"video: movements, expressions, camera moves, and orientation "
|
||||||
|
"follow the motion reference video (other details via prompt).\n"
|
||||||
|
"image: movements and expressions still follow the motion reference video, "
|
||||||
|
"but the character orientation matches the reference image (camera/other details via prompt).",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("mode", options=["pro", "std"]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
reference_image: Input.Image,
|
||||||
|
reference_video: Input.Video,
|
||||||
|
keep_original_sound: bool,
|
||||||
|
character_orientation: str,
|
||||||
|
mode: str,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, max_length=2500)
|
||||||
|
validate_image_dimensions(reference_image, min_width=340, min_height=340)
|
||||||
|
validate_image_aspect_ratio(reference_image, (1, 2.5), (2.5, 1))
|
||||||
|
if character_orientation == "image":
|
||||||
|
validate_video_duration(reference_video, min_duration=3, max_duration=10)
|
||||||
|
else:
|
||||||
|
validate_video_duration(reference_video, min_duration=3, max_duration=30)
|
||||||
|
validate_video_dimensions(reference_video, min_width=340, min_height=340, max_width=3850, max_height=3850)
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/videos/motion-control", method="POST"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
data=MotionControlRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
image_url=(await upload_images_to_comfyapi(cls, reference_image))[0],
|
||||||
|
video_url=await upload_video_to_comfyapi(cls, reference_video),
|
||||||
|
keep_original_sound="yes" if keep_original_sound else "no",
|
||||||
|
character_orientation=character_orientation,
|
||||||
|
mode=mode,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if response.code:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||||
|
)
|
||||||
|
final_response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/kling/v1/videos/motion-control/{response.data.task_id}"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||||
|
|
||||||
|
|
||||||
class KlingExtension(ComfyExtension):
|
class KlingExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -2184,6 +2289,7 @@ class KlingExtension(ComfyExtension):
|
|||||||
OmniProImageNode,
|
OmniProImageNode,
|
||||||
TextToVideoWithAudio,
|
TextToVideoWithAudio,
|
||||||
ImageToVideoWithAudio,
|
ImageToVideoWithAudio,
|
||||||
|
MotionControl,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -23,10 +23,6 @@ UPSCALER_MODELS_MAP = {
|
|||||||
"Starlight (Astra) Fast": "slf-1",
|
"Starlight (Astra) Fast": "slf-1",
|
||||||
"Starlight (Astra) Creative": "slc-1",
|
"Starlight (Astra) Creative": "slc-1",
|
||||||
}
|
}
|
||||||
UPSCALER_VALUES_MAP = {
|
|
||||||
"FullHD (1080p)": 1920,
|
|
||||||
"4K (2160p)": 3840,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TopazImageEnhance(IO.ComfyNode):
|
class TopazImageEnhance(IO.ComfyNode):
|
||||||
@ -214,7 +210,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
|||||||
IO.Video.Input("video"),
|
IO.Video.Input("video"),
|
||||||
IO.Boolean.Input("upscaler_enabled", default=True),
|
IO.Boolean.Input("upscaler_enabled", default=True),
|
||||||
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
|
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
|
||||||
IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())),
|
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"upscaler_creativity",
|
"upscaler_creativity",
|
||||||
options=["low", "middle", "high"],
|
options=["low", "middle", "high"],
|
||||||
@ -306,8 +302,33 @@ class TopazVideoEnhance(IO.ComfyNode):
|
|||||||
target_frame_rate = src_frame_rate
|
target_frame_rate = src_frame_rate
|
||||||
filters = []
|
filters = []
|
||||||
if upscaler_enabled:
|
if upscaler_enabled:
|
||||||
target_width = UPSCALER_VALUES_MAP[upscaler_resolution]
|
if "1080p" in upscaler_resolution:
|
||||||
target_height = UPSCALER_VALUES_MAP[upscaler_resolution]
|
target_pixel_p = 1080
|
||||||
|
max_long_side = 1920
|
||||||
|
else:
|
||||||
|
target_pixel_p = 2160
|
||||||
|
max_long_side = 3840
|
||||||
|
ar = src_width / src_height
|
||||||
|
if src_width >= src_height:
|
||||||
|
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
|
||||||
|
target_height = target_pixel_p
|
||||||
|
target_width = int(target_height * ar)
|
||||||
|
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
|
||||||
|
if target_width > max_long_side:
|
||||||
|
target_width = max_long_side
|
||||||
|
target_height = int(target_width / ar)
|
||||||
|
else:
|
||||||
|
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
|
||||||
|
target_width = target_pixel_p
|
||||||
|
target_height = int(target_width / ar)
|
||||||
|
# Check if height exceeds standard bounds
|
||||||
|
if target_height > max_long_side:
|
||||||
|
target_height = max_long_side
|
||||||
|
target_width = int(target_height * ar)
|
||||||
|
if target_width % 2 != 0:
|
||||||
|
target_width += 1
|
||||||
|
if target_height % 2 != 0:
|
||||||
|
target_height += 1
|
||||||
filters.append(
|
filters.append(
|
||||||
topaz_api.VideoEnhancementFilter(
|
topaz_api.VideoEnhancementFilter(
|
||||||
model=UPSCALER_MODELS_MAP[upscaler_model],
|
model=UPSCALER_MODELS_MAP[upscaler_model],
|
||||||
|
|||||||
@ -155,7 +155,7 @@ class TripoTextToModelNode(IO.ComfyNode):
|
|||||||
model_seed=model_seed,
|
model_seed=model_seed,
|
||||||
texture_seed=texture_seed,
|
texture_seed=texture_seed,
|
||||||
texture_quality=texture_quality,
|
texture_quality=texture_quality,
|
||||||
face_limit=face_limit,
|
face_limit=face_limit if face_limit != -1 else None,
|
||||||
geometry_quality=geometry_quality,
|
geometry_quality=geometry_quality,
|
||||||
auto_size=True,
|
auto_size=True,
|
||||||
quad=quad,
|
quad=quad,
|
||||||
@ -255,7 +255,7 @@ class TripoImageToModelNode(IO.ComfyNode):
|
|||||||
texture_alignment=texture_alignment,
|
texture_alignment=texture_alignment,
|
||||||
texture_seed=texture_seed,
|
texture_seed=texture_seed,
|
||||||
texture_quality=texture_quality,
|
texture_quality=texture_quality,
|
||||||
face_limit=face_limit,
|
face_limit=face_limit if face_limit != -1 else None,
|
||||||
auto_size=True,
|
auto_size=True,
|
||||||
quad=quad,
|
quad=quad,
|
||||||
),
|
),
|
||||||
@ -369,7 +369,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
|||||||
texture_quality=texture_quality,
|
texture_quality=texture_quality,
|
||||||
geometry_quality=geometry_quality,
|
geometry_quality=geometry_quality,
|
||||||
texture_alignment=texture_alignment,
|
texture_alignment=texture_alignment,
|
||||||
face_limit=face_limit,
|
face_limit=face_limit if face_limit != -1 else None,
|
||||||
quad=quad,
|
quad=quad,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -168,6 +168,8 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
|||||||
# Only add generateAudio for Veo 3 models
|
# Only add generateAudio for Veo 3 models
|
||||||
if model.find("veo-2.0") == -1:
|
if model.find("veo-2.0") == -1:
|
||||||
parameters["generateAudio"] = generate_audio
|
parameters["generateAudio"] = generate_audio
|
||||||
|
# force "enhance_prompt" to True for Veo3 models
|
||||||
|
parameters["enhancePrompt"] = True
|
||||||
|
|
||||||
initial_response = await sync_op(
|
initial_response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
@ -291,7 +293,7 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
|||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"enhance_prompt",
|
"enhance_prompt",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip="Whether to enhance the prompt with AI assistance",
|
tooltip="This parameter is deprecated and ignored.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
|
|||||||
@ -46,14 +46,14 @@ class Txt2ImageParametersField(BaseModel):
|
|||||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
prompt_extend: bool = Field(True)
|
prompt_extend: bool = Field(True)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(False)
|
||||||
|
|
||||||
|
|
||||||
class Image2ImageParametersField(BaseModel):
|
class Image2ImageParametersField(BaseModel):
|
||||||
size: str | None = Field(None)
|
size: str | None = Field(None)
|
||||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(False)
|
||||||
|
|
||||||
|
|
||||||
class Text2VideoParametersField(BaseModel):
|
class Text2VideoParametersField(BaseModel):
|
||||||
@ -61,7 +61,7 @@ class Text2VideoParametersField(BaseModel):
|
|||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
duration: int = Field(5, ge=5, le=15)
|
duration: int = Field(5, ge=5, le=15)
|
||||||
prompt_extend: bool = Field(True)
|
prompt_extend: bool = Field(True)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(False)
|
||||||
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
||||||
shot_type: str = Field("single")
|
shot_type: str = Field("single")
|
||||||
|
|
||||||
@ -71,7 +71,7 @@ class Image2VideoParametersField(BaseModel):
|
|||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
duration: int = Field(5, ge=5, le=15)
|
duration: int = Field(5, ge=5, le=15)
|
||||||
prompt_extend: bool = Field(True)
|
prompt_extend: bool = Field(True)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(False)
|
||||||
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
||||||
shot_type: str = Field("single")
|
shot_type: str = Field("single")
|
||||||
|
|
||||||
@ -208,7 +208,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip="Whether to add an AI-generated watermark to the result.",
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -234,7 +234,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
height: int = 1024,
|
height: int = 1024,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
prompt_extend: bool = True,
|
prompt_extend: bool = True,
|
||||||
watermark: bool = True,
|
watermark: bool = False,
|
||||||
):
|
):
|
||||||
initial_response = await sync_op(
|
initial_response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
@ -327,7 +327,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip="Whether to add an AI-generated watermark to the result.",
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -353,7 +353,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
# width: int = 1024,
|
# width: int = 1024,
|
||||||
# height: int = 1024,
|
# height: int = 1024,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
watermark: bool = True,
|
watermark: bool = False,
|
||||||
):
|
):
|
||||||
n_images = get_number_of_images(image)
|
n_images = get_number_of_images(image)
|
||||||
if n_images not in (1, 2):
|
if n_images not in (1, 2):
|
||||||
@ -476,7 +476,7 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip="Whether to add an AI-generated watermark to the result.",
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -512,7 +512,7 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
generate_audio: bool = False,
|
generate_audio: bool = False,
|
||||||
prompt_extend: bool = True,
|
prompt_extend: bool = True,
|
||||||
watermark: bool = True,
|
watermark: bool = False,
|
||||||
shot_type: str = "single",
|
shot_type: str = "single",
|
||||||
):
|
):
|
||||||
if "480p" in size and model == "wan2.6-t2v":
|
if "480p" in size and model == "wan2.6-t2v":
|
||||||
@ -637,7 +637,7 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=False,
|
||||||
tooltip="Whether to add an AI-generated watermark to the result.",
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -674,7 +674,7 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
generate_audio: bool = False,
|
generate_audio: bool = False,
|
||||||
prompt_extend: bool = True,
|
prompt_extend: bool = True,
|
||||||
watermark: bool = True,
|
watermark: bool = False,
|
||||||
shot_type: str = "single",
|
shot_type: str = "single",
|
||||||
):
|
):
|
||||||
if get_number_of_images(image) != 1:
|
if get_number_of_images(image) != 1:
|
||||||
|
|||||||
@ -1,16 +1,22 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy.model_management import processing_interrupted
|
from comfy.model_management import processing_interrupted
|
||||||
from comfy_api.latest import IO
|
from comfy_api.latest import IO
|
||||||
|
|
||||||
from .common_exceptions import ProcessingInterrupted
|
from .common_exceptions import ProcessingInterrupted
|
||||||
|
|
||||||
|
_HAS_PCT_ESC = re.compile(r"%[0-9A-Fa-f]{2}") # any % followed by 2 hex digits
|
||||||
|
_HAS_BAD_PCT = re.compile(r"%(?![0-9A-Fa-f]{2})") # any % not followed by 2 hex digits
|
||||||
|
|
||||||
|
|
||||||
def is_processing_interrupted() -> bool:
|
def is_processing_interrupted() -> bool:
|
||||||
"""Return True if user/runtime requested interruption."""
|
"""Return True if user/runtime requested interruption."""
|
||||||
@ -69,3 +75,17 @@ def get_fs_object_size(path_or_object: str | BytesIO) -> int:
|
|||||||
if isinstance(path_or_object, str):
|
if isinstance(path_or_object, str):
|
||||||
return os.path.getsize(path_or_object)
|
return os.path.getsize(path_or_object)
|
||||||
return len(path_or_object.getvalue())
|
return len(path_or_object.getvalue())
|
||||||
|
|
||||||
|
|
||||||
|
def to_aiohttp_url(url: str) -> URL:
|
||||||
|
"""If `url` appears to be already percent-encoded (contains at least one valid %HH
|
||||||
|
escape and no malformed '%' sequences) and contains no raw whitespace/control
|
||||||
|
characters preserve the original encoding byte-for-byte (important for signed/presigned URLs).
|
||||||
|
Otherwise, return `URL(url)` and allow yarl to normalize/quote as needed."""
|
||||||
|
if any(c.isspace() for c in url) or any(ord(c) < 0x20 for c in url):
|
||||||
|
# Avoid encoded=True if URL contains raw whitespace/control chars
|
||||||
|
return URL(url)
|
||||||
|
if _HAS_PCT_ESC.search(url) and not _HAS_BAD_PCT.search(url):
|
||||||
|
# Preserve encoding only if it appears pre-encoded AND has no invalid % sequences
|
||||||
|
return URL(url, encoded=True)
|
||||||
|
return URL(url)
|
||||||
|
|||||||
@ -430,9 +430,9 @@ def _display_text(
|
|||||||
if status:
|
if status:
|
||||||
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
||||||
if price is not None:
|
if price is not None:
|
||||||
p = f"{float(price):,.4f}".rstrip("0").rstrip(".")
|
p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
|
||||||
if p != "0":
|
if p != "0":
|
||||||
display_lines.append(f"Price: ${p}")
|
display_lines.append(f"Price: {p} credits")
|
||||||
if text is not None:
|
if text is not None:
|
||||||
display_lines.append(text)
|
display_lines.append(text)
|
||||||
if display_lines:
|
if display_lines:
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from ._helpers import (
|
|||||||
get_auth_header,
|
get_auth_header,
|
||||||
is_processing_interrupted,
|
is_processing_interrupted,
|
||||||
sleep_with_interrupt,
|
sleep_with_interrupt,
|
||||||
|
to_aiohttp_url,
|
||||||
)
|
)
|
||||||
from .client import _diagnose_connectivity
|
from .client import _diagnose_connectivity
|
||||||
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||||
@ -94,7 +95,7 @@ async def download_url_to_bytesio(
|
|||||||
|
|
||||||
monitor_task = asyncio.create_task(_monitor())
|
monitor_task = asyncio.create_task(_monitor())
|
||||||
|
|
||||||
req_task = asyncio.create_task(session.get(url, headers=headers))
|
req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers))
|
||||||
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
|
||||||
if monitor_task in done and req_task in pending:
|
if monitor_task in done and req_task in pending:
|
||||||
|
|||||||
@ -97,6 +97,11 @@ def get_input_info(
|
|||||||
extra_info = input_info[1]
|
extra_info = input_info[1]
|
||||||
else:
|
else:
|
||||||
extra_info = {}
|
extra_info = {}
|
||||||
|
# if input_type is a list, it is a Combo defined in outdated format; convert it.
|
||||||
|
# NOTE: uncomment this when we are confident old format going away won't cause too much trouble.
|
||||||
|
# if isinstance(input_type, list):
|
||||||
|
# extra_info["options"] = input_type
|
||||||
|
# input_type = IO.Combo.io_type
|
||||||
return input_type, input_category, extra_info
|
return input_type, input_category, extra_info
|
||||||
|
|
||||||
class TopologicalSort:
|
class TopologicalSort:
|
||||||
@ -202,15 +207,15 @@ class ExecutionList(TopologicalSort):
|
|||||||
return self.output_cache.get(node_id) is not None
|
return self.output_cache.get(node_id) is not None
|
||||||
|
|
||||||
def cache_link(self, from_node_id, to_node_id):
|
def cache_link(self, from_node_id, to_node_id):
|
||||||
if not to_node_id in self.execution_cache:
|
if to_node_id not in self.execution_cache:
|
||||||
self.execution_cache[to_node_id] = {}
|
self.execution_cache[to_node_id] = {}
|
||||||
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
||||||
if not from_node_id in self.execution_cache_listeners:
|
if from_node_id not in self.execution_cache_listeners:
|
||||||
self.execution_cache_listeners[from_node_id] = set()
|
self.execution_cache_listeners[from_node_id] = set()
|
||||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||||
|
|
||||||
def get_cache(self, from_node_id, to_node_id):
|
def get_cache(self, from_node_id, to_node_id):
|
||||||
if not to_node_id in self.execution_cache:
|
if to_node_id not in self.execution_cache:
|
||||||
return None
|
return None
|
||||||
value = self.execution_cache[to_node_id].get(from_node_id)
|
value = self.execution_cache[to_node_id].get(from_node_id)
|
||||||
if value is None:
|
if value is None:
|
||||||
|
|||||||
@ -21,14 +21,24 @@ def validate_node_input(
|
|||||||
"""
|
"""
|
||||||
# If the types are exactly the same, we can return immediately
|
# If the types are exactly the same, we can return immediately
|
||||||
# Use pre-union behaviour: inverse of `__ne__`
|
# Use pre-union behaviour: inverse of `__ne__`
|
||||||
|
# NOTE: this lets legacy '*' Any types work that override the __ne__ method of the str class.
|
||||||
if not received_type != input_type:
|
if not received_type != input_type:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# If one of the types is '*', we can return True immediately; this is the 'Any' type.
|
||||||
|
if received_type == IO.AnyType.io_type or input_type == IO.AnyType.io_type:
|
||||||
|
return True
|
||||||
|
|
||||||
# If the received type or input_type is a MatchType, we can return True immediately;
|
# If the received type or input_type is a MatchType, we can return True immediately;
|
||||||
# validation for this is handled by the frontend
|
# validation for this is handled by the frontend
|
||||||
if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
|
if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# This accounts for some custom nodes that output lists of options as the type;
|
||||||
|
# if we ever want to break them on purpose, this can be removed
|
||||||
|
if isinstance(received_type, list) and input_type == IO.Combo.io_type:
|
||||||
|
return True
|
||||||
|
|
||||||
# Not equal, and not strings
|
# Not equal, and not strings
|
||||||
if not isinstance(received_type, str) or not isinstance(input_type, str):
|
if not isinstance(received_type, str) or not isinstance(input_type, str):
|
||||||
return False
|
return False
|
||||||
@ -37,6 +47,10 @@ def validate_node_input(
|
|||||||
received_types = set(t.strip() for t in received_type.split(","))
|
received_types = set(t.strip() for t in received_type.split(","))
|
||||||
input_types = set(t.strip() for t in input_type.split(","))
|
input_types = set(t.strip() for t in input_type.split(","))
|
||||||
|
|
||||||
|
# If any of the types is '*', we can return True immediately; this is the 'Any' type.
|
||||||
|
if IO.AnyType.io_type in received_types or IO.AnyType.io_type in input_types:
|
||||||
|
return True
|
||||||
|
|
||||||
if strict:
|
if strict:
|
||||||
# In strict mode, all received types must be in the input types
|
# In strict mode, all received types must be in the input types
|
||||||
return received_types.issubset(input_types)
|
return received_types.issubset(input_types)
|
||||||
|
|||||||
@ -55,7 +55,8 @@ class APG(io.ComfyNode):
|
|||||||
def pre_cfg_function(args):
|
def pre_cfg_function(args):
|
||||||
nonlocal running_avg, prev_sigma
|
nonlocal running_avg, prev_sigma
|
||||||
|
|
||||||
if len(args["conds_out"]) == 1: return args["conds_out"]
|
if len(args["conds_out"]) == 1:
|
||||||
|
return args["conds_out"]
|
||||||
|
|
||||||
cond = args["conds_out"][0]
|
cond = args["conds_out"][0]
|
||||||
uncond = args["conds_out"][1]
|
uncond = args["conds_out"][1]
|
||||||
|
|||||||
@ -112,7 +112,7 @@ class VAEDecodeAudio(IO.ComfyNode):
|
|||||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||||
std[std < 1.0] = 1.0
|
std[std < 1.0] = 1.0
|
||||||
audio /= std
|
audio /= std
|
||||||
return IO.NodeOutput({"waveform": audio, "sample_rate": 44100})
|
return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]})
|
||||||
|
|
||||||
decode = execute # TODO: remove
|
decode = execute # TODO: remove
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import comfy.utils
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
class BasicScheduler(io.ComfyNode):
|
class BasicScheduler(io.ComfyNode):
|
||||||
@ -760,8 +761,12 @@ class SamplerCustom(io.ComfyNode):
|
|||||||
out = latent.copy()
|
out = latent.copy()
|
||||||
out["samples"] = samples
|
out["samples"] = samples
|
||||||
if "x0" in x0_output:
|
if "x0" in x0_output:
|
||||||
|
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
|
||||||
|
if samples.is_nested:
|
||||||
|
latent_shapes = [x.shape for x in samples.unbind()]
|
||||||
|
x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes))
|
||||||
out_denoised = latent.copy()
|
out_denoised = latent.copy()
|
||||||
out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu())
|
out_denoised["samples"] = x0_out
|
||||||
else:
|
else:
|
||||||
out_denoised = out
|
out_denoised = out
|
||||||
return io.NodeOutput(out, out_denoised)
|
return io.NodeOutput(out, out_denoised)
|
||||||
@ -948,8 +953,12 @@ class SamplerCustomAdvanced(io.ComfyNode):
|
|||||||
out = latent.copy()
|
out = latent.copy()
|
||||||
out["samples"] = samples
|
out["samples"] = samples
|
||||||
if "x0" in x0_output:
|
if "x0" in x0_output:
|
||||||
|
x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
|
||||||
|
if samples.is_nested:
|
||||||
|
latent_shapes = [x.shape for x in samples.unbind()]
|
||||||
|
x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes))
|
||||||
out_denoised = latent.copy()
|
out_denoised = latent.copy()
|
||||||
out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
|
out_denoised["samples"] = x0_out
|
||||||
else:
|
else:
|
||||||
out_denoised = out
|
out_denoised = out
|
||||||
return io.NodeOutput(out, out_denoised)
|
return io.NodeOutput(out, out_denoised)
|
||||||
@ -1005,6 +1014,25 @@ class AddNoise(io.ComfyNode):
|
|||||||
|
|
||||||
add_noise = execute
|
add_noise = execute
|
||||||
|
|
||||||
|
class ManualSigmas(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ManualSigmas",
|
||||||
|
category="_for_testing/custom_sampling",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.String.Input("sigmas", default="1, 0.5", multiline=False)
|
||||||
|
],
|
||||||
|
outputs=[io.Sigmas.Output()]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, sigmas) -> io.NodeOutput:
|
||||||
|
sigmas = re.findall(r"[-+]?(?:\d*\.*\d+)", sigmas)
|
||||||
|
sigmas = [float(i) for i in sigmas]
|
||||||
|
sigmas = torch.FloatTensor(sigmas)
|
||||||
|
return io.NodeOutput(sigmas)
|
||||||
|
|
||||||
class CustomSamplersExtension(ComfyExtension):
|
class CustomSamplersExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
@ -1044,6 +1072,7 @@ class CustomSamplersExtension(ComfyExtension):
|
|||||||
DisableNoise,
|
DisableNoise,
|
||||||
AddNoise,
|
AddNoise,
|
||||||
SamplerCustomAdvanced,
|
SamplerCustomAdvanced,
|
||||||
|
ManualSigmas,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -667,16 +667,19 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, image, longer_edge):
|
def _process(cls, image, longer_edge):
|
||||||
img = tensor_to_pil(image)
|
resized_images = []
|
||||||
w, h = img.size
|
for image_i in image:
|
||||||
if w > h:
|
img = tensor_to_pil(image_i)
|
||||||
new_w = longer_edge
|
w, h = img.size
|
||||||
new_h = int(h * (longer_edge / w))
|
if w > h:
|
||||||
else:
|
new_w = longer_edge
|
||||||
new_h = longer_edge
|
new_h = int(h * (longer_edge / w))
|
||||||
new_w = int(w * (longer_edge / h))
|
else:
|
||||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
new_h = longer_edge
|
||||||
return pil_to_tensor(img)
|
new_w = int(w * (longer_edge / h))
|
||||||
|
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||||
|
resized_images.append(pil_to_tensor(img))
|
||||||
|
return torch.cat(resized_images, dim=0)
|
||||||
|
|
||||||
|
|
||||||
class CenterCropImagesNode(ImageProcessingNode):
|
class CenterCropImagesNode(ImageProcessingNode):
|
||||||
|
|||||||
@ -5,7 +5,9 @@ import comfy.model_management
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel
|
from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel
|
||||||
|
from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
import json
|
||||||
|
|
||||||
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
|
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -186,7 +188,7 @@ class LatentUpscaleModelLoader(io.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model_name) -> io.NodeOutput:
|
def execute(cls, model_name) -> io.NodeOutput:
|
||||||
model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name)
|
model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name)
|
||||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True)
|
||||||
|
|
||||||
if "blocks.0.block.0.conv.weight" in sd:
|
if "blocks.0.block.0.conv.weight" in sd:
|
||||||
config = {
|
config = {
|
||||||
@ -197,6 +199,8 @@ class LatentUpscaleModelLoader(io.ComfyNode):
|
|||||||
"global_residual": False,
|
"global_residual": False,
|
||||||
}
|
}
|
||||||
model_type = "720p"
|
model_type = "720p"
|
||||||
|
model = HunyuanVideo15SRModel(model_type, config)
|
||||||
|
model.load_sd(sd)
|
||||||
elif "up.0.block.0.conv1.conv.weight" in sd:
|
elif "up.0.block.0.conv1.conv.weight" in sd:
|
||||||
sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()}
|
sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()}
|
||||||
config = {
|
config = {
|
||||||
@ -205,9 +209,12 @@ class LatentUpscaleModelLoader(io.ComfyNode):
|
|||||||
"block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))),
|
"block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))),
|
||||||
}
|
}
|
||||||
model_type = "1080p"
|
model_type = "1080p"
|
||||||
|
model = HunyuanVideo15SRModel(model_type, config)
|
||||||
model = HunyuanVideo15SRModel(model_type, config)
|
model.load_sd(sd)
|
||||||
model.load_sd(sd)
|
elif "post_upsample_res_blocks.0.conv2.bias" in sd:
|
||||||
|
config = json.loads(metadata["config"])
|
||||||
|
model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32]))
|
||||||
|
model.load_state_dict(sd)
|
||||||
|
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
|||||||
@ -2,280 +2,231 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy.cli_args import args
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
from PIL.PngImagePlugin import PngInfo
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from io import BytesIO
|
|
||||||
from inspect import cleandoc
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
from comfy.comfy_types import FileLocator, IO
|
|
||||||
from server import PromptServer
|
from server import PromptServer
|
||||||
|
from comfy_api.latest import ComfyExtension, IO, UI
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later.
|
||||||
|
|
||||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||||
|
|
||||||
class ImageCrop:
|
class ImageCrop(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "image": ("IMAGE",),
|
return IO.Schema(
|
||||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
node_id="ImageCrop",
|
||||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
display_name="Image Crop",
|
||||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
category="image/transform",
|
||||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
inputs=[
|
||||||
}}
|
IO.Image.Input("image"),
|
||||||
RETURN_TYPES = ("IMAGE",)
|
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
FUNCTION = "crop"
|
IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
],
|
||||||
|
outputs=[IO.Image.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "image/transform"
|
@classmethod
|
||||||
|
def execute(cls, image, width, height, x, y) -> IO.NodeOutput:
|
||||||
def crop(self, image, width, height, x, y):
|
|
||||||
x = min(x, image.shape[2] - 1)
|
x = min(x, image.shape[2] - 1)
|
||||||
y = min(y, image.shape[1] - 1)
|
y = min(y, image.shape[1] - 1)
|
||||||
to_x = width + x
|
to_x = width + x
|
||||||
to_y = height + y
|
to_y = height + y
|
||||||
img = image[:,y:to_y, x:to_x, :]
|
img = image[:,y:to_y, x:to_x, :]
|
||||||
return (img,)
|
return IO.NodeOutput(img)
|
||||||
|
|
||||||
class RepeatImageBatch:
|
crop = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class RepeatImageBatch(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "image": ("IMAGE",),
|
return IO.Schema(
|
||||||
"amount": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
node_id="RepeatImageBatch",
|
||||||
}}
|
category="image/batch",
|
||||||
RETURN_TYPES = ("IMAGE",)
|
inputs=[
|
||||||
FUNCTION = "repeat"
|
IO.Image.Input("image"),
|
||||||
|
IO.Int.Input("amount", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
|
outputs=[IO.Image.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "image/batch"
|
@classmethod
|
||||||
|
def execute(cls, image, amount) -> IO.NodeOutput:
|
||||||
def repeat(self, image, amount):
|
|
||||||
s = image.repeat((amount, 1,1,1))
|
s = image.repeat((amount, 1,1,1))
|
||||||
return (s,)
|
return IO.NodeOutput(s)
|
||||||
|
|
||||||
class ImageFromBatch:
|
repeat = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFromBatch(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "image": ("IMAGE",),
|
return IO.Schema(
|
||||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}),
|
node_id="ImageFromBatch",
|
||||||
"length": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
category="image/batch",
|
||||||
}}
|
inputs=[
|
||||||
RETURN_TYPES = ("IMAGE",)
|
IO.Image.Input("image"),
|
||||||
FUNCTION = "frombatch"
|
IO.Int.Input("batch_index", default=0, min=0, max=4095),
|
||||||
|
IO.Int.Input("length", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
|
outputs=[IO.Image.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "image/batch"
|
@classmethod
|
||||||
|
def execute(cls, image, batch_index, length) -> IO.NodeOutput:
|
||||||
def frombatch(self, image, batch_index, length):
|
|
||||||
s_in = image
|
s_in = image
|
||||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||||
length = min(s_in.shape[0] - batch_index, length)
|
length = min(s_in.shape[0] - batch_index, length)
|
||||||
s = s_in[batch_index:batch_index + length].clone()
|
s = s_in[batch_index:batch_index + length].clone()
|
||||||
return (s,)
|
return IO.NodeOutput(s)
|
||||||
|
|
||||||
|
frombatch = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class ImageAddNoise:
|
class ImageAddNoise(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "image": ("IMAGE",),
|
return IO.Schema(
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
|
node_id="ImageAddNoise",
|
||||||
"strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
category="image",
|
||||||
}}
|
inputs=[
|
||||||
RETURN_TYPES = ("IMAGE",)
|
IO.Image.Input("image"),
|
||||||
FUNCTION = "repeat"
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFFFFFFFFFF,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="The random seed used for creating the noise.",
|
||||||
|
),
|
||||||
|
IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[IO.Image.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "image"
|
@classmethod
|
||||||
|
def execute(cls, image, seed, strength) -> IO.NodeOutput:
|
||||||
def repeat(self, image, seed, strength):
|
|
||||||
generator = torch.manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0)
|
s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0)
|
||||||
return (s,)
|
return IO.NodeOutput(s)
|
||||||
|
|
||||||
class SaveAnimatedWEBP:
|
repeat = execute # TODO: remove
|
||||||
def __init__(self):
|
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
|
||||||
self.type = "output"
|
|
||||||
self.prefix_append = ""
|
|
||||||
|
|
||||||
methods = {"default": 4, "fastest": 0, "slowest": 6}
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required":
|
|
||||||
{"images": ("IMAGE", ),
|
|
||||||
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
|
|
||||||
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
|
|
||||||
"lossless": ("BOOLEAN", {"default": True}),
|
|
||||||
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
|
|
||||||
"method": (list(s.methods.keys()),),
|
|
||||||
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
|
|
||||||
},
|
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
class SaveAnimatedWEBP(IO.ComfyNode):
|
||||||
FUNCTION = "save_images"
|
COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6}
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
CATEGORY = "image/animation"
|
|
||||||
|
|
||||||
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
|
|
||||||
method = self.methods.get(method)
|
|
||||||
filename_prefix += self.prefix_append
|
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
|
||||||
results: list[FileLocator] = []
|
|
||||||
pil_images = []
|
|
||||||
for image in images:
|
|
||||||
i = 255. * image.cpu().numpy()
|
|
||||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
|
||||||
pil_images.append(img)
|
|
||||||
|
|
||||||
metadata = pil_images[0].getexif()
|
|
||||||
if not args.disable_metadata:
|
|
||||||
if prompt is not None:
|
|
||||||
metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
|
|
||||||
if extra_pnginfo is not None:
|
|
||||||
inital_exif = 0x010f
|
|
||||||
for x in extra_pnginfo:
|
|
||||||
metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
|
|
||||||
inital_exif -= 1
|
|
||||||
|
|
||||||
if num_frames == 0:
|
|
||||||
num_frames = len(pil_images)
|
|
||||||
|
|
||||||
c = len(pil_images)
|
|
||||||
for i in range(0, c, num_frames):
|
|
||||||
file = f"{filename}_{counter:05}_.webp"
|
|
||||||
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
|
|
||||||
results.append({
|
|
||||||
"filename": file,
|
|
||||||
"subfolder": subfolder,
|
|
||||||
"type": self.type
|
|
||||||
})
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
animated = num_frames != 1
|
|
||||||
return { "ui": { "images": results, "animated": (animated,) } }
|
|
||||||
|
|
||||||
class SaveAnimatedPNG:
|
|
||||||
def __init__(self):
|
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
|
||||||
self.type = "output"
|
|
||||||
self.prefix_append = ""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required":
|
return IO.Schema(
|
||||||
{"images": ("IMAGE", ),
|
node_id="SaveAnimatedWEBP",
|
||||||
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
|
category="image/animation",
|
||||||
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
|
inputs=[
|
||||||
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
|
IO.Image.Input("images"),
|
||||||
},
|
IO.String.Input("filename_prefix", default="ComfyUI"),
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
|
||||||
}
|
IO.Boolean.Input("lossless", default=True),
|
||||||
|
IO.Int.Input("quality", default=80, min=0, max=100),
|
||||||
|
IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())),
|
||||||
|
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
|
||||||
|
],
|
||||||
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
@classmethod
|
||||||
FUNCTION = "save_images"
|
def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput:
|
||||||
|
return IO.NodeOutput(
|
||||||
|
ui=UI.ImageSaveHelper.get_save_animated_webp_ui(
|
||||||
|
images=images,
|
||||||
|
filename_prefix=filename_prefix,
|
||||||
|
cls=cls,
|
||||||
|
fps=fps,
|
||||||
|
lossless=lossless,
|
||||||
|
quality=quality,
|
||||||
|
method=cls.COMPRESS_METHODS.get(method)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
save_images = execute # TODO: remove
|
||||||
|
|
||||||
CATEGORY = "image/animation"
|
|
||||||
|
|
||||||
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
|
||||||
filename_prefix += self.prefix_append
|
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
|
||||||
results = list()
|
|
||||||
pil_images = []
|
|
||||||
for image in images:
|
|
||||||
i = 255. * image.cpu().numpy()
|
|
||||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
|
||||||
pil_images.append(img)
|
|
||||||
|
|
||||||
metadata = None
|
|
||||||
if not args.disable_metadata:
|
|
||||||
metadata = PngInfo()
|
|
||||||
if prompt is not None:
|
|
||||||
metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
|
|
||||||
if extra_pnginfo is not None:
|
|
||||||
for x in extra_pnginfo:
|
|
||||||
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
|
|
||||||
|
|
||||||
file = f"{filename}_{counter:05}_.png"
|
|
||||||
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
|
|
||||||
results.append({
|
|
||||||
"filename": file,
|
|
||||||
"subfolder": subfolder,
|
|
||||||
"type": self.type
|
|
||||||
})
|
|
||||||
|
|
||||||
return { "ui": { "images": results, "animated": (True,)} }
|
|
||||||
|
|
||||||
class SVG:
|
|
||||||
"""
|
|
||||||
Stores SVG representations via a list of BytesIO objects.
|
|
||||||
"""
|
|
||||||
def __init__(self, data: list[BytesIO]):
|
|
||||||
self.data = data
|
|
||||||
|
|
||||||
def combine(self, other: 'SVG') -> 'SVG':
|
|
||||||
return SVG(self.data + other.data)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def combine_all(svgs: list['SVG']) -> 'SVG':
|
|
||||||
all_svgs_list: list[BytesIO] = []
|
|
||||||
for svg_item in svgs:
|
|
||||||
all_svgs_list.extend(svg_item.data)
|
|
||||||
return SVG(all_svgs_list)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageStitch:
|
class SaveAnimatedPNG(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="SaveAnimatedPNG",
|
||||||
|
category="image/animation",
|
||||||
|
inputs=[
|
||||||
|
IO.Image.Input("images"),
|
||||||
|
IO.String.Input("filename_prefix", default="ComfyUI"),
|
||||||
|
IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
|
||||||
|
IO.Int.Input("compress_level", default=4, min=0, max=9),
|
||||||
|
],
|
||||||
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput:
|
||||||
|
return IO.NodeOutput(
|
||||||
|
ui=UI.ImageSaveHelper.get_save_animated_png_ui(
|
||||||
|
images=images,
|
||||||
|
filename_prefix=filename_prefix,
|
||||||
|
cls=cls,
|
||||||
|
fps=fps,
|
||||||
|
compress_level=compress_level,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
save_images = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ImageStitch(IO.ComfyNode):
|
||||||
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
|
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="ImageStitch",
|
||||||
"image1": ("IMAGE",),
|
display_name="Image Stitch",
|
||||||
"direction": (["right", "down", "left", "up"], {"default": "right"}),
|
description="Stitches image2 to image1 in the specified direction.\n"
|
||||||
"match_image_size": ("BOOLEAN", {"default": True}),
|
"If image2 is not provided, returns image1 unchanged.\n"
|
||||||
"spacing_width": (
|
"Optional spacing can be added between images.",
|
||||||
"INT",
|
category="image/transform",
|
||||||
{"default": 0, "min": 0, "max": 1024, "step": 2},
|
inputs=[
|
||||||
),
|
IO.Image.Input("image1"),
|
||||||
"spacing_color": (
|
IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"),
|
||||||
["white", "black", "red", "green", "blue"],
|
IO.Boolean.Input("match_image_size", default=True),
|
||||||
{"default": "white"},
|
IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2),
|
||||||
),
|
IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white"),
|
||||||
},
|
IO.Image.Input("image2", optional=True),
|
||||||
"optional": {
|
],
|
||||||
"image2": ("IMAGE",),
|
outputs=[IO.Image.Output()],
|
||||||
},
|
)
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
@classmethod
|
||||||
FUNCTION = "stitch"
|
def execute(
|
||||||
CATEGORY = "image/transform"
|
cls,
|
||||||
DESCRIPTION = """
|
|
||||||
Stitches image2 to image1 in the specified direction.
|
|
||||||
If image2 is not provided, returns image1 unchanged.
|
|
||||||
Optional spacing can be added between images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def stitch(
|
|
||||||
self,
|
|
||||||
image1,
|
image1,
|
||||||
direction,
|
direction,
|
||||||
match_image_size,
|
match_image_size,
|
||||||
spacing_width,
|
spacing_width,
|
||||||
spacing_color,
|
spacing_color,
|
||||||
image2=None,
|
image2=None,
|
||||||
):
|
) -> IO.NodeOutput:
|
||||||
if image2 is None:
|
if image2 is None:
|
||||||
return (image1,)
|
return IO.NodeOutput(image1)
|
||||||
|
|
||||||
# Handle batch size differences
|
# Handle batch size differences
|
||||||
if image1.shape[0] != image2.shape[0]:
|
if image1.shape[0] != image2.shape[0]:
|
||||||
@ -412,36 +363,30 @@ Optional spacing can be added between images.
|
|||||||
images.insert(1, spacing)
|
images.insert(1, spacing)
|
||||||
|
|
||||||
concat_dim = 2 if direction in ["left", "right"] else 1
|
concat_dim = 2 if direction in ["left", "right"] else 1
|
||||||
return (torch.cat(images, dim=concat_dim),)
|
return IO.NodeOutput(torch.cat(images, dim=concat_dim))
|
||||||
|
|
||||||
|
stitch = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeAndPadImage(IO.ComfyNode):
|
||||||
|
|
||||||
class ResizeAndPadImage:
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="ResizeAndPadImage",
|
||||||
"image": ("IMAGE",),
|
category="image/transform",
|
||||||
"target_width": ("INT", {
|
inputs=[
|
||||||
"default": 512,
|
IO.Image.Input("image"),
|
||||||
"min": 1,
|
IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
"max": MAX_RESOLUTION,
|
IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
"step": 1
|
IO.Combo.Input("padding_color", options=["white", "black"]),
|
||||||
}),
|
IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"]),
|
||||||
"target_height": ("INT", {
|
],
|
||||||
"default": 512,
|
outputs=[IO.Image.Output()],
|
||||||
"min": 1,
|
)
|
||||||
"max": MAX_RESOLUTION,
|
|
||||||
"step": 1
|
|
||||||
}),
|
|
||||||
"padding_color": (["white", "black"],),
|
|
||||||
"interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
@classmethod
|
||||||
FUNCTION = "resize_and_pad"
|
def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput:
|
||||||
CATEGORY = "image/transform"
|
|
||||||
|
|
||||||
def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation):
|
|
||||||
batch_size, orig_height, orig_width, channels = image.shape
|
batch_size, orig_height, orig_width, channels = image.shape
|
||||||
|
|
||||||
scale_w = target_width / orig_width
|
scale_w = target_width / orig_width
|
||||||
@ -469,52 +414,47 @@ class ResizeAndPadImage:
|
|||||||
padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized
|
padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized
|
||||||
|
|
||||||
output = padded.permute(0, 2, 3, 1)
|
output = padded.permute(0, 2, 3, 1)
|
||||||
return (output,)
|
return IO.NodeOutput(output)
|
||||||
|
|
||||||
class SaveSVGNode:
|
resize_and_pad = execute # TODO: remove
|
||||||
"""
|
|
||||||
Save SVG files on disk.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
|
||||||
self.type = "output"
|
|
||||||
self.prefix_append = ""
|
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
class SaveSVGNode(IO.ComfyNode):
|
||||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
|
||||||
FUNCTION = "save_svg"
|
|
||||||
CATEGORY = "image/save" # Changed
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="SaveSVGNode",
|
||||||
"svg": ("SVG",), # Changed
|
description="Save SVG files on disk.",
|
||||||
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
category="image/save",
|
||||||
},
|
inputs=[
|
||||||
"hidden": {
|
IO.SVG.Input("svg"),
|
||||||
"prompt": "PROMPT",
|
IO.String.Input(
|
||||||
"extra_pnginfo": "EXTRA_PNGINFO"
|
"filename_prefix",
|
||||||
}
|
default="svg/ComfyUI",
|
||||||
}
|
tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
|
@classmethod
|
||||||
filename_prefix += self.prefix_append
|
def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput:
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||||
results = list()
|
results: list[UI.SavedResult] = []
|
||||||
|
|
||||||
# Prepare metadata JSON
|
# Prepare metadata JSON
|
||||||
metadata_dict = {}
|
metadata_dict = {}
|
||||||
if prompt is not None:
|
if cls.hidden.prompt is not None:
|
||||||
metadata_dict["prompt"] = prompt
|
metadata_dict["prompt"] = cls.hidden.prompt
|
||||||
if extra_pnginfo is not None:
|
if cls.hidden.extra_pnginfo is not None:
|
||||||
metadata_dict.update(extra_pnginfo)
|
metadata_dict.update(cls.hidden.extra_pnginfo)
|
||||||
|
|
||||||
# Convert metadata to JSON string
|
# Convert metadata to JSON string
|
||||||
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
|
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
|
||||||
|
|
||||||
|
|
||||||
for batch_number, svg_bytes in enumerate(svg.data):
|
for batch_number, svg_bytes in enumerate(svg.data):
|
||||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||||
file = f"{filename_with_batch_num}_{counter:05}_.svg"
|
file = f"{filename_with_batch_num}_{counter:05}_.svg"
|
||||||
@ -544,57 +484,64 @@ class SaveSVGNode:
|
|||||||
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
|
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
|
||||||
svg_file.write(svg_content.encode('utf-8'))
|
svg_file.write(svg_content.encode('utf-8'))
|
||||||
|
|
||||||
results.append({
|
results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output))
|
||||||
"filename": file,
|
|
||||||
"subfolder": subfolder,
|
|
||||||
"type": self.type
|
|
||||||
})
|
|
||||||
counter += 1
|
counter += 1
|
||||||
return { "ui": { "images": results } }
|
return IO.NodeOutput(ui={"images": results})
|
||||||
|
|
||||||
class GetImageSize:
|
save_svg = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class GetImageSize(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="GetImageSize",
|
||||||
"image": (IO.IMAGE,),
|
display_name="Get Image Size",
|
||||||
},
|
description="Returns width and height of the image, and passes it through unchanged.",
|
||||||
"hidden": {
|
category="image",
|
||||||
"unique_id": "UNIQUE_ID",
|
inputs=[
|
||||||
}
|
IO.Image.Input("image"),
|
||||||
}
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Int.Output(display_name="width"),
|
||||||
|
IO.Int.Output(display_name="height"),
|
||||||
|
IO.Int.Output(display_name="batch_size"),
|
||||||
|
],
|
||||||
|
hidden=[IO.Hidden.unique_id],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = (IO.INT, IO.INT, IO.INT)
|
@classmethod
|
||||||
RETURN_NAMES = ("width", "height", "batch_size")
|
def execute(cls, image) -> IO.NodeOutput:
|
||||||
FUNCTION = "get_size"
|
|
||||||
|
|
||||||
CATEGORY = "image"
|
|
||||||
DESCRIPTION = """Returns width and height of the image, and passes it through unchanged."""
|
|
||||||
|
|
||||||
def get_size(self, image, unique_id=None) -> tuple[int, int]:
|
|
||||||
height = image.shape[1]
|
height = image.shape[1]
|
||||||
width = image.shape[2]
|
width = image.shape[2]
|
||||||
batch_size = image.shape[0]
|
batch_size = image.shape[0]
|
||||||
|
|
||||||
# Send progress text to display size on the node
|
# Send progress text to display size on the node
|
||||||
if unique_id:
|
if cls.hidden.unique_id:
|
||||||
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id)
|
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id)
|
||||||
|
|
||||||
return width, height, batch_size
|
return IO.NodeOutput(width, height, batch_size)
|
||||||
|
|
||||||
|
get_size = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRotate(IO.ComfyNode):
|
||||||
|
|
||||||
class ImageRotate:
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "image": (IO.IMAGE,),
|
return IO.Schema(
|
||||||
"rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
|
node_id="ImageRotate",
|
||||||
}}
|
category="image/transform",
|
||||||
RETURN_TYPES = (IO.IMAGE,)
|
inputs=[
|
||||||
FUNCTION = "rotate"
|
IO.Image.Input("image"),
|
||||||
|
IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]),
|
||||||
|
],
|
||||||
|
outputs=[IO.Image.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "image/transform"
|
@classmethod
|
||||||
|
def execute(cls, image, rotation) -> IO.NodeOutput:
|
||||||
def rotate(self, image, rotation):
|
|
||||||
rotate_by = 0
|
rotate_by = 0
|
||||||
if rotation.startswith("90"):
|
if rotation.startswith("90"):
|
||||||
rotate_by = 1
|
rotate_by = 1
|
||||||
@ -604,41 +551,57 @@ class ImageRotate:
|
|||||||
rotate_by = 3
|
rotate_by = 3
|
||||||
|
|
||||||
image = torch.rot90(image, k=rotate_by, dims=[2, 1])
|
image = torch.rot90(image, k=rotate_by, dims=[2, 1])
|
||||||
return (image,)
|
return IO.NodeOutput(image)
|
||||||
|
|
||||||
|
rotate = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFlip(IO.ComfyNode):
|
||||||
|
|
||||||
class ImageFlip:
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "image": (IO.IMAGE,),
|
return IO.Schema(
|
||||||
"flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
|
node_id="ImageFlip",
|
||||||
}}
|
category="image/transform",
|
||||||
RETURN_TYPES = (IO.IMAGE,)
|
inputs=[
|
||||||
FUNCTION = "flip"
|
IO.Image.Input("image"),
|
||||||
|
IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]),
|
||||||
|
],
|
||||||
|
outputs=[IO.Image.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "image/transform"
|
@classmethod
|
||||||
|
def execute(cls, image, flip_method) -> IO.NodeOutput:
|
||||||
def flip(self, image, flip_method):
|
|
||||||
if flip_method.startswith("x"):
|
if flip_method.startswith("x"):
|
||||||
image = torch.flip(image, dims=[1])
|
image = torch.flip(image, dims=[1])
|
||||||
elif flip_method.startswith("y"):
|
elif flip_method.startswith("y"):
|
||||||
image = torch.flip(image, dims=[2])
|
image = torch.flip(image, dims=[2])
|
||||||
|
|
||||||
return (image,)
|
return IO.NodeOutput(image)
|
||||||
|
|
||||||
class ImageScaleToMaxDimension:
|
flip = execute # TODO: remove
|
||||||
upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"]
|
|
||||||
|
|
||||||
|
class ImageScaleToMaxDimension(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"image": ("IMAGE",),
|
return IO.Schema(
|
||||||
"upscale_method": (s.upscale_methods,),
|
node_id="ImageScaleToMaxDimension",
|
||||||
"largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}}
|
category="image/upscaling",
|
||||||
RETURN_TYPES = ("IMAGE",)
|
inputs=[
|
||||||
FUNCTION = "upscale"
|
IO.Image.Input("image"),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"upscale_method",
|
||||||
|
options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"],
|
||||||
|
),
|
||||||
|
IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||||
|
],
|
||||||
|
outputs=[IO.Image.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "image/upscaling"
|
@classmethod
|
||||||
|
def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput:
|
||||||
def upscale(self, image, upscale_method, largest_size):
|
|
||||||
height = image.shape[1]
|
height = image.shape[1]
|
||||||
width = image.shape[2]
|
width = image.shape[2]
|
||||||
|
|
||||||
@ -655,20 +618,30 @@ class ImageScaleToMaxDimension:
|
|||||||
samples = image.movedim(-1, 1)
|
samples = image.movedim(-1, 1)
|
||||||
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
|
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
|
||||||
s = s.movedim(1, -1)
|
s = s.movedim(1, -1)
|
||||||
return (s,)
|
return IO.NodeOutput(s)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
upscale = execute # TODO: remove
|
||||||
"ImageCrop": ImageCrop,
|
|
||||||
"RepeatImageBatch": RepeatImageBatch,
|
|
||||||
"ImageFromBatch": ImageFromBatch,
|
class ImagesExtension(ComfyExtension):
|
||||||
"ImageAddNoise": ImageAddNoise,
|
@override
|
||||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
return [
|
||||||
"SaveSVGNode": SaveSVGNode,
|
ImageCrop,
|
||||||
"ImageStitch": ImageStitch,
|
RepeatImageBatch,
|
||||||
"ResizeAndPadImage": ResizeAndPadImage,
|
ImageFromBatch,
|
||||||
"GetImageSize": GetImageSize,
|
ImageAddNoise,
|
||||||
"ImageRotate": ImageRotate,
|
SaveAnimatedWEBP,
|
||||||
"ImageFlip": ImageFlip,
|
SaveAnimatedPNG,
|
||||||
"ImageScaleToMaxDimension": ImageScaleToMaxDimension,
|
SaveSVGNode,
|
||||||
}
|
ImageStitch,
|
||||||
|
ResizeAndPadImage,
|
||||||
|
GetImageSize,
|
||||||
|
ImageRotate,
|
||||||
|
ImageFlip,
|
||||||
|
ImageScaleToMaxDimension,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ImagesExtension:
|
||||||
|
return ImagesExtension()
|
||||||
|
|||||||
@ -255,6 +255,7 @@ class LatentBatch(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="LatentBatch",
|
node_id="LatentBatch",
|
||||||
category="latent/batch",
|
category="latent/batch",
|
||||||
|
is_deprecated=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Latent.Input("samples1"),
|
io.Latent.Input("samples1"),
|
||||||
io.Latent.Input("samples2"),
|
io.Latent.Input("samples2"),
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
|
from __future__ import annotations
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
from comfy_api.latest import _io
|
from comfy_api.latest import _io
|
||||||
|
|
||||||
|
# sentinel for missing inputs
|
||||||
|
MISSING = object()
|
||||||
|
|
||||||
|
|
||||||
class SwitchNode(io.ComfyNode):
|
class SwitchNode(io.ComfyNode):
|
||||||
@ -14,6 +17,37 @@ class SwitchNode(io.ComfyNode):
|
|||||||
display_name="Switch",
|
display_name="Switch",
|
||||||
category="logic",
|
category="logic",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Boolean.Input("switch"),
|
||||||
|
io.MatchType.Input("on_false", template=template, lazy=True),
|
||||||
|
io.MatchType.Input("on_true", template=template, lazy=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.MatchType.Output(template=template, display_name="output"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_lazy_status(cls, switch, on_false=None, on_true=None):
|
||||||
|
if switch and on_true is None:
|
||||||
|
return ["on_true"]
|
||||||
|
if not switch and on_false is None:
|
||||||
|
return ["on_false"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, switch, on_true, on_false) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(on_true if switch else on_false)
|
||||||
|
|
||||||
|
|
||||||
|
class SoftSwitchNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
template = io.MatchType.Template("switch")
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ComfySoftSwitchNode",
|
||||||
|
display_name="Soft Switch",
|
||||||
|
category="logic",
|
||||||
|
is_experimental=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Boolean.Input("switch"),
|
io.Boolean.Input("switch"),
|
||||||
io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
|
io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
|
||||||
@ -25,14 +59,14 @@ class SwitchNode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_lazy_status(cls, switch, on_false=..., on_true=...):
|
def check_lazy_status(cls, switch, on_false=MISSING, on_true=MISSING):
|
||||||
# We use ... instead of None, as None is passed for connected-but-unevaluated inputs.
|
# We use MISSING instead of None, as None is passed for connected-but-unevaluated inputs.
|
||||||
# This trick allows us to ignore the value of the switch and still be able to run execute().
|
# This trick allows us to ignore the value of the switch and still be able to run execute().
|
||||||
|
|
||||||
# One of the inputs may be missing, in which case we need to evaluate the other input
|
# One of the inputs may be missing, in which case we need to evaluate the other input
|
||||||
if on_false is ...:
|
if on_false is MISSING:
|
||||||
return ["on_true"]
|
return ["on_true"]
|
||||||
if on_true is ...:
|
if on_true is MISSING:
|
||||||
return ["on_false"]
|
return ["on_false"]
|
||||||
# Normal lazy switch operation
|
# Normal lazy switch operation
|
||||||
if switch and on_true is None:
|
if switch and on_true is None:
|
||||||
@ -41,22 +75,50 @@ class SwitchNode(io.ComfyNode):
|
|||||||
return ["on_false"]
|
return ["on_false"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_inputs(cls, switch, on_false=..., on_true=...):
|
def validate_inputs(cls, switch, on_false=MISSING, on_true=MISSING):
|
||||||
# This check happens before check_lazy_status(), so we can eliminate the case where
|
# This check happens before check_lazy_status(), so we can eliminate the case where
|
||||||
# both inputs are missing.
|
# both inputs are missing.
|
||||||
if on_false is ... and on_true is ...:
|
if on_false is MISSING and on_true is MISSING:
|
||||||
return "At least one of on_false or on_true must be connected to Switch node"
|
return "At least one of on_false or on_true must be connected to Switch node"
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput:
|
def execute(cls, switch, on_true=MISSING, on_false=MISSING) -> io.NodeOutput:
|
||||||
if on_true is ...:
|
if on_true is MISSING:
|
||||||
return io.NodeOutput(on_false)
|
return io.NodeOutput(on_false)
|
||||||
if on_false is ...:
|
if on_false is MISSING:
|
||||||
return io.NodeOutput(on_true)
|
return io.NodeOutput(on_true)
|
||||||
return io.NodeOutput(on_true if switch else on_false)
|
return io.NodeOutput(on_true if switch else on_false)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomComboNode(io.ComfyNode):
|
||||||
|
"""
|
||||||
|
Frontend node that allows user to write their own options for a combo.
|
||||||
|
This is here to make sure the node has a backend-representation to avoid some annoyances.
|
||||||
|
"""
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CustomCombo",
|
||||||
|
display_name="Custom Combo",
|
||||||
|
category="utils",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[io.Combo.Input("choice", options=[])],
|
||||||
|
outputs=[io.String.Output()]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_inputs(cls, choice: io.Combo.Type) -> bool:
|
||||||
|
# NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs.
|
||||||
|
# I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined.
|
||||||
|
# I need to skip checking that the chosen combo option is in the options list, since those are defined by the user.
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, choice: io.Combo.Type) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(choice)
|
||||||
|
|
||||||
|
|
||||||
class DCTestNode(io.ComfyNode):
|
class DCTestNode(io.ComfyNode):
|
||||||
class DCValues(TypedDict):
|
class DCValues(TypedDict):
|
||||||
combo: str
|
combo: str
|
||||||
@ -72,14 +134,14 @@ class DCTestNode(io.ComfyNode):
|
|||||||
display_name="DCTest",
|
display_name="DCTest",
|
||||||
category="logic",
|
category="logic",
|
||||||
is_output_node=True,
|
is_output_node=True,
|
||||||
inputs=[_io.DynamicCombo.Input("combo", options=[
|
inputs=[io.DynamicCombo.Input("combo", options=[
|
||||||
_io.DynamicCombo.Option("option1", [io.String.Input("string")]),
|
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
|
||||||
_io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
|
io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
|
||||||
_io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
|
io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
|
||||||
_io.DynamicCombo.Option("option4", [
|
io.DynamicCombo.Option("option4", [
|
||||||
_io.DynamicCombo.Input("subcombo", options=[
|
io.DynamicCombo.Input("subcombo", options=[
|
||||||
_io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
|
io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
|
||||||
_io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
|
io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
|
||||||
])
|
])
|
||||||
])]
|
])]
|
||||||
)],
|
)],
|
||||||
@ -141,14 +203,65 @@ class AutogrowPrefixTestNode(io.ComfyNode):
|
|||||||
combined = ",".join([str(x) for x in vals])
|
combined = ",".join([str(x) for x in vals])
|
||||||
return io.NodeOutput(combined)
|
return io.NodeOutput(combined)
|
||||||
|
|
||||||
|
class ComboOutputTestNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ComboOptionTestNode",
|
||||||
|
display_name="ComboOptionTest",
|
||||||
|
category="logic",
|
||||||
|
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
|
||||||
|
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
|
||||||
|
outputs=[io.Combo.Output(), io.Combo.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, combo: io.Combo.Type, combo2: io.Combo.Type) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(combo, combo2)
|
||||||
|
|
||||||
|
class ConvertStringToComboNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ConvertStringToComboNode",
|
||||||
|
display_name="Convert String to Combo",
|
||||||
|
category="logic",
|
||||||
|
inputs=[io.String.Input("string")],
|
||||||
|
outputs=[io.Combo.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, string: str) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(string)
|
||||||
|
|
||||||
|
class InvertBooleanNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="InvertBooleanNode",
|
||||||
|
display_name="Invert Boolean",
|
||||||
|
category="logic",
|
||||||
|
inputs=[io.Boolean.Input("boolean")],
|
||||||
|
outputs=[io.Boolean.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, boolean: bool) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(not boolean)
|
||||||
|
|
||||||
class LogicExtension(ComfyExtension):
|
class LogicExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
# SwitchNode,
|
SwitchNode,
|
||||||
|
CustomComboNode,
|
||||||
|
# SoftSwitchNode,
|
||||||
|
# ConvertStringToComboNode,
|
||||||
# DCTestNode,
|
# DCTestNode,
|
||||||
# AutogrowNamesTestNode,
|
# AutogrowNamesTestNode,
|
||||||
# AutogrowPrefixTestNode,
|
# AutogrowPrefixTestNode,
|
||||||
|
# ComboOutputTestNode,
|
||||||
|
# InvertBooleanNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def comfy_entrypoint() -> LogicExtension:
|
async def comfy_entrypoint() -> LogicExtension:
|
||||||
|
|||||||
@ -81,6 +81,59 @@ class LTXVImgToVideo(io.ComfyNode):
|
|||||||
generate = execute # TODO: remove
|
generate = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVImgToVideoInplace(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVImgToVideoInplace",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Image.Input("image"),
|
||||||
|
io.Latent.Input("latent"),
|
||||||
|
io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
|
||||||
|
io.Boolean.Input("bypass", default=False, tooltip="Bypass the conditioning.")
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput:
|
||||||
|
if bypass:
|
||||||
|
return (latent,)
|
||||||
|
|
||||||
|
samples = latent["samples"]
|
||||||
|
_, height_scale_factor, width_scale_factor = (
|
||||||
|
vae.downscale_index_formula
|
||||||
|
)
|
||||||
|
|
||||||
|
batch, _, latent_frames, latent_height, latent_width = samples.shape
|
||||||
|
width = latent_width * width_scale_factor
|
||||||
|
height = latent_height * height_scale_factor
|
||||||
|
|
||||||
|
if image.shape[1] != height or image.shape[2] != width:
|
||||||
|
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
else:
|
||||||
|
pixels = image
|
||||||
|
encode_pixels = pixels[:, :, :, :3]
|
||||||
|
t = vae.encode(encode_pixels)
|
||||||
|
|
||||||
|
samples[:, :, :t.shape[2]] = t
|
||||||
|
|
||||||
|
conditioning_latent_frames_mask = torch.ones(
|
||||||
|
(batch, 1, latent_frames, 1, 1),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=samples.device,
|
||||||
|
)
|
||||||
|
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
|
||||||
|
|
||||||
|
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
|
||||||
|
|
||||||
|
generate = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
def conditioning_get_any_value(conditioning, key, default=None):
|
def conditioning_get_any_value(conditioning, key, default=None):
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
if key in t[1]:
|
if key in t[1]:
|
||||||
@ -106,12 +159,12 @@ def get_keyframe_idxs(cond):
|
|||||||
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
||||||
if keyframe_idxs is None:
|
if keyframe_idxs is None:
|
||||||
return None, 0
|
return None, 0
|
||||||
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
|
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
|
||||||
|
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
|
||||||
return keyframe_idxs, num_keyframes
|
return keyframe_idxs, num_keyframes
|
||||||
|
|
||||||
class LTXVAddGuide(io.ComfyNode):
|
class LTXVAddGuide(io.ComfyNode):
|
||||||
NUM_PREFIX_FRAMES = 2
|
PATCHIFIER = SymmetricPatchifier(1, start_end=True)
|
||||||
PATCHIFIER = SymmetricPatchifier(1)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -182,26 +235,35 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
|
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128):
|
||||||
_, latent_idx = cls.get_latent_index(
|
if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels:
|
||||||
cond=positive,
|
raise ValueError("Adding guide to a combined AV latent is not supported.")
|
||||||
latent_length=latent_image.shape[2],
|
|
||||||
guide_length=guiding_latent.shape[2],
|
|
||||||
frame_idx=frame_idx,
|
|
||||||
scale_factors=scale_factors,
|
|
||||||
)
|
|
||||||
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
|
|
||||||
|
|
||||||
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
||||||
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||||
|
|
||||||
mask = torch.full(
|
if guide_mask is not None:
|
||||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
target_h = max(noise_mask.shape[3], guide_mask.shape[3])
|
||||||
1.0 - strength,
|
target_w = max(noise_mask.shape[4], guide_mask.shape[4])
|
||||||
dtype=noise_mask.dtype,
|
|
||||||
device=noise_mask.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
if noise_mask.shape[3] == 1 or noise_mask.shape[4] == 1:
|
||||||
|
noise_mask = noise_mask.expand(-1, -1, -1, target_h, target_w)
|
||||||
|
|
||||||
|
if guide_mask.shape[3] == 1 or guide_mask.shape[4] == 1:
|
||||||
|
guide_mask = guide_mask.expand(-1, -1, -1, target_h, target_w)
|
||||||
|
mask = guide_mask - strength
|
||||||
|
else:
|
||||||
|
mask = torch.full(
|
||||||
|
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||||
|
1.0 - strength,
|
||||||
|
dtype=noise_mask.dtype,
|
||||||
|
device=noise_mask.device,
|
||||||
|
)
|
||||||
|
# This solves audio video combined latent case where latent_image has audio latent concatenated
|
||||||
|
# in channel dimension with video latent. The solution is to pad guiding latent accordingly.
|
||||||
|
if latent_image.shape[1] > guiding_latent.shape[1]:
|
||||||
|
pad_len = latent_image.shape[1] - guiding_latent.shape[1]
|
||||||
|
guiding_latent = torch.nn.functional.pad(guiding_latent, pad=(0, 0, 0, 0, 0, 0, 0, pad_len), value=0)
|
||||||
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
|
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
|
||||||
noise_mask = torch.cat([noise_mask, mask], dim=2)
|
noise_mask = torch.cat([noise_mask, mask], dim=2)
|
||||||
return positive, negative, latent_image, noise_mask
|
return positive, negative, latent_image, noise_mask
|
||||||
@ -238,33 +300,17 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||||
|
|
||||||
num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2])
|
|
||||||
|
|
||||||
positive, negative, latent_image, noise_mask = cls.append_keyframe(
|
positive, negative, latent_image, noise_mask = cls.append_keyframe(
|
||||||
positive,
|
positive,
|
||||||
negative,
|
negative,
|
||||||
frame_idx,
|
frame_idx,
|
||||||
latent_image,
|
latent_image,
|
||||||
noise_mask,
|
noise_mask,
|
||||||
t[:, :, :num_prefix_frames],
|
t,
|
||||||
strength,
|
strength,
|
||||||
scale_factors,
|
scale_factors,
|
||||||
)
|
)
|
||||||
|
|
||||||
latent_idx += num_prefix_frames
|
|
||||||
|
|
||||||
t = t[:, :, num_prefix_frames:]
|
|
||||||
if t.shape[2] == 0:
|
|
||||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
|
||||||
|
|
||||||
latent_image, noise_mask = cls.replace_latent_frames(
|
|
||||||
latent_image,
|
|
||||||
noise_mask,
|
|
||||||
t,
|
|
||||||
latent_idx,
|
|
||||||
strength,
|
|
||||||
)
|
|
||||||
|
|
||||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||||
|
|
||||||
generate = execute # TODO: remove
|
generate = execute # TODO: remove
|
||||||
@ -507,18 +553,90 @@ class LTXVPreprocess(io.ComfyNode):
|
|||||||
|
|
||||||
preprocess = execute # TODO: remove
|
preprocess = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
import comfy.nested_tensor
|
||||||
|
class LTXVConcatAVLatent(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVConcatAVLatent",
|
||||||
|
category="latent/video/ltxv",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("video_latent"),
|
||||||
|
io.Latent.Input("audio_latent"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, video_latent, audio_latent) -> io.NodeOutput:
|
||||||
|
output = {}
|
||||||
|
output.update(video_latent)
|
||||||
|
output.update(audio_latent)
|
||||||
|
video_noise_mask = video_latent.get("noise_mask", None)
|
||||||
|
audio_noise_mask = audio_latent.get("noise_mask", None)
|
||||||
|
|
||||||
|
if video_noise_mask is not None or audio_noise_mask is not None:
|
||||||
|
if video_noise_mask is None:
|
||||||
|
video_noise_mask = torch.ones_like(video_latent["samples"])
|
||||||
|
if audio_noise_mask is None:
|
||||||
|
audio_noise_mask = torch.ones_like(audio_latent["samples"])
|
||||||
|
output["noise_mask"] = comfy.nested_tensor.NestedTensor((video_noise_mask, audio_noise_mask))
|
||||||
|
|
||||||
|
output["samples"] = comfy.nested_tensor.NestedTensor((video_latent["samples"], audio_latent["samples"]))
|
||||||
|
|
||||||
|
return io.NodeOutput(output)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVSeparateAVLatent(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVSeparateAVLatent",
|
||||||
|
category="latent/video/ltxv",
|
||||||
|
description="LTXV Separate AV Latent",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("av_latent"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(display_name="video_latent"),
|
||||||
|
io.Latent.Output(display_name="audio_latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, av_latent) -> io.NodeOutput:
|
||||||
|
latents = av_latent["samples"].unbind()
|
||||||
|
video_latent = av_latent.copy()
|
||||||
|
video_latent["samples"] = latents[0]
|
||||||
|
audio_latent = av_latent.copy()
|
||||||
|
audio_latent["samples"] = latents[1]
|
||||||
|
if "noise_mask" in av_latent:
|
||||||
|
masks = av_latent["noise_mask"]
|
||||||
|
if masks is not None:
|
||||||
|
masks = masks.unbind()
|
||||||
|
video_latent["noise_mask"] = masks[0]
|
||||||
|
audio_latent["noise_mask"] = masks[1]
|
||||||
|
return io.NodeOutput(video_latent, audio_latent)
|
||||||
|
|
||||||
|
|
||||||
class LtxvExtension(ComfyExtension):
|
class LtxvExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
EmptyLTXVLatentVideo,
|
EmptyLTXVLatentVideo,
|
||||||
LTXVImgToVideo,
|
LTXVImgToVideo,
|
||||||
|
LTXVImgToVideoInplace,
|
||||||
ModelSamplingLTXV,
|
ModelSamplingLTXV,
|
||||||
LTXVConditioning,
|
LTXVConditioning,
|
||||||
LTXVScheduler,
|
LTXVScheduler,
|
||||||
LTXVAddGuide,
|
LTXVAddGuide,
|
||||||
LTXVPreprocess,
|
LTXVPreprocess,
|
||||||
LTXVCropGuides,
|
LTXVCropGuides,
|
||||||
|
LTXVConcatAVLatent,
|
||||||
|
LTXVSeparateAVLatent,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
216
comfy_extras/nodes_lt_audio.py
Normal file
216
comfy_extras/nodes_lt_audio.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
import folder_paths
|
||||||
|
import comfy.utils
|
||||||
|
import comfy.model_management
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVAudioVAELoader(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVAudioVAELoader",
|
||||||
|
display_name="LTXV Audio VAE Loader",
|
||||||
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input(
|
||||||
|
"ckpt_name",
|
||||||
|
options=folder_paths.get_filename_list("checkpoints"),
|
||||||
|
tooltip="Audio VAE checkpoint to load.",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
outputs=[io.Vae.Output(display_name="Audio VAE")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, ckpt_name: str) -> io.NodeOutput:
|
||||||
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
|
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||||
|
return io.NodeOutput(AudioVAE(sd, metadata))
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVAudioVAEEncode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVAudioVAEEncode",
|
||||||
|
display_name="LTXV Audio VAE Encode",
|
||||||
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
io.Audio.Input("audio", tooltip="The audio to be encoded."),
|
||||||
|
io.Vae.Input(
|
||||||
|
id="audio_vae",
|
||||||
|
display_name="Audio VAE",
|
||||||
|
tooltip="The Audio VAE model to use for encoding.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Latent.Output(display_name="Audio Latent")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput:
|
||||||
|
audio_latents = audio_vae.encode(audio)
|
||||||
|
return io.NodeOutput(
|
||||||
|
{
|
||||||
|
"samples": audio_latents,
|
||||||
|
"sample_rate": int(audio_vae.sample_rate),
|
||||||
|
"type": "audio",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVAudioVAEDecode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVAudioVAEDecode",
|
||||||
|
display_name="LTXV Audio VAE Decode",
|
||||||
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples", tooltip="The latent to be decoded."),
|
||||||
|
io.Vae.Input(
|
||||||
|
id="audio_vae",
|
||||||
|
display_name="Audio VAE",
|
||||||
|
tooltip="The Audio VAE model used for decoding the latent.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Audio.Output(display_name="Audio")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput:
|
||||||
|
audio_latent = samples["samples"]
|
||||||
|
if audio_latent.is_nested:
|
||||||
|
audio_latent = audio_latent.unbind()[-1]
|
||||||
|
audio = audio_vae.decode(audio_latent).to(audio_latent.device)
|
||||||
|
output_audio_sample_rate = audio_vae.output_sample_rate
|
||||||
|
return io.NodeOutput(
|
||||||
|
{
|
||||||
|
"waveform": audio,
|
||||||
|
"sample_rate": int(output_audio_sample_rate),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVEmptyLatentAudio(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVEmptyLatentAudio",
|
||||||
|
display_name="LTXV Empty Latent Audio",
|
||||||
|
category="latent/audio",
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input(
|
||||||
|
"frames_number",
|
||||||
|
default=97,
|
||||||
|
min=1,
|
||||||
|
max=1000,
|
||||||
|
step=1,
|
||||||
|
display_mode=io.NumberDisplay.number,
|
||||||
|
tooltip="Number of frames.",
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"frame_rate",
|
||||||
|
default=25,
|
||||||
|
min=1,
|
||||||
|
max=1000,
|
||||||
|
step=1,
|
||||||
|
display_mode=io.NumberDisplay.number,
|
||||||
|
tooltip="Number of frames per second.",
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"batch_size",
|
||||||
|
default=1,
|
||||||
|
min=1,
|
||||||
|
max=4096,
|
||||||
|
display_mode=io.NumberDisplay.number,
|
||||||
|
tooltip="The number of latent audio samples in the batch.",
|
||||||
|
),
|
||||||
|
io.Vae.Input(
|
||||||
|
id="audio_vae",
|
||||||
|
display_name="Audio VAE",
|
||||||
|
tooltip="The Audio VAE model to get configuration from.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Latent.Output(display_name="Latent")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(
|
||||||
|
cls,
|
||||||
|
frames_number: int,
|
||||||
|
frame_rate: int,
|
||||||
|
batch_size: int,
|
||||||
|
audio_vae: AudioVAE,
|
||||||
|
) -> io.NodeOutput:
|
||||||
|
"""Generate empty audio latents matching the reference pipeline structure."""
|
||||||
|
|
||||||
|
assert audio_vae is not None, "Audio VAE model is required"
|
||||||
|
|
||||||
|
z_channels = audio_vae.latent_channels
|
||||||
|
audio_freq = audio_vae.latent_frequency_bins
|
||||||
|
sampling_rate = int(audio_vae.sample_rate)
|
||||||
|
|
||||||
|
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate)
|
||||||
|
|
||||||
|
audio_latents = torch.zeros(
|
||||||
|
(batch_size, z_channels, num_audio_latents, audio_freq),
|
||||||
|
device=comfy.model_management.intermediate_device(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return io.NodeOutput(
|
||||||
|
{
|
||||||
|
"samples": audio_latents,
|
||||||
|
"sample_rate": sampling_rate,
|
||||||
|
"type": "audio",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXAVTextEncoderLoader(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXAVTextEncoderLoader",
|
||||||
|
display_name="LTXV Audio Text Encoder Loader",
|
||||||
|
category="advanced/loaders",
|
||||||
|
description="[Recipes]\n\nltxav: gemma 3 12B",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input(
|
||||||
|
"text_encoder",
|
||||||
|
options=folder_paths.get_filename_list("text_encoders"),
|
||||||
|
),
|
||||||
|
io.Combo.Input(
|
||||||
|
"ckpt_name",
|
||||||
|
options=folder_paths.get_filename_list("checkpoints"),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
outputs=[io.Clip.Output(display_name="Audio VAE")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, text_encoder, ckpt_name, device="default"):
|
||||||
|
clip_type = comfy.sd.CLIPType.LTXV
|
||||||
|
|
||||||
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder)
|
||||||
|
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
|
|
||||||
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||||
|
return io.NodeOutput(clip)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVAudioExtension(ComfyExtension):
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
LTXVAudioVAELoader,
|
||||||
|
LTXVAudioVAEEncode,
|
||||||
|
LTXVAudioVAEDecode,
|
||||||
|
LTXVEmptyLatentAudio,
|
||||||
|
LTXAVTextEncoderLoader,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ComfyExtension:
|
||||||
|
return LTXVAudioExtension()
|
||||||
75
comfy_extras/nodes_lt_upsampler.py
Normal file
75
comfy_extras/nodes_lt_upsampler.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
from comfy import model_management
|
||||||
|
import math
|
||||||
|
|
||||||
|
class LTXVLatentUpsampler:
|
||||||
|
"""
|
||||||
|
Upsamples a video latent by a factor of 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"samples": ("LATENT",),
|
||||||
|
"upscale_model": ("LATENT_UPSCALE_MODEL",),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "upsample_latent"
|
||||||
|
CATEGORY = "latent/video"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def upsample_latent(
|
||||||
|
self,
|
||||||
|
samples: dict,
|
||||||
|
upscale_model,
|
||||||
|
vae,
|
||||||
|
) -> tuple:
|
||||||
|
"""
|
||||||
|
Upsample the input latent using the provided model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
samples (dict): Input latent samples
|
||||||
|
upscale_model (LatentUpsampler): Loaded upscale model
|
||||||
|
vae: VAE model for normalization
|
||||||
|
auto_tiling (bool): Whether to automatically tile the input for processing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Tuple containing the upsampled latent
|
||||||
|
"""
|
||||||
|
device = model_management.get_torch_device()
|
||||||
|
memory_required = model_management.module_size(upscale_model)
|
||||||
|
|
||||||
|
model_dtype = next(upscale_model.parameters()).dtype
|
||||||
|
latents = samples["samples"]
|
||||||
|
input_dtype = latents.dtype
|
||||||
|
|
||||||
|
memory_required += math.prod(latents.shape) * 3000.0 # TODO: more accurate
|
||||||
|
model_management.free_memory(memory_required, device)
|
||||||
|
|
||||||
|
try:
|
||||||
|
upscale_model.to(device) # TODO: use the comfy model management system.
|
||||||
|
|
||||||
|
latents = latents.to(dtype=model_dtype, device=device)
|
||||||
|
|
||||||
|
"""Upsample latents without tiling."""
|
||||||
|
latents = vae.first_stage_model.per_channel_statistics.un_normalize(latents)
|
||||||
|
upsampled_latents = upscale_model(latents)
|
||||||
|
finally:
|
||||||
|
upscale_model.cpu()
|
||||||
|
|
||||||
|
upsampled_latents = vae.first_stage_model.per_channel_statistics.normalize(
|
||||||
|
upsampled_latents
|
||||||
|
)
|
||||||
|
upsampled_latents = upsampled_latents.to(dtype=input_dtype, device=model_management.intermediate_device())
|
||||||
|
return_dict = samples.copy()
|
||||||
|
return_dict["samples"] = upsampled_latents
|
||||||
|
return_dict.pop("noise_mask", None)
|
||||||
|
return (return_dict,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"LTXVLatentUpsampler": LTXVLatentUpsampler,
|
||||||
|
}
|
||||||
@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="Mahiro",
|
node_id="Mahiro",
|
||||||
display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
|
display_name="Mahiro CFG",
|
||||||
category="_for_testing",
|
category="_for_testing",
|
||||||
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
|
|||||||
@ -348,7 +348,7 @@ class ZImageControlPatch:
|
|||||||
if self.mask is None:
|
if self.mask is None:
|
||||||
mask_ = torch.zeros_like(inpaint_image_latent)[:, :1]
|
mask_ = torch.zeros_like(inpaint_image_latent)[:, :1]
|
||||||
else:
|
else:
|
||||||
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
|
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True).to(device=inpaint_image_latent.device), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
|
||||||
|
|
||||||
if latent_image is None:
|
if latent_image is None:
|
||||||
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5))
|
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5))
|
||||||
|
|||||||
@ -4,11 +4,15 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import math
|
import math
|
||||||
|
from enum import Enum
|
||||||
|
from typing import TypedDict, Literal
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy_extras.nodes_latent import reshape_latent_to
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from nodes import MAX_RESOLUTION
|
||||||
|
|
||||||
class Blend(io.ComfyNode):
|
class Blend(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -241,6 +245,353 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
|||||||
s = s.movedim(1,-1)
|
s = s.movedim(1,-1)
|
||||||
return io.NodeOutput(s)
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
|
class ResizeType(str, Enum):
|
||||||
|
SCALE_BY = "scale by multiplier"
|
||||||
|
SCALE_DIMENSIONS = "scale dimensions"
|
||||||
|
SCALE_LONGER_DIMENSION = "scale longer dimension"
|
||||||
|
SCALE_SHORTER_DIMENSION = "scale shorter dimension"
|
||||||
|
SCALE_WIDTH = "scale width"
|
||||||
|
SCALE_HEIGHT = "scale height"
|
||||||
|
SCALE_TOTAL_PIXELS = "scale total pixels"
|
||||||
|
MATCH_SIZE = "match size"
|
||||||
|
|
||||||
|
def is_image(input: torch.Tensor) -> bool:
|
||||||
|
# images have 4 dimensions: [batch, height, width, channels]
|
||||||
|
# masks have 3 dimensions: [batch, height, width]
|
||||||
|
return len(input.shape) == 4
|
||||||
|
|
||||||
|
def init_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor:
|
||||||
|
if is_type_image:
|
||||||
|
input = input.movedim(-1, 1)
|
||||||
|
else:
|
||||||
|
input = input.unsqueeze(1)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def finalize_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor:
|
||||||
|
if is_type_image:
|
||||||
|
input = input.movedim(1, -1)
|
||||||
|
else:
|
||||||
|
input = input.squeeze(1)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def scale_by(input: torch.Tensor, multiplier: float, scale_method: str) -> torch.Tensor:
|
||||||
|
is_type_image = is_image(input)
|
||||||
|
input = init_image_mask_input(input, is_type_image)
|
||||||
|
width = round(input.shape[-1] * multiplier)
|
||||||
|
height = round(input.shape[-2] * multiplier)
|
||||||
|
|
||||||
|
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||||
|
input = finalize_image_mask_input(input, is_type_image)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def scale_dimensions(input: torch.Tensor, width: int, height: int, scale_method: str, crop: str="disabled") -> torch.Tensor:
|
||||||
|
if width == 0 and height == 0:
|
||||||
|
return input
|
||||||
|
is_type_image = is_image(input)
|
||||||
|
input = init_image_mask_input(input, is_type_image)
|
||||||
|
|
||||||
|
if width == 0:
|
||||||
|
width = max(1, round(input.shape[-1] * height / input.shape[-2]))
|
||||||
|
elif height == 0:
|
||||||
|
height = max(1, round(input.shape[-2] * width / input.shape[-1]))
|
||||||
|
|
||||||
|
input = comfy.utils.common_upscale(input, width, height, scale_method, crop)
|
||||||
|
input = finalize_image_mask_input(input, is_type_image)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def scale_longer_dimension(input: torch.Tensor, longer_size: int, scale_method: str) -> torch.Tensor:
|
||||||
|
is_type_image = is_image(input)
|
||||||
|
input = init_image_mask_input(input, is_type_image)
|
||||||
|
width = input.shape[-1]
|
||||||
|
height = input.shape[-2]
|
||||||
|
|
||||||
|
if height > width:
|
||||||
|
width = round((width / height) * longer_size)
|
||||||
|
height = longer_size
|
||||||
|
elif width > height:
|
||||||
|
height = round((height / width) * longer_size)
|
||||||
|
width = longer_size
|
||||||
|
else:
|
||||||
|
height = longer_size
|
||||||
|
width = longer_size
|
||||||
|
|
||||||
|
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||||
|
input = finalize_image_mask_input(input, is_type_image)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method: str) -> torch.Tensor:
|
||||||
|
is_type_image = is_image(input)
|
||||||
|
input = init_image_mask_input(input, is_type_image)
|
||||||
|
width = input.shape[-1]
|
||||||
|
height = input.shape[-2]
|
||||||
|
|
||||||
|
if height < width:
|
||||||
|
width = round((width / height) * shorter_size)
|
||||||
|
height = shorter_size
|
||||||
|
elif width > height:
|
||||||
|
height = round((height / width) * shorter_size)
|
||||||
|
width = shorter_size
|
||||||
|
else:
|
||||||
|
height = shorter_size
|
||||||
|
width = shorter_size
|
||||||
|
|
||||||
|
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||||
|
input = finalize_image_mask_input(input, is_type_image)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def scale_total_pixels(input: torch.Tensor, megapixels: float, scale_method: str) -> torch.Tensor:
|
||||||
|
is_type_image = is_image(input)
|
||||||
|
input = init_image_mask_input(input, is_type_image)
|
||||||
|
total = int(megapixels * 1024 * 1024)
|
||||||
|
|
||||||
|
scale_by = math.sqrt(total / (input.shape[-1] * input.shape[-2]))
|
||||||
|
width = round(input.shape[-1] * scale_by)
|
||||||
|
height = round(input.shape[-2] * scale_by)
|
||||||
|
|
||||||
|
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||||
|
input = finalize_image_mask_input(input, is_type_image)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str, crop: str) -> torch.Tensor:
|
||||||
|
is_type_image = is_image(input)
|
||||||
|
input = init_image_mask_input(input, is_type_image)
|
||||||
|
match = init_image_mask_input(match, is_image(match))
|
||||||
|
|
||||||
|
width = match.shape[-1]
|
||||||
|
height = match.shape[-2]
|
||||||
|
input = comfy.utils.common_upscale(input, width, height, scale_method, crop)
|
||||||
|
input = finalize_image_mask_input(input, is_type_image)
|
||||||
|
return input
|
||||||
|
|
||||||
|
class ResizeImageMaskNode(io.ComfyNode):
|
||||||
|
|
||||||
|
scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||||
|
crop_methods = ["disabled", "center"]
|
||||||
|
|
||||||
|
class ResizeTypedDict(TypedDict):
|
||||||
|
resize_type: ResizeType
|
||||||
|
scale_method: Literal["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||||
|
crop: Literal["disabled", "center"]
|
||||||
|
multiplier: float
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
longer_size: int
|
||||||
|
shorter_size: int
|
||||||
|
megapixels: float
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
template = io.MatchType.Template("input_type", [io.Image, io.Mask])
|
||||||
|
crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center")
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ResizeImageMaskNode",
|
||||||
|
display_name="Resize Image/Mask",
|
||||||
|
category="transform",
|
||||||
|
inputs=[
|
||||||
|
io.MatchType.Input("input", template=template),
|
||||||
|
io.DynamicCombo.Input("resize_type", options=[
|
||||||
|
io.DynamicCombo.Option(ResizeType.SCALE_BY, [
|
||||||
|
io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01),
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [
|
||||||
|
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||||
|
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||||
|
crop_combo,
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [
|
||||||
|
io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [
|
||||||
|
io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [
|
||||||
|
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [
|
||||||
|
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [
|
||||||
|
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [
|
||||||
|
io.MultiType.Input("match", [io.Image, io.Mask]),
|
||||||
|
crop_combo,
|
||||||
|
]),
|
||||||
|
]),
|
||||||
|
io.Combo.Input("scale_method", options=cls.scale_methods, default="area"),
|
||||||
|
],
|
||||||
|
outputs=[io.MatchType.Output(template=template, display_name="resized")]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, input: io.Image.Type | io.Mask.Type, scale_method: io.Combo.Type, resize_type: ResizeTypedDict) -> io.NodeOutput:
|
||||||
|
selected_type = resize_type["resize_type"]
|
||||||
|
if selected_type == ResizeType.SCALE_BY:
|
||||||
|
return io.NodeOutput(scale_by(input, resize_type["multiplier"], scale_method))
|
||||||
|
elif selected_type == ResizeType.SCALE_DIMENSIONS:
|
||||||
|
return io.NodeOutput(scale_dimensions(input, resize_type["width"], resize_type["height"], scale_method, resize_type["crop"]))
|
||||||
|
elif selected_type == ResizeType.SCALE_LONGER_DIMENSION:
|
||||||
|
return io.NodeOutput(scale_longer_dimension(input, resize_type["longer_size"], scale_method))
|
||||||
|
elif selected_type == ResizeType.SCALE_SHORTER_DIMENSION:
|
||||||
|
return io.NodeOutput(scale_shorter_dimension(input, resize_type["shorter_size"], scale_method))
|
||||||
|
elif selected_type == ResizeType.SCALE_WIDTH:
|
||||||
|
return io.NodeOutput(scale_dimensions(input, resize_type["width"], 0, scale_method))
|
||||||
|
elif selected_type == ResizeType.SCALE_HEIGHT:
|
||||||
|
return io.NodeOutput(scale_dimensions(input, 0, resize_type["height"], scale_method))
|
||||||
|
elif selected_type == ResizeType.SCALE_TOTAL_PIXELS:
|
||||||
|
return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method))
|
||||||
|
elif selected_type == ResizeType.MATCH_SIZE:
|
||||||
|
return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"]))
|
||||||
|
raise ValueError(f"Unsupported resize type: {selected_type}")
|
||||||
|
|
||||||
|
def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None:
|
||||||
|
if len(images) == 0:
|
||||||
|
return None
|
||||||
|
# first, get the max channels count
|
||||||
|
max_channels = max(image.shape[-1] for image in images)
|
||||||
|
# then, pad all images to have the same channels count
|
||||||
|
padded_images: list[torch.Tensor] = []
|
||||||
|
for image in images:
|
||||||
|
if image.shape[-1] < max_channels:
|
||||||
|
padded_images.append(torch.nn.functional.pad(image, (0,1), mode='constant', value=1.0))
|
||||||
|
else:
|
||||||
|
padded_images.append(image)
|
||||||
|
# resize all images to be the same size as the first image
|
||||||
|
resized_images: list[torch.Tensor] = []
|
||||||
|
first_image_shape = padded_images[0].shape
|
||||||
|
for image in padded_images:
|
||||||
|
if image.shape[1:] != first_image_shape[1:]:
|
||||||
|
resized_images.append(comfy.utils.common_upscale(image.movedim(-1,1), first_image_shape[2], first_image_shape[1], "bilinear", "center").movedim(1,-1))
|
||||||
|
else:
|
||||||
|
resized_images.append(image)
|
||||||
|
# batch the images in the format [b, h, w, c]
|
||||||
|
return torch.cat(resized_images, dim=0)
|
||||||
|
|
||||||
|
def batch_masks(masks: list[torch.Tensor]) -> torch.Tensor | None:
|
||||||
|
if len(masks) == 0:
|
||||||
|
return None
|
||||||
|
# resize all masks to be the same size as the first mask
|
||||||
|
resized_masks: list[torch.Tensor] = []
|
||||||
|
first_mask_shape = masks[0].shape
|
||||||
|
for mask in masks:
|
||||||
|
if mask.shape[1:] != first_mask_shape[1:]:
|
||||||
|
mask = init_image_mask_input(mask, is_type_image=False)
|
||||||
|
mask = comfy.utils.common_upscale(mask, first_mask_shape[2], first_mask_shape[1], "bilinear", "center")
|
||||||
|
resized_masks.append(finalize_image_mask_input(mask, is_type_image=False))
|
||||||
|
else:
|
||||||
|
resized_masks.append(mask)
|
||||||
|
# batch the masks in the format [b, h, w]
|
||||||
|
return torch.cat(resized_masks, dim=0)
|
||||||
|
|
||||||
|
def batch_latents(latents: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor] | None:
|
||||||
|
if len(latents) == 0:
|
||||||
|
return None
|
||||||
|
samples_out = latents[0].copy()
|
||||||
|
samples_out["batch_index"] = []
|
||||||
|
first_samples = latents[0]["samples"]
|
||||||
|
tensors: list[torch.Tensor] = []
|
||||||
|
for latent in latents:
|
||||||
|
# first, deal with latent tensors
|
||||||
|
tensors.append(reshape_latent_to(first_samples.shape, latent["samples"], repeat_batch=False))
|
||||||
|
# next, deal with batch_index
|
||||||
|
samples_out["batch_index"].extend(latent.get("batch_index", [x for x in range(0, latent["samples"].shape[0])]))
|
||||||
|
samples_out["samples"] = torch.cat(tensors, dim=0)
|
||||||
|
return samples_out
|
||||||
|
|
||||||
|
class BatchImagesNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
autogrow_template = io.Autogrow.TemplatePrefix(io.Image.Input("image"), prefix="image", min=2, max=50)
|
||||||
|
return io.Schema(
|
||||||
|
node_id="BatchImagesNode",
|
||||||
|
display_name="Batch Images",
|
||||||
|
category="image",
|
||||||
|
inputs=[
|
||||||
|
io.Autogrow.Input("images", template=autogrow_template)
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, images: io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(batch_images(list(images.values())))
|
||||||
|
|
||||||
|
class BatchMasksNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50)
|
||||||
|
return io.Schema(
|
||||||
|
node_id="BatchMasksNode",
|
||||||
|
display_name="Batch Masks",
|
||||||
|
category="mask",
|
||||||
|
inputs=[
|
||||||
|
io.Autogrow.Input("masks", template=autogrow_template)
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Mask.Output()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, masks: io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(batch_masks(list(masks.values())))
|
||||||
|
|
||||||
|
class BatchLatentsNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50)
|
||||||
|
return io.Schema(
|
||||||
|
node_id="BatchLatentsNode",
|
||||||
|
display_name="Batch Latents",
|
||||||
|
category="latent",
|
||||||
|
inputs=[
|
||||||
|
io.Autogrow.Input("latents", template=autogrow_template)
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, latents: io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(batch_latents(list(latents.values())))
|
||||||
|
|
||||||
|
class BatchImagesMasksLatentsNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
matchtype_template = io.MatchType.Template("input", allowed_types=[io.Image, io.Mask, io.Latent])
|
||||||
|
autogrow_template = io.Autogrow.TemplatePrefix(
|
||||||
|
io.MatchType.Input("input", matchtype_template),
|
||||||
|
prefix="input", min=1, max=50)
|
||||||
|
return io.Schema(
|
||||||
|
node_id="BatchImagesMasksLatentsNode",
|
||||||
|
display_name="Batch Images/Masks/Latents",
|
||||||
|
category="util",
|
||||||
|
inputs=[
|
||||||
|
io.Autogrow.Input("inputs", template=autogrow_template)
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.MatchType.Output(id=None, template=matchtype_template)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
batched = None
|
||||||
|
values = list(inputs.values())
|
||||||
|
# latents
|
||||||
|
if isinstance(values[0], dict):
|
||||||
|
batched = batch_latents(values)
|
||||||
|
# images
|
||||||
|
elif is_image(values[0]):
|
||||||
|
batched = batch_images(values)
|
||||||
|
# masks
|
||||||
|
else:
|
||||||
|
batched = batch_masks(values)
|
||||||
|
return io.NodeOutput(batched)
|
||||||
|
|
||||||
class PostProcessingExtension(ComfyExtension):
|
class PostProcessingExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -250,6 +601,11 @@ class PostProcessingExtension(ComfyExtension):
|
|||||||
Quantize,
|
Quantize,
|
||||||
Sharpen,
|
Sharpen,
|
||||||
ImageScaleToTotalPixels,
|
ImageScaleToTotalPixels,
|
||||||
|
ResizeImageMaskNode,
|
||||||
|
BatchImagesNode,
|
||||||
|
BatchMasksNode,
|
||||||
|
BatchLatentsNode,
|
||||||
|
# BatchImagesMasksLatentsNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def comfy_entrypoint() -> PostProcessingExtension:
|
async def comfy_entrypoint() -> PostProcessingExtension:
|
||||||
|
|||||||
@ -66,7 +66,7 @@ class Float(io.ComfyNode):
|
|||||||
display_name="Float",
|
display_name="Float",
|
||||||
category="utils/primitive",
|
category="utils/primitive",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize),
|
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize, step=0.1),
|
||||||
],
|
],
|
||||||
outputs=[io.Float.Output()],
|
outputs=[io.Float.Output()],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -3,7 +3,9 @@ import comfy.utils
|
|||||||
import math
|
import math
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
import comfy.model_management
|
||||||
|
import torch
|
||||||
|
import nodes
|
||||||
|
|
||||||
class TextEncodeQwenImageEdit(io.ComfyNode):
|
class TextEncodeQwenImageEdit(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -104,12 +106,37 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode):
|
|||||||
return io.NodeOutput(conditioning)
|
return io.NodeOutput(conditioning)
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyQwenImageLayeredLatentImage(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="EmptyQwenImageLayeredLatentImage",
|
||||||
|
display_name="Empty Qwen Image Layered Latent",
|
||||||
|
category="latent/qwen",
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("layers", default=3, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, width, height, layers, batch_size=1) -> io.NodeOutput:
|
||||||
|
latent = torch.zeros([batch_size, 16, layers + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
return io.NodeOutput({"samples": latent})
|
||||||
|
|
||||||
|
|
||||||
class QwenExtension(ComfyExtension):
|
class QwenExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
TextEncodeQwenImageEdit,
|
TextEncodeQwenImageEdit,
|
||||||
TextEncodeQwenImageEditPlus,
|
TextEncodeQwenImageEditPlus,
|
||||||
|
EmptyQwenImageLayeredLatentImage,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -78,18 +78,20 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
|||||||
overlap = 32
|
overlap = 32
|
||||||
|
|
||||||
oom = True
|
oom = True
|
||||||
while oom:
|
try:
|
||||||
try:
|
while oom:
|
||||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
try:
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
oom = False
|
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||||
except model_management.OOM_EXCEPTION as e:
|
oom = False
|
||||||
tile //= 2
|
except model_management.OOM_EXCEPTION as e:
|
||||||
if tile < 128:
|
tile //= 2
|
||||||
raise e
|
if tile < 128:
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
upscale_model.to("cpu")
|
||||||
|
|
||||||
upscale_model.to("cpu")
|
|
||||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
||||||
return io.NodeOutput(s)
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
|
|||||||
@ -817,7 +817,7 @@ def get_sample_indices(original_fps,
|
|||||||
if required_duration > total_frames / original_fps:
|
if required_duration > total_frames / original_fps:
|
||||||
raise ValueError("required_duration must be less than video length")
|
raise ValueError("required_duration must be less than video length")
|
||||||
|
|
||||||
if not fixed_start is None and fixed_start >= 0:
|
if fixed_start is not None and fixed_start >= 0:
|
||||||
start_frame = fixed_start
|
start_frame = fixed_start
|
||||||
else:
|
else:
|
||||||
max_start = total_frames - required_origin_frames
|
max_start = total_frames - required_origin_frames
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.5.1"
|
__version__ = "0.7.0"
|
||||||
|
|||||||
46
execution.py
46
execution.py
@ -79,7 +79,7 @@ class IsChangedCache:
|
|||||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||||
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
|
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
|
||||||
try:
|
try:
|
||||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
|
||||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -148,13 +148,12 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
|||||||
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
|
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
|
||||||
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||||
v3_data: io.V3Data = {}
|
v3_data: io.V3Data = {}
|
||||||
|
hidden_inputs_v3 = {}
|
||||||
|
valid_inputs = class_def.INPUT_TYPES()
|
||||||
if is_v3:
|
if is_v3:
|
||||||
valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
|
valid_inputs, hidden, v3_data = _io.get_finalized_class_inputs(valid_inputs, inputs)
|
||||||
else:
|
|
||||||
valid_inputs = class_def.INPUT_TYPES()
|
|
||||||
input_data_all = {}
|
input_data_all = {}
|
||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
hidden_inputs_v3 = {}
|
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||||
@ -180,18 +179,18 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
|||||||
input_data_all[x] = [input_data]
|
input_data_all[x] = [input_data]
|
||||||
|
|
||||||
if is_v3:
|
if is_v3:
|
||||||
if schema.hidden:
|
if hidden is not None:
|
||||||
if io.Hidden.prompt in schema.hidden:
|
if io.Hidden.prompt.name in hidden:
|
||||||
hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
|
hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
|
||||||
if io.Hidden.dynprompt in schema.hidden:
|
if io.Hidden.dynprompt.name in hidden:
|
||||||
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
|
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
|
||||||
if io.Hidden.extra_pnginfo in schema.hidden:
|
if io.Hidden.extra_pnginfo.name in hidden:
|
||||||
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
|
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
|
||||||
if io.Hidden.unique_id in schema.hidden:
|
if io.Hidden.unique_id.name in hidden:
|
||||||
hidden_inputs_v3[io.Hidden.unique_id] = unique_id
|
hidden_inputs_v3[io.Hidden.unique_id] = unique_id
|
||||||
if io.Hidden.auth_token_comfy_org in schema.hidden:
|
if io.Hidden.auth_token_comfy_org.name in hidden:
|
||||||
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||||
if io.Hidden.api_key_comfy_org in schema.hidden:
|
if io.Hidden.api_key_comfy_org.name in hidden:
|
||||||
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||||
else:
|
else:
|
||||||
if "hidden" in valid_inputs:
|
if "hidden" in valid_inputs:
|
||||||
@ -258,7 +257,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
|||||||
pre_execute_cb(index)
|
pre_execute_cb(index)
|
||||||
# V3
|
# V3
|
||||||
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
|
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
|
||||||
# if is just a class, then assign no resources or state, just create clone
|
# if is just a class, then assign no state, just create clone
|
||||||
if is_class(obj):
|
if is_class(obj):
|
||||||
type_obj = obj
|
type_obj = obj
|
||||||
obj.VALIDATE_CLASS()
|
obj.VALIDATE_CLASS()
|
||||||
@ -481,7 +480,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
else:
|
else:
|
||||||
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||||
if lazy_status_present:
|
if lazy_status_present:
|
||||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data)
|
# for check_lazy_status, the returned data should include the original key of the input
|
||||||
|
v3_data_lazy = v3_data.copy()
|
||||||
|
v3_data_lazy["create_dynamic_tuple"] = True
|
||||||
|
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data_lazy)
|
||||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||||
@ -599,6 +601,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
|
|
||||||
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
||||||
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
||||||
|
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
||||||
logging.error("Got an OOM, unloading all loaded models.")
|
logging.error("Got an OOM, unloading all loaded models.")
|
||||||
comfy.model_management.unload_all_models()
|
comfy.model_management.unload_all_models()
|
||||||
|
|
||||||
@ -756,10 +759,13 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
errors = []
|
errors = []
|
||||||
valid = True
|
valid = True
|
||||||
|
|
||||||
|
v3_data = None
|
||||||
validate_function_inputs = []
|
validate_function_inputs = []
|
||||||
validate_has_kwargs = False
|
validate_has_kwargs = False
|
||||||
if issubclass(obj_class, _ComfyNodeInternal):
|
if issubclass(obj_class, _ComfyNodeInternal):
|
||||||
class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
|
obj_class: _io._ComfyNodeBaseInternal
|
||||||
|
class_inputs = obj_class.INPUT_TYPES()
|
||||||
|
class_inputs, _, v3_data = _io.get_finalized_class_inputs(class_inputs, inputs)
|
||||||
validate_function_name = "validate_inputs"
|
validate_function_name = "validate_inputs"
|
||||||
validate_function = first_real_override(obj_class, validate_function_name)
|
validate_function = first_real_override(obj_class, validate_function_name)
|
||||||
else:
|
else:
|
||||||
@ -779,10 +785,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
assert extra_info is not None
|
assert extra_info is not None
|
||||||
if x not in inputs:
|
if x not in inputs:
|
||||||
if input_category == "required":
|
if input_category == "required":
|
||||||
|
details = f"{x}" if not v3_data else x.split(".")[-1]
|
||||||
error = {
|
error = {
|
||||||
"type": "required_input_missing",
|
"type": "required_input_missing",
|
||||||
"message": "Required input is missing",
|
"message": "Required input is missing",
|
||||||
"details": f"{x}",
|
"details": details,
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x
|
"input_name": x
|
||||||
}
|
}
|
||||||
@ -916,8 +923,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(input_type, list):
|
if isinstance(input_type, list) or input_type == io.Combo.io_type:
|
||||||
combo_options = input_type
|
if input_type == io.Combo.io_type:
|
||||||
|
combo_options = extra_info.get("options", [])
|
||||||
|
else:
|
||||||
|
combo_options = input_type
|
||||||
if val not in combo_options:
|
if val not in combo_options:
|
||||||
input_config = info
|
input_config = info
|
||||||
list_info = ""
|
list_info = ""
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
comfyui_manager==4.0.3b7
|
comfyui_manager==4.0.4
|
||||||
|
|||||||
26
nodes.py
26
nodes.py
@ -295,7 +295,11 @@ class VAEDecode:
|
|||||||
DESCRIPTION = "Decodes latent images back into pixel space images."
|
DESCRIPTION = "Decodes latent images back into pixel space images."
|
||||||
|
|
||||||
def decode(self, vae, samples):
|
def decode(self, vae, samples):
|
||||||
images = vae.decode(samples["samples"])
|
latent = samples["samples"]
|
||||||
|
if latent.is_nested:
|
||||||
|
latent = latent.unbind()[0]
|
||||||
|
|
||||||
|
images = vae.decode(latent)
|
||||||
if len(images.shape) == 5: #Combine batches
|
if len(images.shape) == 5: #Combine batches
|
||||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||||
return (images, )
|
return (images, )
|
||||||
@ -970,7 +974,7 @@ class DualCLIPLoader:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ),
|
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -980,7 +984,7 @@ class DualCLIPLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small"
|
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||||
@ -1663,8 +1667,6 @@ class LoadImage:
|
|||||||
output_masks = []
|
output_masks = []
|
||||||
w, h = None, None
|
w, h = None, None
|
||||||
|
|
||||||
excluded_formats = ['MPO']
|
|
||||||
|
|
||||||
for i in ImageSequence.Iterator(img):
|
for i in ImageSequence.Iterator(img):
|
||||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||||
|
|
||||||
@ -1692,7 +1694,10 @@ class LoadImage:
|
|||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
output_masks.append(mask.unsqueeze(0))
|
output_masks.append(mask.unsqueeze(0))
|
||||||
|
|
||||||
if len(output_images) > 1 and img.format not in excluded_formats:
|
if img.format == "MPO":
|
||||||
|
break # ignore all frames except the first one for MPO format
|
||||||
|
|
||||||
|
if len(output_images) > 1:
|
||||||
output_image = torch.cat(output_images, dim=0)
|
output_image = torch.cat(output_images, dim=0)
|
||||||
output_mask = torch.cat(output_masks, dim=0)
|
output_mask = torch.cat(output_masks, dim=0)
|
||||||
else:
|
else:
|
||||||
@ -1863,6 +1868,7 @@ class ImageBatch:
|
|||||||
FUNCTION = "batch"
|
FUNCTION = "batch"
|
||||||
|
|
||||||
CATEGORY = "image"
|
CATEGORY = "image"
|
||||||
|
DEPRECATED = True
|
||||||
|
|
||||||
def batch(self, image1, image2):
|
def batch(self, image1, image2):
|
||||||
if image1.shape[-1] != image2.shape[-1]:
|
if image1.shape[-1] != image2.shape[-1]:
|
||||||
@ -2241,8 +2247,10 @@ async def init_external_custom_nodes():
|
|||||||
|
|
||||||
for possible_module in possible_modules:
|
for possible_module in possible_modules:
|
||||||
module_path = os.path.join(custom_node_path, possible_module)
|
module_path = os.path.join(custom_node_path, possible_module)
|
||||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py":
|
||||||
if module_path.endswith(".disabled"): continue
|
continue
|
||||||
|
if module_path.endswith(".disabled"):
|
||||||
|
continue
|
||||||
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
||||||
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||||
continue
|
continue
|
||||||
@ -2327,6 +2335,8 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_mochi.py",
|
"nodes_mochi.py",
|
||||||
"nodes_slg.py",
|
"nodes_slg.py",
|
||||||
"nodes_mahiro.py",
|
"nodes_mahiro.py",
|
||||||
|
"nodes_lt_upsampler.py",
|
||||||
|
"nodes_lt_audio.py",
|
||||||
"nodes_lt.py",
|
"nodes_lt.py",
|
||||||
"nodes_hooks.py",
|
"nodes_hooks.py",
|
||||||
"nodes_load_3d.py",
|
"nodes_load_3d.py",
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.5.1"
|
version = "0.7.0"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.10"
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
homepage = "https://www.comfy.org/"
|
homepage = "https://www.comfy.org/"
|
||||||
@ -15,12 +15,16 @@ lint.select = [
|
|||||||
"N805", # invalid-first-argument-name-for-method
|
"N805", # invalid-first-argument-name-for-method
|
||||||
"S307", # suspicious-eval-usage
|
"S307", # suspicious-eval-usage
|
||||||
"S102", # exec
|
"S102", # exec
|
||||||
|
"E",
|
||||||
"T", # print-usage
|
"T", # print-usage
|
||||||
"W",
|
"W",
|
||||||
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
||||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||||
"F",
|
"F",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
lint.ignore = ["E501", "E722", "E731", "E712", "E402", "E741"]
|
||||||
|
|
||||||
exclude = ["*.ipynb", "**/generated/*.pyi"]
|
exclude = ["*.ipynb", "**/generated/*.pyi"]
|
||||||
|
|
||||||
[tool.pylint]
|
[tool.pylint]
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.34.9
|
comfyui-frontend-package==1.35.9
|
||||||
comfyui-workflow-templates==0.7.60
|
comfyui-workflow-templates==0.7.65
|
||||||
comfyui-embedded-docs==0.3.1
|
comfyui-embedded-docs==0.3.1
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
@ -324,7 +324,7 @@ class PromptServer():
|
|||||||
@routes.get("/models/{folder}")
|
@routes.get("/models/{folder}")
|
||||||
async def get_models(request):
|
async def get_models(request):
|
||||||
folder = request.match_info.get("folder", None)
|
folder = request.match_info.get("folder", None)
|
||||||
if not folder in folder_paths.folder_names_and_paths:
|
if folder not in folder_paths.folder_names_and_paths:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
files = folder_paths.get_filename_list(folder)
|
files = folder_paths.get_filename_list(folder)
|
||||||
return web.json_response(files)
|
return web.json_response(files)
|
||||||
@ -579,7 +579,7 @@ class PromptServer():
|
|||||||
folder_name = request.match_info.get("folder_name", None)
|
folder_name = request.match_info.get("folder_name", None)
|
||||||
if folder_name is None:
|
if folder_name is None:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
if not "filename" in request.rel_url.query:
|
if "filename" not in request.rel_url.query:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
filename = request.rel_url.query["filename"]
|
filename = request.rel_url.query["filename"]
|
||||||
@ -593,7 +593,7 @@ class PromptServer():
|
|||||||
if out is None:
|
if out is None:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
dt = json.loads(out)
|
dt = json.loads(out)
|
||||||
if not "__metadata__" in dt:
|
if "__metadata__" not in dt:
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
return web.json_response(dt["__metadata__"])
|
return web.json_response(dt["__metadata__"])
|
||||||
|
|
||||||
|
|||||||
@ -25,7 +25,7 @@ class TestImageStitch:
|
|||||||
|
|
||||||
result = node.stitch(image1, "right", True, 0, "white", image2=None)
|
result = node.stitch(image1, "right", True, 0, "white", image2=None)
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result.result) == 1
|
||||||
assert torch.equal(result[0], image1)
|
assert torch.equal(result[0], image1)
|
||||||
|
|
||||||
def test_basic_horizontal_stitch_right(self):
|
def test_basic_horizontal_stitch_right(self):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user