mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 18:13:01 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
483ba1e98b
@ -66,8 +66,10 @@ if branch is None:
|
|||||||
try:
|
try:
|
||||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||||
except:
|
except:
|
||||||
print("pulling.") # noqa: T201
|
print("fetching.") # noqa: T201
|
||||||
pull(repo)
|
for remote in repo.remotes:
|
||||||
|
if remote.name == "origin":
|
||||||
|
remote.fetch()
|
||||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||||
repo.checkout(ref)
|
repo.checkout(ref)
|
||||||
branch = repo.lookup_branch('master')
|
branch = repo.lookup_branch('master')
|
||||||
@ -149,3 +151,4 @@ try:
|
|||||||
shutil.copy(stable_update_script, stable_update_script_to)
|
shutil.copy(stable_update_script, stable_update_script_to)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,2 @@
|
|||||||
# Admins
|
# Admins
|
||||||
* @comfyanonymous
|
* @comfyanonymous @kosinkadink @guill
|
||||||
* @kosinkadink
|
|
||||||
|
|||||||
@ -81,6 +81,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||||
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
||||||
|
- [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
|
||||||
- Audio Models
|
- Audio Models
|
||||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||||
|
|||||||
@ -122,6 +122,12 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force
|
|||||||
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
|
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
|
||||||
|
|
||||||
|
|
||||||
|
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
|
||||||
|
manager_group = parser.add_mutually_exclusive_group()
|
||||||
|
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
|
||||||
|
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
|
||||||
|
|
||||||
|
|
||||||
vram_group = parser.add_mutually_exclusive_group()
|
vram_group = parser.add_mutually_exclusive_group()
|
||||||
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||||
@ -169,6 +175,7 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
|
|||||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||||
|
|
||||||
|
|
||||||
# The default built-in provider hosted under web/
|
# The default built-in provider hosted under web/
|
||||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||||
|
|
||||||
|
|||||||
@ -40,7 +40,8 @@ class ChromaParams:
|
|||||||
out_dim: int
|
out_dim: int
|
||||||
hidden_dim: int
|
hidden_dim: int
|
||||||
n_layers: int
|
n_layers: int
|
||||||
|
txt_ids_dims: list
|
||||||
|
vec_in_dim: int
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -57,6 +57,35 @@ class MLPEmbedder(nn.Module):
|
|||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return self.out_layer(self.silu(self.in_layer(x)))
|
return self.out_layer(self.silu(self.in_layer(x)))
|
||||||
|
|
||||||
|
class YakMLP(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
|
||||||
|
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
|
||||||
|
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
self.act_fn = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
return down_proj
|
||||||
|
|
||||||
|
def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||||
|
if yak_mlp:
|
||||||
|
return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
if mlp_silu_act:
|
||||||
|
return nn.Sequential(
|
||||||
|
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
||||||
|
SiLUActivation(),
|
||||||
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return nn.Sequential(
|
||||||
|
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
class RMSNorm(torch.nn.Module):
|
||||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||||
@ -140,7 +169,7 @@ class SiLUActivation(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DoubleStreamBlock(nn.Module):
|
class DoubleStreamBlock(nn.Module):
|
||||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
@ -156,18 +185,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
|
||||||
if mlp_silu_act:
|
self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||||
self.img_mlp = nn.Sequential(
|
|
||||||
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
|
||||||
SiLUActivation(),
|
|
||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.img_mlp = nn.Sequential(
|
|
||||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
|
||||||
nn.GELU(approximate="tanh"),
|
|
||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.modulation:
|
if self.modulation:
|
||||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||||
@ -177,18 +195,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
|
||||||
if mlp_silu_act:
|
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||||
self.txt_mlp = nn.Sequential(
|
|
||||||
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
|
||||||
SiLUActivation(),
|
|
||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.txt_mlp = nn.Sequential(
|
|
||||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
|
||||||
nn.GELU(approximate="tanh"),
|
|
||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.flipped_img_txt = flipped_img_txt
|
self.flipped_img_txt = flipped_img_txt
|
||||||
|
|
||||||
@ -275,6 +282,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
modulation=True,
|
modulation=True,
|
||||||
mlp_silu_act=False,
|
mlp_silu_act=False,
|
||||||
bias=True,
|
bias=True,
|
||||||
|
yak_mlp=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None
|
operations=None
|
||||||
@ -288,12 +296,17 @@ class SingleStreamBlock(nn.Module):
|
|||||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
|
||||||
self.mlp_hidden_dim_first = self.mlp_hidden_dim
|
self.mlp_hidden_dim_first = self.mlp_hidden_dim
|
||||||
|
self.yak_mlp = yak_mlp
|
||||||
if mlp_silu_act:
|
if mlp_silu_act:
|
||||||
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
|
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
|
||||||
self.mlp_act = SiLUActivation()
|
self.mlp_act = SiLUActivation()
|
||||||
else:
|
else:
|
||||||
self.mlp_act = nn.GELU(approximate="tanh")
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
|
|
||||||
|
if self.yak_mlp:
|
||||||
|
self.mlp_hidden_dim_first *= 2
|
||||||
|
self.mlp_act = nn.SiLU()
|
||||||
|
|
||||||
# qkv and mlp_in
|
# qkv and mlp_in
|
||||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
|
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
|
||||||
# proj and mlp_out
|
# proj and mlp_out
|
||||||
@ -325,6 +338,9 @@ class SingleStreamBlock(nn.Module):
|
|||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
|
if self.yak_mlp:
|
||||||
|
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||||
|
else:
|
||||||
mlp = self.mlp_act(mlp)
|
mlp = self.mlp_act(mlp)
|
||||||
output = self.linear2(torch.cat((attn, mlp), 2))
|
output = self.linear2(torch.cat((attn, mlp), 2))
|
||||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||||
|
|||||||
@ -15,7 +15,8 @@ from .layers import (
|
|||||||
MLPEmbedder,
|
MLPEmbedder,
|
||||||
SingleStreamBlock,
|
SingleStreamBlock,
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
Modulation
|
Modulation,
|
||||||
|
RMSNorm
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -34,11 +35,14 @@ class FluxParams:
|
|||||||
patch_size: int
|
patch_size: int
|
||||||
qkv_bias: bool
|
qkv_bias: bool
|
||||||
guidance_embed: bool
|
guidance_embed: bool
|
||||||
|
txt_ids_dims: list
|
||||||
global_modulation: bool = False
|
global_modulation: bool = False
|
||||||
mlp_silu_act: bool = False
|
mlp_silu_act: bool = False
|
||||||
ops_bias: bool = True
|
ops_bias: bool = True
|
||||||
default_ref_method: str = "offset"
|
default_ref_method: str = "offset"
|
||||||
ref_index_scale: float = 1.0
|
ref_index_scale: float = 1.0
|
||||||
|
yak_mlp: bool = False
|
||||||
|
txt_norm: bool = False
|
||||||
|
|
||||||
|
|
||||||
class Flux(nn.Module):
|
class Flux(nn.Module):
|
||||||
@ -76,6 +80,11 @@ class Flux(nn.Module):
|
|||||||
)
|
)
|
||||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
if params.txt_norm:
|
||||||
|
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
self.txt_norm = None
|
||||||
|
|
||||||
self.double_blocks = nn.ModuleList(
|
self.double_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
DoubleStreamBlock(
|
DoubleStreamBlock(
|
||||||
@ -86,6 +95,7 @@ class Flux(nn.Module):
|
|||||||
modulation=params.global_modulation is False,
|
modulation=params.global_modulation is False,
|
||||||
mlp_silu_act=params.mlp_silu_act,
|
mlp_silu_act=params.mlp_silu_act,
|
||||||
proj_bias=params.ops_bias,
|
proj_bias=params.ops_bias,
|
||||||
|
yak_mlp=params.yak_mlp,
|
||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
for _ in range(params.depth)
|
for _ in range(params.depth)
|
||||||
@ -94,7 +104,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
self.single_blocks = nn.ModuleList(
|
self.single_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||||
for _ in range(params.depth_single_blocks)
|
for _ in range(params.depth_single_blocks)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -150,6 +160,8 @@ class Flux(nn.Module):
|
|||||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||||
|
|
||||||
|
if self.txt_norm is not None:
|
||||||
|
txt = self.txt_norm(txt)
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
vec_orig = vec
|
vec_orig = vec
|
||||||
@ -332,8 +344,9 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||||
|
|
||||||
if len(self.params.axes_dim) == 4: # Flux 2
|
if len(self.params.txt_ids_dims) > 0:
|
||||||
txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
for i in self.params.txt_ids_dims:
|
||||||
|
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||||
|
|
||||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||||
out = out[:, :img_tokens]
|
out = out[:, :img_tokens]
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
|
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||||
|
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||||
import model_management, model_patcher
|
import model_management, model_patcher
|
||||||
|
|
||||||
class SRResidualCausalBlock3D(nn.Module):
|
class SRResidualCausalBlock3D(nn.Module):
|
||||||
|
|||||||
@ -1,42 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
|
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.ldm.models.autoencoder
|
import comfy.ldm.models.autoencoder
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
class NoPadConv3d(nn.Module):
|
|
||||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.conv(x)
|
|
||||||
|
|
||||||
|
|
||||||
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
|
|
||||||
|
|
||||||
x = xl[0]
|
|
||||||
xl.clear()
|
|
||||||
|
|
||||||
if conv_carry_out is not None:
|
|
||||||
to_push = x[:, :, -2:, :, :].clone()
|
|
||||||
conv_carry_out.append(to_push)
|
|
||||||
|
|
||||||
if isinstance(op, NoPadConv3d):
|
|
||||||
if conv_carry_in is None:
|
|
||||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
|
|
||||||
else:
|
|
||||||
carry_len = conv_carry_in[0].shape[2]
|
|
||||||
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
|
||||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
|
|
||||||
|
|
||||||
out = op(x)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class RMS_norm(nn.Module):
|
class RMS_norm(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
@ -49,7 +19,7 @@ class RMS_norm(nn.Module):
|
|||||||
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
|
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
class DnSmpl(nn.Module):
|
class DnSmpl(nn.Module):
|
||||||
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
|
def __init__(self, ic, oc, tds, refiner_vae, op):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
||||||
assert oc % fct == 0
|
assert oc % fct == 0
|
||||||
@ -109,7 +79,7 @@ class DnSmpl(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class UpSmpl(nn.Module):
|
class UpSmpl(nn.Module):
|
||||||
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
|
def __init__(self, ic, oc, tus, refiner_vae, op):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
||||||
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
|
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
|
||||||
@ -163,23 +133,6 @@ class UpSmpl(nn.Module):
|
|||||||
|
|
||||||
return h + x
|
return h + x
|
||||||
|
|
||||||
class HunyuanRefinerResnetBlock(ResnetBlock):
|
|
||||||
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
|
|
||||||
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
|
||||||
|
|
||||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
|
||||||
h = x
|
|
||||||
h = [ self.swish(self.norm1(x)) ]
|
|
||||||
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
|
||||||
|
|
||||||
h = [ self.dropout(self.swish(self.norm2(h))) ]
|
|
||||||
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
|
||||||
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
x = self.nin_shortcut(x)
|
|
||||||
|
|
||||||
return x+h
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||||
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
|
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
|
||||||
@ -191,7 +144,7 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
self.refiner_vae = refiner_vae
|
self.refiner_vae = refiner_vae
|
||||||
if self.refiner_vae:
|
if self.refiner_vae:
|
||||||
conv_op = NoPadConv3d
|
conv_op = CarriedConv3d
|
||||||
norm_op = RMS_norm
|
norm_op = RMS_norm
|
||||||
else:
|
else:
|
||||||
conv_op = ops.Conv3d
|
conv_op = ops.Conv3d
|
||||||
@ -206,8 +159,9 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
for i, tgt in enumerate(block_out_channels):
|
for i, tgt in enumerate(block_out_channels):
|
||||||
stage = nn.Module()
|
stage = nn.Module()
|
||||||
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||||
out_channels=tgt,
|
out_channels=tgt,
|
||||||
|
temb_channels=0,
|
||||||
conv_op=conv_op, norm_op=norm_op)
|
conv_op=conv_op, norm_op=norm_op)
|
||||||
for j in range(num_res_blocks)])
|
for j in range(num_res_blocks)])
|
||||||
ch = tgt
|
ch = tgt
|
||||||
@ -218,9 +172,9 @@ class Encoder(nn.Module):
|
|||||||
self.down.append(stage)
|
self.down.append(stage)
|
||||||
|
|
||||||
self.mid = nn.Module()
|
self.mid = nn.Module()
|
||||||
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||||
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||||
|
|
||||||
self.norm_out = norm_op(ch)
|
self.norm_out = norm_op(ch)
|
||||||
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
||||||
@ -246,22 +200,20 @@ class Encoder(nn.Module):
|
|||||||
conv_carry_out = []
|
conv_carry_out = []
|
||||||
if i == len(x) - 1:
|
if i == len(x) - 1:
|
||||||
conv_carry_out = None
|
conv_carry_out = None
|
||||||
|
|
||||||
x1 = [ x1 ]
|
x1 = [ x1 ]
|
||||||
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
for stage in self.down:
|
for stage in self.down:
|
||||||
for blk in stage.block:
|
for blk in stage.block:
|
||||||
x1 = blk(x1, conv_carry_in, conv_carry_out)
|
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
||||||
if hasattr(stage, 'downsample'):
|
if hasattr(stage, 'downsample'):
|
||||||
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
|
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
out.append(x1)
|
out.append(x1)
|
||||||
conv_carry_in = conv_carry_out
|
conv_carry_in = conv_carry_out
|
||||||
|
|
||||||
if len(out) > 1:
|
out = torch_cat_if_needed(out, dim=2)
|
||||||
out = torch.cat(out, dim=2)
|
|
||||||
else:
|
|
||||||
out = out[0]
|
|
||||||
|
|
||||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
|
||||||
del out
|
del out
|
||||||
@ -288,7 +240,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
self.refiner_vae = refiner_vae
|
self.refiner_vae = refiner_vae
|
||||||
if self.refiner_vae:
|
if self.refiner_vae:
|
||||||
conv_op = NoPadConv3d
|
conv_op = CarriedConv3d
|
||||||
norm_op = RMS_norm
|
norm_op = RMS_norm
|
||||||
else:
|
else:
|
||||||
conv_op = ops.Conv3d
|
conv_op = ops.Conv3d
|
||||||
@ -298,9 +250,9 @@ class Decoder(nn.Module):
|
|||||||
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
self.mid = nn.Module()
|
self.mid = nn.Module()
|
||||||
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||||
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||||
|
|
||||||
self.up = nn.ModuleList()
|
self.up = nn.ModuleList()
|
||||||
depth = (ffactor_spatial >> 1).bit_length()
|
depth = (ffactor_spatial >> 1).bit_length()
|
||||||
@ -308,8 +260,9 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
for i, tgt in enumerate(block_out_channels):
|
for i, tgt in enumerate(block_out_channels):
|
||||||
stage = nn.Module()
|
stage = nn.Module()
|
||||||
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||||
out_channels=tgt,
|
out_channels=tgt,
|
||||||
|
temb_channels=0,
|
||||||
conv_op=conv_op, norm_op=norm_op)
|
conv_op=conv_op, norm_op=norm_op)
|
||||||
for j in range(num_res_blocks + 1)])
|
for j in range(num_res_blocks + 1)])
|
||||||
ch = tgt
|
ch = tgt
|
||||||
@ -340,7 +293,7 @@ class Decoder(nn.Module):
|
|||||||
conv_carry_out = None
|
conv_carry_out = None
|
||||||
for stage in self.up:
|
for stage in self.up:
|
||||||
for blk in stage.block:
|
for blk in stage.block:
|
||||||
x1 = blk(x1, conv_carry_in, conv_carry_out)
|
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
||||||
if hasattr(stage, 'upsample'):
|
if hasattr(stage, 'upsample'):
|
||||||
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
@ -350,10 +303,7 @@ class Decoder(nn.Module):
|
|||||||
conv_carry_in = conv_carry_out
|
conv_carry_in = conv_carry_out
|
||||||
del x
|
del x
|
||||||
|
|
||||||
if len(out) > 1:
|
out = torch_cat_if_needed(out, dim=2)
|
||||||
out = torch.cat(out, dim=2)
|
|
||||||
else:
|
|
||||||
out = out[0]
|
|
||||||
|
|
||||||
if not self.refiner_vae:
|
if not self.refiner_vae:
|
||||||
if z.shape[-3] == 1:
|
if z.shape[-3] == 1:
|
||||||
|
|||||||
113
comfy/ldm/lumina/controlnet.py
Normal file
113
comfy/ldm/lumina/controlnet.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from .model import JointTransformerBlock
|
||||||
|
|
||||||
|
class ZImageControlTransformerBlock(JointTransformerBlock):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
multiple_of: int,
|
||||||
|
ffn_dim_multiplier: float,
|
||||||
|
norm_eps: float,
|
||||||
|
qk_norm: bool,
|
||||||
|
modulation=True,
|
||||||
|
block_id=0,
|
||||||
|
operation_settings=None,
|
||||||
|
):
|
||||||
|
super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
|
||||||
|
self.block_id = block_id
|
||||||
|
if block_id == 0:
|
||||||
|
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, c, x, **kwargs):
|
||||||
|
if self.block_id == 0:
|
||||||
|
c = self.before_proj(c) + x
|
||||||
|
c = super().forward(c, **kwargs)
|
||||||
|
c_skip = self.after_proj(c)
|
||||||
|
return c_skip, c
|
||||||
|
|
||||||
|
class ZImage_Control(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int = 3840,
|
||||||
|
n_heads: int = 30,
|
||||||
|
n_kv_heads: int = 30,
|
||||||
|
multiple_of: int = 256,
|
||||||
|
ffn_dim_multiplier: float = (8.0 / 3.0),
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
self.additional_in_dim = 0
|
||||||
|
self.control_in_dim = 16
|
||||||
|
n_refiner_layers = 2
|
||||||
|
self.n_control_layers = 6
|
||||||
|
self.control_layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ZImageControlTransformerBlock(
|
||||||
|
i,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
multiple_of,
|
||||||
|
ffn_dim_multiplier,
|
||||||
|
norm_eps,
|
||||||
|
qk_norm,
|
||||||
|
block_id=i,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
for i in range(self.n_control_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
all_x_embedder = {}
|
||||||
|
patch_size = 2
|
||||||
|
f_patch_size = 1
|
||||||
|
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
|
||||||
|
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||||
|
|
||||||
|
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||||
|
self.control_noise_refiner = nn.ModuleList(
|
||||||
|
[
|
||||||
|
JointTransformerBlock(
|
||||||
|
layer_id,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
multiple_of,
|
||||||
|
ffn_dim_multiplier,
|
||||||
|
norm_eps,
|
||||||
|
qk_norm,
|
||||||
|
modulation=True,
|
||||||
|
z_image_modulation=True,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
for layer_id in range(n_refiner_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
||||||
|
patch_size = 2
|
||||||
|
f_patch_size = 1
|
||||||
|
pH = pW = patch_size
|
||||||
|
B, C, H, W = control_context.shape
|
||||||
|
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||||
|
|
||||||
|
x_attn_mask = None
|
||||||
|
for layer in self.control_noise_refiner:
|
||||||
|
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||||
|
return control_context
|
||||||
|
|
||||||
|
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||||
|
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||||
@ -22,6 +22,10 @@ def modulate(x, scale):
|
|||||||
# Core NextDiT Model #
|
# Core NextDiT Model #
|
||||||
#############################################################################
|
#############################################################################
|
||||||
|
|
||||||
|
def clamp_fp16(x):
|
||||||
|
if x.dtype == torch.float16:
|
||||||
|
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
return x
|
||||||
|
|
||||||
class JointAttention(nn.Module):
|
class JointAttention(nn.Module):
|
||||||
"""Multi-head attention module."""
|
"""Multi-head attention module."""
|
||||||
@ -169,7 +173,7 @@ class FeedForward(nn.Module):
|
|||||||
|
|
||||||
# @torch.compile
|
# @torch.compile
|
||||||
def _forward_silu_gating(self, x1, x3):
|
def _forward_silu_gating(self, x1, x3):
|
||||||
return F.silu(x1) * x3
|
return clamp_fp16(F.silu(x1) * x3)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||||
@ -273,27 +277,27 @@ class JointTransformerBlock(nn.Module):
|
|||||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
||||||
|
|
||||||
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
||||||
self.attention(
|
clamp_fp16(self.attention(
|
||||||
modulate(self.attention_norm1(x), scale_msa),
|
modulate(self.attention_norm1(x), scale_msa),
|
||||||
x_mask,
|
x_mask,
|
||||||
freqs_cis,
|
freqs_cis,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
))
|
||||||
)
|
)
|
||||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||||
self.feed_forward(
|
clamp_fp16(self.feed_forward(
|
||||||
modulate(self.ffn_norm1(x), scale_mlp),
|
modulate(self.ffn_norm1(x), scale_mlp),
|
||||||
)
|
))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert adaln_input is None
|
assert adaln_input is None
|
||||||
x = x + self.attention_norm2(
|
x = x + self.attention_norm2(
|
||||||
self.attention(
|
clamp_fp16(self.attention(
|
||||||
self.attention_norm1(x),
|
self.attention_norm1(x),
|
||||||
x_mask,
|
x_mask,
|
||||||
freqs_cis,
|
freqs_cis,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
))
|
||||||
)
|
)
|
||||||
x = x + self.ffn_norm2(
|
x = x + self.ffn_norm2(
|
||||||
self.feed_forward(
|
self.feed_forward(
|
||||||
@ -564,7 +568,7 @@ class NextDiT(nn.Module):
|
|||||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||||
|
|
||||||
# def forward(self, x, t, cap_feats, cap_mask):
|
# def forward(self, x, t, cap_feats, cap_mask):
|
||||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
|
||||||
t = 1.0 - timesteps
|
t = 1.0 - timesteps
|
||||||
cap_feats = context
|
cap_feats = context
|
||||||
cap_mask = attention_mask
|
cap_mask = attention_mask
|
||||||
@ -581,16 +585,24 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||||
|
|
||||||
|
patches = transformer_options.get("patches", {})
|
||||||
transformer_options = kwargs.get("transformer_options", {})
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
x_is_tensor = isinstance(x, torch.Tensor)
|
x_is_tensor = isinstance(x, torch.Tensor)
|
||||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||||
freqs_cis = freqs_cis.to(x.device)
|
freqs_cis = freqs_cis.to(img.device)
|
||||||
|
|
||||||
for layer in self.layers:
|
for i, layer in enumerate(self.layers):
|
||||||
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||||
|
if "double_block" in patches:
|
||||||
|
for p in patches["double_block"]:
|
||||||
|
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||||
|
if "img" in out:
|
||||||
|
img[:, cap_size[0]:] = out["img"]
|
||||||
|
if "txt" in out:
|
||||||
|
img[:, :cap_size[0]] = out["txt"]
|
||||||
|
|
||||||
x = self.final_layer(x, adaln_input)
|
img = self.final_layer(img, adaln_input)
|
||||||
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||||
|
|
||||||
return -x
|
return -img
|
||||||
|
|
||||||
|
|||||||
@ -529,6 +529,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
exception_fallback = False
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
tensor_layout = "HND"
|
tensor_layout = "HND"
|
||||||
@ -553,6 +554,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||||
|
exception_fallback = True
|
||||||
|
if exception_fallback:
|
||||||
if tensor_layout == "NHD":
|
if tensor_layout == "NHD":
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.transpose(1, 2),
|
lambda t: t.transpose(1, 2),
|
||||||
|
|||||||
@ -13,6 +13,12 @@ if model_management.xformers_enabled_vae():
|
|||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
|
def torch_cat_if_needed(xl, dim):
|
||||||
|
if len(xl) > 1:
|
||||||
|
return torch.cat(xl, dim)
|
||||||
|
else:
|
||||||
|
return xl[0]
|
||||||
|
|
||||||
def get_timestep_embedding(timesteps, embedding_dim):
|
def get_timestep_embedding(timesteps, embedding_dim):
|
||||||
"""
|
"""
|
||||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||||
@ -43,6 +49,37 @@ def Normalize(in_channels, num_groups=32):
|
|||||||
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
|
||||||
|
class CarriedConv3d(nn.Module):
|
||||||
|
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
|
||||||
|
|
||||||
|
x = xl[0]
|
||||||
|
xl.clear()
|
||||||
|
|
||||||
|
if isinstance(op, CarriedConv3d):
|
||||||
|
if conv_carry_in is None:
|
||||||
|
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
|
||||||
|
else:
|
||||||
|
carry_len = conv_carry_in[0].shape[2]
|
||||||
|
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
|
||||||
|
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
||||||
|
|
||||||
|
if conv_carry_out is not None:
|
||||||
|
to_push = x[:, :, -2:, :, :].clone()
|
||||||
|
conv_carry_out.append(to_push)
|
||||||
|
|
||||||
|
out = op(x)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class VideoConv3d(nn.Module):
|
class VideoConv3d(nn.Module):
|
||||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
|
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -89,29 +126,24 @@ class Upsample(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||||
scale_factor = self.scale_factor
|
scale_factor = self.scale_factor
|
||||||
if isinstance(scale_factor, (int, float)):
|
if isinstance(scale_factor, (int, float)):
|
||||||
scale_factor = (scale_factor,) * (x.ndim - 2)
|
scale_factor = (scale_factor,) * (x.ndim - 2)
|
||||||
|
|
||||||
if x.ndim == 5 and scale_factor[0] > 1.0:
|
if x.ndim == 5 and scale_factor[0] > 1.0:
|
||||||
t = x.shape[2]
|
results = []
|
||||||
if t > 1:
|
if conv_carry_in is None:
|
||||||
a, b = x.split((1, t - 1), dim=2)
|
first = x[:, :, :1, :, :]
|
||||||
del x
|
results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
|
||||||
b = interpolate_up(b, scale_factor)
|
x = x[:, :, 1:, :, :]
|
||||||
else:
|
if x.shape[2] > 0:
|
||||||
a = x
|
results.append(interpolate_up(x, scale_factor))
|
||||||
|
x = torch_cat_if_needed(results, dim=2)
|
||||||
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
|
|
||||||
if t > 1:
|
|
||||||
x = torch.cat((a, b), dim=2)
|
|
||||||
else:
|
|
||||||
x = a
|
|
||||||
else:
|
else:
|
||||||
x = interpolate_up(x, scale_factor)
|
x = interpolate_up(x, scale_factor)
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
x = self.conv(x)
|
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -127,12 +159,15 @@ class Downsample(nn.Module):
|
|||||||
stride=stride,
|
stride=stride,
|
||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
if x.ndim == 4:
|
if isinstance(self.conv, CarriedConv3d):
|
||||||
|
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||||
|
elif x.ndim == 4:
|
||||||
pad = (0, 1, 0, 1)
|
pad = (0, 1, 0, 1)
|
||||||
mode = "constant"
|
mode = "constant"
|
||||||
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
|
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
|
||||||
|
x = self.conv(x)
|
||||||
elif x.ndim == 5:
|
elif x.ndim == 5:
|
||||||
pad = (1, 1, 1, 1, 2, 0)
|
pad = (1, 1, 1, 1, 2, 0)
|
||||||
mode = "replicate"
|
mode = "replicate"
|
||||||
@ -183,23 +218,23 @@ class ResnetBlock(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
def forward(self, x, temb=None):
|
def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None):
|
||||||
h = x
|
h = x
|
||||||
h = self.norm1(h)
|
h = self.norm1(h)
|
||||||
h = self.swish(h)
|
h = [ self.swish(h) ]
|
||||||
h = self.conv1(h)
|
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||||
|
|
||||||
if temb is not None:
|
if temb is not None:
|
||||||
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
|
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
|
||||||
|
|
||||||
h = self.norm2(h)
|
h = self.norm2(h)
|
||||||
h = self.swish(h)
|
h = self.swish(h)
|
||||||
h = self.dropout(h)
|
h = [ self.dropout(h) ]
|
||||||
h = self.conv2(h)
|
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||||
|
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
if self.use_conv_shortcut:
|
if self.use_conv_shortcut:
|
||||||
x = self.conv_shortcut(x)
|
x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||||
else:
|
else:
|
||||||
x = self.nin_shortcut(x)
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
@ -279,6 +314,7 @@ def pytorch_attention(q, k, v):
|
|||||||
orig_shape = q.shape
|
orig_shape = q.shape
|
||||||
B = orig_shape[0]
|
B = orig_shape[0]
|
||||||
C = orig_shape[1]
|
C = orig_shape[1]
|
||||||
|
oom_fallback = False
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
@ -289,6 +325,8 @@ def pytorch_attention(q, k, v):
|
|||||||
out = out.transpose(2, 3).reshape(orig_shape)
|
out = out.transpose(2, 3).reshape(orig_shape)
|
||||||
except model_management.OOM_EXCEPTION:
|
except model_management.OOM_EXCEPTION:
|
||||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||||
|
oom_fallback = True
|
||||||
|
if oom_fallback:
|
||||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -517,8 +555,13 @@ class Encoder(nn.Module):
|
|||||||
self.num_res_blocks = num_res_blocks
|
self.num_res_blocks = num_res_blocks
|
||||||
self.resolution = resolution
|
self.resolution = resolution
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
self.carried = False
|
||||||
|
|
||||||
if conv3d:
|
if conv3d:
|
||||||
|
if not attn_resolutions:
|
||||||
|
conv_op = CarriedConv3d
|
||||||
|
self.carried = True
|
||||||
|
else:
|
||||||
conv_op = VideoConv3d
|
conv_op = VideoConv3d
|
||||||
mid_attn_conv_op = ops.Conv3d
|
mid_attn_conv_op = ops.Conv3d
|
||||||
else:
|
else:
|
||||||
@ -532,6 +575,7 @@ class Encoder(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
|
|
||||||
|
self.time_compress = 1
|
||||||
curr_res = resolution
|
curr_res = resolution
|
||||||
in_ch_mult = (1,)+tuple(ch_mult)
|
in_ch_mult = (1,)+tuple(ch_mult)
|
||||||
self.in_ch_mult = in_ch_mult
|
self.in_ch_mult = in_ch_mult
|
||||||
@ -558,10 +602,15 @@ class Encoder(nn.Module):
|
|||||||
if time_compress is not None:
|
if time_compress is not None:
|
||||||
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
|
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
|
||||||
stride = (1, 2, 2)
|
stride = (1, 2, 2)
|
||||||
|
else:
|
||||||
|
self.time_compress *= 2
|
||||||
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
|
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
|
||||||
curr_res = curr_res // 2
|
curr_res = curr_res // 2
|
||||||
self.down.append(down)
|
self.down.append(down)
|
||||||
|
|
||||||
|
if time_compress is not None:
|
||||||
|
self.time_compress = time_compress
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
self.mid = nn.Module()
|
self.mid = nn.Module()
|
||||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||||
@ -587,15 +636,42 @@ class Encoder(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# timestep embedding
|
# timestep embedding
|
||||||
temb = None
|
temb = None
|
||||||
|
|
||||||
|
if self.carried:
|
||||||
|
xl = [x[:, :, :1, :, :]]
|
||||||
|
if x.shape[2] > self.time_compress:
|
||||||
|
tc = self.time_compress
|
||||||
|
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
|
||||||
|
x = xl
|
||||||
|
else:
|
||||||
|
x = [x]
|
||||||
|
out = []
|
||||||
|
|
||||||
|
conv_carry_in = None
|
||||||
|
|
||||||
|
for i, x1 in enumerate(x):
|
||||||
|
conv_carry_out = []
|
||||||
|
if i == len(x) - 1:
|
||||||
|
conv_carry_out = None
|
||||||
|
|
||||||
# downsampling
|
# downsampling
|
||||||
h = self.conv_in(x)
|
x1 = [ x1 ]
|
||||||
|
h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
for i_level in range(self.num_resolutions):
|
for i_level in range(self.num_resolutions):
|
||||||
for i_block in range(self.num_res_blocks):
|
for i_block in range(self.num_res_blocks):
|
||||||
h = self.down[i_level].block[i_block](h, temb)
|
h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
|
||||||
if len(self.down[i_level].attn) > 0:
|
if len(self.down[i_level].attn) > 0:
|
||||||
h = self.down[i_level].attn[i_block](h)
|
assert i == 0 #carried should not happen if attn exists
|
||||||
|
h1 = self.down[i_level].attn[i_block](h1)
|
||||||
if i_level != self.num_resolutions-1:
|
if i_level != self.num_resolutions-1:
|
||||||
h = self.down[i_level].downsample(h)
|
h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
|
out.append(h1)
|
||||||
|
conv_carry_in = conv_carry_out
|
||||||
|
|
||||||
|
h = torch_cat_if_needed(out, dim=2)
|
||||||
|
del out
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
h = self.mid.block_1(h, temb)
|
h = self.mid.block_1(h, temb)
|
||||||
@ -604,15 +680,15 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
# end
|
# end
|
||||||
h = self.norm_out(h)
|
h = self.norm_out(h)
|
||||||
h = nonlinearity(h)
|
h = [ nonlinearity(h) ]
|
||||||
h = self.conv_out(h)
|
h = conv_carry_causal_3d(h, self.conv_out)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
resolution, z_channels, tanh_out=False, use_linear_attn=False,
|
||||||
conv_out_op=ops.Conv2d,
|
conv_out_op=ops.Conv2d,
|
||||||
resnet_op=ResnetBlock,
|
resnet_op=ResnetBlock,
|
||||||
attn_op=AttnBlock,
|
attn_op=AttnBlock,
|
||||||
@ -626,12 +702,18 @@ class Decoder(nn.Module):
|
|||||||
self.num_res_blocks = num_res_blocks
|
self.num_res_blocks = num_res_blocks
|
||||||
self.resolution = resolution
|
self.resolution = resolution
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.give_pre_end = give_pre_end
|
|
||||||
self.tanh_out = tanh_out
|
self.tanh_out = tanh_out
|
||||||
|
self.carried = False
|
||||||
|
|
||||||
if conv3d:
|
if conv3d:
|
||||||
|
if not attn_resolutions and resnet_op == ResnetBlock:
|
||||||
|
conv_op = CarriedConv3d
|
||||||
|
conv_out_op = CarriedConv3d
|
||||||
|
self.carried = True
|
||||||
|
else:
|
||||||
conv_op = VideoConv3d
|
conv_op = VideoConv3d
|
||||||
conv_out_op = VideoConv3d
|
conv_out_op = VideoConv3d
|
||||||
|
|
||||||
mid_attn_conv_op = ops.Conv3d
|
mid_attn_conv_op = ops.Conv3d
|
||||||
else:
|
else:
|
||||||
conv_op = ops.Conv2d
|
conv_op = ops.Conv2d
|
||||||
@ -706,29 +788,43 @@ class Decoder(nn.Module):
|
|||||||
temb = None
|
temb = None
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
h = self.conv_in(z)
|
h = conv_carry_causal_3d([z], self.conv_in)
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
h = self.mid.block_1(h, temb, **kwargs)
|
h = self.mid.block_1(h, temb, **kwargs)
|
||||||
h = self.mid.attn_1(h, **kwargs)
|
h = self.mid.attn_1(h, **kwargs)
|
||||||
h = self.mid.block_2(h, temb, **kwargs)
|
h = self.mid.block_2(h, temb, **kwargs)
|
||||||
|
|
||||||
|
if self.carried:
|
||||||
|
h = torch.split(h, 2, dim=2)
|
||||||
|
else:
|
||||||
|
h = [ h ]
|
||||||
|
out = []
|
||||||
|
|
||||||
|
conv_carry_in = None
|
||||||
|
|
||||||
# upsampling
|
# upsampling
|
||||||
|
for i, h1 in enumerate(h):
|
||||||
|
conv_carry_out = []
|
||||||
|
if i == len(h) - 1:
|
||||||
|
conv_carry_out = None
|
||||||
for i_level in reversed(range(self.num_resolutions)):
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
for i_block in range(self.num_res_blocks+1):
|
for i_block in range(self.num_res_blocks+1):
|
||||||
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
|
||||||
if len(self.up[i_level].attn) > 0:
|
if len(self.up[i_level].attn) > 0:
|
||||||
h = self.up[i_level].attn[i_block](h, **kwargs)
|
assert i == 0 #carried should not happen if attn exists
|
||||||
|
h1 = self.up[i_level].attn[i_block](h1, **kwargs)
|
||||||
if i_level != 0:
|
if i_level != 0:
|
||||||
h = self.up[i_level].upsample(h)
|
h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
# end
|
h1 = self.norm_out(h1)
|
||||||
if self.give_pre_end:
|
h1 = [ nonlinearity(h1) ]
|
||||||
return h
|
h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
h = self.norm_out(h)
|
|
||||||
h = nonlinearity(h)
|
|
||||||
h = self.conv_out(h, **kwargs)
|
|
||||||
if self.tanh_out:
|
if self.tanh_out:
|
||||||
h = torch.tanh(h)
|
h1 = torch.tanh(h1)
|
||||||
return h
|
out.append(h1)
|
||||||
|
conv_carry_in = conv_carry_out
|
||||||
|
|
||||||
|
out = torch_cat_if_needed(out, dim=2)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|||||||
@ -208,12 +208,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["theta"] = 2000
|
dit_config["theta"] = 2000
|
||||||
dit_config["out_channels"] = 128
|
dit_config["out_channels"] = 128
|
||||||
dit_config["global_modulation"] = True
|
dit_config["global_modulation"] = True
|
||||||
dit_config["vec_in_dim"] = None
|
|
||||||
dit_config["mlp_silu_act"] = True
|
dit_config["mlp_silu_act"] = True
|
||||||
dit_config["qkv_bias"] = False
|
dit_config["qkv_bias"] = False
|
||||||
dit_config["ops_bias"] = False
|
dit_config["ops_bias"] = False
|
||||||
dit_config["default_ref_method"] = "index"
|
dit_config["default_ref_method"] = "index"
|
||||||
dit_config["ref_index_scale"] = 10.0
|
dit_config["ref_index_scale"] = 10.0
|
||||||
|
dit_config["txt_ids_dims"] = [3]
|
||||||
patch_size = 1
|
patch_size = 1
|
||||||
else:
|
else:
|
||||||
dit_config["image_model"] = "flux"
|
dit_config["image_model"] = "flux"
|
||||||
@ -223,6 +223,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["theta"] = 10000
|
dit_config["theta"] = 10000
|
||||||
dit_config["out_channels"] = 16
|
dit_config["out_channels"] = 16
|
||||||
dit_config["qkv_bias"] = True
|
dit_config["qkv_bias"] = True
|
||||||
|
dit_config["txt_ids_dims"] = []
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
|
|
||||||
dit_config["in_channels"] = 16
|
dit_config["in_channels"] = 16
|
||||||
@ -245,6 +246,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
||||||
if vec_in_key in state_dict_keys:
|
if vec_in_key in state_dict_keys:
|
||||||
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
||||||
|
else:
|
||||||
|
dit_config["vec_in_dim"] = None
|
||||||
|
|
||||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||||
@ -270,6 +273,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||||
else:
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
|
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||||
|
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
|
||||||
|
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||||
|
dit_config["txt_ids_dims"] = [1, 2]
|
||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
||||||
|
|||||||
@ -699,12 +699,12 @@ class ModelPatcher:
|
|||||||
offloaded = []
|
offloaded = []
|
||||||
offload_buffer = 0
|
offload_buffer = 0
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
for x in loading:
|
for i, x in enumerate(loading):
|
||||||
module_offload_mem, module_mem, n, m, params = x
|
module_offload_mem, module_mem, n, m, params = x
|
||||||
|
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
|
|
||||||
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
|
potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]]))
|
||||||
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
|
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
|
||||||
|
|
||||||
weight_key = "{}.weight".format(n)
|
weight_key = "{}.weight".format(n)
|
||||||
@ -876,14 +876,18 @@ class ModelPatcher:
|
|||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
unload_list = self._load_list()
|
unload_list = self._load_list()
|
||||||
unload_list.sort()
|
unload_list.sort()
|
||||||
|
|
||||||
offload_buffer = self.model.model_offload_buffer_memory
|
offload_buffer = self.model.model_offload_buffer_memory
|
||||||
|
if len(unload_list) > 0:
|
||||||
|
NS = comfy.model_management.NUM_STREAMS
|
||||||
|
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
|
||||||
|
|
||||||
for unload in unload_list:
|
for unload in unload_list:
|
||||||
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
|
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
|
||||||
break
|
break
|
||||||
module_offload_mem, module_mem, n, m, params = unload
|
module_offload_mem, module_mem, n, m, params = unload
|
||||||
|
|
||||||
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
|
potential_offload = module_offload_mem + sum(offload_weight_factor)
|
||||||
|
|
||||||
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
||||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||||
@ -935,6 +939,8 @@ class ModelPatcher:
|
|||||||
m.comfy_patched_weights = False
|
m.comfy_patched_weights = False
|
||||||
memory_freed += module_mem
|
memory_freed += module_mem
|
||||||
offload_buffer = max(offload_buffer, potential_offload)
|
offload_buffer = max(offload_buffer, potential_offload)
|
||||||
|
offload_weight_factor.append(module_mem)
|
||||||
|
offload_weight_factor.pop(0)
|
||||||
logging.debug("freed {}".format(n))
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
for param in params:
|
for param in params:
|
||||||
|
|||||||
27
comfy/ops.py
27
comfy/ops.py
@ -111,22 +111,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||||
|
|
||||||
if bias_has_function:
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
with wf_context:
|
|
||||||
|
bias_a = bias
|
||||||
|
weight_a = weight
|
||||||
|
|
||||||
|
if s.bias is not None:
|
||||||
for f in s.bias_function:
|
for f in s.bias_function:
|
||||||
bias = f(bias)
|
bias = f(bias)
|
||||||
|
|
||||||
if weight_has_function or weight.dtype != dtype:
|
if weight_has_function or weight.dtype != dtype:
|
||||||
with wf_context:
|
|
||||||
weight = weight.to(dtype=dtype)
|
weight = weight.to(dtype=dtype)
|
||||||
if isinstance(weight, QuantizedTensor):
|
if isinstance(weight, QuantizedTensor):
|
||||||
weight = weight.dequantize()
|
weight = weight.dequantize()
|
||||||
for f in s.weight_function:
|
for f in s.weight_function:
|
||||||
weight = f(weight)
|
weight = f(weight)
|
||||||
|
|
||||||
comfy.model_management.sync_stream(device, offload_stream)
|
|
||||||
if offloadable:
|
if offloadable:
|
||||||
return weight, bias, offload_stream
|
return weight, bias, (offload_stream, weight_a, bias_a)
|
||||||
else:
|
else:
|
||||||
#Legacy function signature
|
#Legacy function signature
|
||||||
return weight, bias
|
return weight, bias
|
||||||
@ -135,13 +137,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
def uncast_bias_weight(s, weight, bias, offload_stream):
|
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||||
if offload_stream is None:
|
if offload_stream is None:
|
||||||
return
|
return
|
||||||
if weight is not None:
|
os, weight_a, bias_a = offload_stream
|
||||||
device = weight.device
|
if os is None:
|
||||||
else:
|
|
||||||
if bias is None:
|
|
||||||
return
|
return
|
||||||
device = bias.device
|
if weight_a is not None:
|
||||||
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
device = weight_a.device
|
||||||
|
else:
|
||||||
|
if bias_a is None:
|
||||||
|
return
|
||||||
|
device = bias_a.device
|
||||||
|
os.wait_stream(comfy.model_management.current_stream(device))
|
||||||
|
|
||||||
|
|
||||||
class CastWeightBiasOp:
|
class CastWeightBiasOp:
|
||||||
|
|||||||
23
comfy/sd.py
23
comfy/sd.py
@ -53,6 +53,7 @@ import comfy.text_encoders.omnigen2
|
|||||||
import comfy.text_encoders.qwen_image
|
import comfy.text_encoders.qwen_image
|
||||||
import comfy.text_encoders.hunyuan_image
|
import comfy.text_encoders.hunyuan_image
|
||||||
import comfy.text_encoders.z_image
|
import comfy.text_encoders.z_image
|
||||||
|
import comfy.text_encoders.ovis
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -192,6 +193,7 @@ class CLIP:
|
|||||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||||
|
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||||
all_hooks.reset()
|
all_hooks.reset()
|
||||||
self.patcher.patch_hooks(None)
|
self.patcher.patch_hooks(None)
|
||||||
if show_pbar:
|
if show_pbar:
|
||||||
@ -239,6 +241,7 @@ class CLIP:
|
|||||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||||
|
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
cond, pooled = o[:2]
|
cond, pooled = o[:2]
|
||||||
if return_dict:
|
if return_dict:
|
||||||
@ -468,7 +471,7 @@ class VAE:
|
|||||||
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
|
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
|
||||||
|
|
||||||
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||||
elif "decoder.conv_in.conv.weight" in sd:
|
elif "decoder.conv_in.conv.weight" in sd:
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
ddconfig["conv3d"] = True
|
ddconfig["conv3d"] = True
|
||||||
@ -480,8 +483,10 @@ class VAE:
|
|||||||
self.latent_dim = 3
|
self.latent_dim = 3
|
||||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||||
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
#This is likely to significantly over-estimate with single image or low frame counts as the
|
||||||
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
#implementation is able to completely skip caching. Rework if used as an image only VAE
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
elif "decoder.unpatcher3d.wavelets" in sd:
|
elif "decoder.unpatcher3d.wavelets" in sd:
|
||||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
|
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
|
||||||
@ -956,6 +961,7 @@ class CLIPType(Enum):
|
|||||||
QWEN_IMAGE = 18
|
QWEN_IMAGE = 18
|
||||||
HUNYUAN_IMAGE = 19
|
HUNYUAN_IMAGE = 19
|
||||||
HUNYUAN_VIDEO_15 = 20
|
HUNYUAN_VIDEO_15 = 20
|
||||||
|
OVIS = 21
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
@ -987,6 +993,7 @@ class TEModel(Enum):
|
|||||||
MISTRAL3_24B = 14
|
MISTRAL3_24B = 14
|
||||||
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
||||||
QWEN3_4B = 16
|
QWEN3_4B = 16
|
||||||
|
QWEN3_2B = 17
|
||||||
|
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
@ -1020,9 +1027,12 @@ def detect_te_model(sd):
|
|||||||
if weight.shape[0] == 512:
|
if weight.shape[0] == 512:
|
||||||
return TEModel.QWEN25_7B
|
return TEModel.QWEN25_7B
|
||||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
|
||||||
return TEModel.QWEN3_4B
|
|
||||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||||
|
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||||
|
if weight.shape[0] == 2560:
|
||||||
|
return TEModel.QWEN3_4B
|
||||||
|
elif weight.shape[0] == 2048:
|
||||||
|
return TEModel.QWEN3_2B
|
||||||
if weight.shape[0] == 5120:
|
if weight.shape[0] == 5120:
|
||||||
if "model.layers.39.post_attention_layernorm.weight" in sd:
|
if "model.layers.39.post_attention_layernorm.weight" in sd:
|
||||||
return TEModel.MISTRAL3_24B
|
return TEModel.MISTRAL3_24B
|
||||||
@ -1150,6 +1160,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif te_model == TEModel.QWEN3_4B:
|
elif te_model == TEModel.QWEN3_4B:
|
||||||
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
|
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
|
||||||
|
elif te_model == TEModel.QWEN3_2B:
|
||||||
|
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
||||||
else:
|
else:
|
||||||
# clip_l
|
# clip_l
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
|
|||||||
@ -147,6 +147,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||||
self.return_projected_pooled = return_projected_pooled
|
self.return_projected_pooled = return_projected_pooled
|
||||||
self.return_attention_masks = return_attention_masks
|
self.return_attention_masks = return_attention_masks
|
||||||
|
self.execution_device = None
|
||||||
|
|
||||||
if layer == "hidden":
|
if layer == "hidden":
|
||||||
assert layer_idx is not None
|
assert layer_idx is not None
|
||||||
@ -163,6 +164,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
layer_idx = options.get("layer", self.layer_idx)
|
layer_idx = options.get("layer", self.layer_idx)
|
||||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||||
|
self.execution_device = options.get("execution_device", self.execution_device)
|
||||||
if isinstance(self.layer, list) or self.layer == "all":
|
if isinstance(self.layer, list) or self.layer == "all":
|
||||||
pass
|
pass
|
||||||
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||||
@ -175,6 +177,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
self.layer = self.options_default[0]
|
self.layer = self.options_default[0]
|
||||||
self.layer_idx = self.options_default[1]
|
self.layer_idx = self.options_default[1]
|
||||||
self.return_projected_pooled = self.options_default[2]
|
self.return_projected_pooled = self.options_default[2]
|
||||||
|
self.execution_device = None
|
||||||
|
|
||||||
def process_tokens(self, tokens, device):
|
def process_tokens(self, tokens, device):
|
||||||
end_token = self.special_tokens.get("end", None)
|
end_token = self.special_tokens.get("end", None)
|
||||||
@ -258,7 +261,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
|
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
|
||||||
|
|
||||||
def forward(self, tokens):
|
def forward(self, tokens):
|
||||||
|
if self.execution_device is None:
|
||||||
device = self.transformer.get_input_embeddings().weight.device
|
device = self.transformer.get_input_embeddings().weight.device
|
||||||
|
else:
|
||||||
|
device = self.execution_device
|
||||||
|
|
||||||
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
|
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
|
||||||
|
|
||||||
attention_mask_model = None
|
attention_mask_model = None
|
||||||
|
|||||||
@ -1027,6 +1027,8 @@ class ZImage(Lumina2):
|
|||||||
|
|
||||||
memory_usage_factor = 1.7
|
memory_usage_factor = 1.7
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||||
|
|||||||
@ -100,6 +100,28 @@ class Qwen3_4BConfig:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Ovis25_2BConfig:
|
||||||
|
vocab_size: int = 151936
|
||||||
|
hidden_size: int = 2048
|
||||||
|
intermediate_size: int = 6144
|
||||||
|
num_hidden_layers: int = 28
|
||||||
|
num_attention_heads: int = 16
|
||||||
|
num_key_value_heads: int = 8
|
||||||
|
max_position_embeddings: int = 40960
|
||||||
|
rms_norm_eps: float = 1e-6
|
||||||
|
rope_theta: float = 1000000.0
|
||||||
|
transformer_type: str = "llama"
|
||||||
|
head_dim = 128
|
||||||
|
rms_norm_add = False
|
||||||
|
mlp_activation = "silu"
|
||||||
|
qkv_bias = False
|
||||||
|
rope_dims = None
|
||||||
|
q_norm = "gemma3"
|
||||||
|
k_norm = "gemma3"
|
||||||
|
rope_scale = None
|
||||||
|
final_norm: bool = True
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen25_7BVLI_Config:
|
class Qwen25_7BVLI_Config:
|
||||||
vocab_size: int = 152064
|
vocab_size: int = 152064
|
||||||
@ -542,6 +564,15 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
class Ovis25_2B(BaseLlama, torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
config = Ovis25_2BConfig(**config_dict)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
|
||||||
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
69
comfy/text_encoders/ovis.py
Normal file
69
comfy/text_encoders/ovis.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
from transformers import Qwen2Tokenizer
|
||||||
|
import comfy.text_encoders.llama
|
||||||
|
from comfy import sd1_clip
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
|
class OvisTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer)
|
||||||
|
self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||||
|
if llama_template is None:
|
||||||
|
llama_text = self.llama_template.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = llama_template.format(text)
|
||||||
|
|
||||||
|
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
class Ovis25_2BModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class OvisTEModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options)
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs, template_end=-1):
|
||||||
|
out, pooled = super().encode_token_weights(token_weight_pairs)
|
||||||
|
tok_pairs = token_weight_pairs["qwen3_2b"][0]
|
||||||
|
count_im_start = 0
|
||||||
|
if template_end == -1:
|
||||||
|
for i, v in enumerate(tok_pairs):
|
||||||
|
elem = v[0]
|
||||||
|
if not torch.is_tensor(elem):
|
||||||
|
if isinstance(elem, numbers.Integral):
|
||||||
|
if elem == 4004 and count_im_start < 1:
|
||||||
|
template_end = i
|
||||||
|
count_im_start += 1
|
||||||
|
|
||||||
|
if out.shape[1] > (template_end + 1):
|
||||||
|
if tok_pairs[template_end + 1][0] == 25:
|
||||||
|
template_end += 1
|
||||||
|
|
||||||
|
out = out[:, template_end:]
|
||||||
|
return out, pooled, {}
|
||||||
|
|
||||||
|
|
||||||
|
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
|
||||||
|
class OvisTEModel_(OvisTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return OvisTEModel_
|
||||||
@ -179,36 +179,36 @@
|
|||||||
"special": false
|
"special": false
|
||||||
},
|
},
|
||||||
"151665": {
|
"151665": {
|
||||||
"content": "<|img|>",
|
"content": "<tool_response>",
|
||||||
"lstrip": false,
|
"lstrip": false,
|
||||||
"normalized": false,
|
"normalized": false,
|
||||||
"rstrip": false,
|
"rstrip": false,
|
||||||
"single_word": false,
|
"single_word": false,
|
||||||
"special": true
|
"special": false
|
||||||
},
|
},
|
||||||
"151666": {
|
"151666": {
|
||||||
"content": "<|endofimg|>",
|
"content": "</tool_response>",
|
||||||
"lstrip": false,
|
"lstrip": false,
|
||||||
"normalized": false,
|
"normalized": false,
|
||||||
"rstrip": false,
|
"rstrip": false,
|
||||||
"single_word": false,
|
"single_word": false,
|
||||||
"special": true
|
"special": false
|
||||||
},
|
},
|
||||||
"151667": {
|
"151667": {
|
||||||
"content": "<|meta|>",
|
"content": "<think>",
|
||||||
"lstrip": false,
|
"lstrip": false,
|
||||||
"normalized": false,
|
"normalized": false,
|
||||||
"rstrip": false,
|
"rstrip": false,
|
||||||
"single_word": false,
|
"single_word": false,
|
||||||
"special": true
|
"special": false
|
||||||
},
|
},
|
||||||
"151668": {
|
"151668": {
|
||||||
"content": "<|endofmeta|>",
|
"content": "</think>",
|
||||||
"lstrip": false,
|
"lstrip": false,
|
||||||
"normalized": false,
|
"normalized": false,
|
||||||
"rstrip": false,
|
"rstrip": false,
|
||||||
"single_word": false,
|
"single_word": false,
|
||||||
"special": true
|
"special": false
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additional_special_tokens": [
|
"additional_special_tokens": [
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from comfy.cli_args import args
|
|||||||
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
||||||
"supports_preview_metadata": True,
|
"supports_preview_metadata": True,
|
||||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||||
|
"extension": {"manager": {"supports_v4": True}},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class
|
|||||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||||
from . import _io as io
|
from . import _io_public as io
|
||||||
from . import _ui as ui
|
from . import _ui_public as ui
|
||||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||||
|
|||||||
@ -336,7 +336,10 @@ class VideoFromComponents(VideoInput):
|
|||||||
raise ValueError("Only MP4 format is supported for now")
|
raise ValueError("Only MP4 format is supported for now")
|
||||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||||
raise ValueError("Only H264 codec is supported for now")
|
raise ValueError("Only H264 codec is supported for now")
|
||||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
extra_kwargs = {}
|
||||||
|
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
||||||
|
extra_kwargs["format"] = format.value
|
||||||
|
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
|
||||||
# Add metadata before writing any streams
|
# Add metadata before writing any streams
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
for key, value in metadata.items():
|
for key, value in metadata.items():
|
||||||
|
|||||||
@ -4,7 +4,8 @@ import copy
|
|||||||
import inspect
|
import inspect
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from dataclasses import asdict, dataclass
|
from collections.abc import Iterable
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
|
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
|
||||||
from typing_extensions import NotRequired, final
|
from typing_extensions import NotRequired, final
|
||||||
@ -150,6 +151,9 @@ class _IO_V3:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def io_type(self):
|
def io_type(self):
|
||||||
return self.Parent.io_type
|
return self.Parent.io_type
|
||||||
@ -182,6 +186,9 @@ class Input(_IO_V3):
|
|||||||
def get_io_type(self):
|
def get_io_type(self):
|
||||||
return _StringIOType(self.io_type)
|
return _StringIOType(self.io_type)
|
||||||
|
|
||||||
|
def get_all(self) -> list[Input]:
|
||||||
|
return [self]
|
||||||
|
|
||||||
class WidgetInput(Input):
|
class WidgetInput(Input):
|
||||||
'''
|
'''
|
||||||
Base class for a V3 Input with widget.
|
Base class for a V3 Input with widget.
|
||||||
@ -814,13 +821,61 @@ class MultiType:
|
|||||||
else:
|
else:
|
||||||
return super().as_dict()
|
return super().as_dict()
|
||||||
|
|
||||||
|
@comfytype(io_type="COMFY_MATCHTYPE_V3")
|
||||||
|
class MatchType(ComfyTypeIO):
|
||||||
|
class Template:
|
||||||
|
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType):
|
||||||
|
self.template_id = template_id
|
||||||
|
# account for syntactic sugar
|
||||||
|
if not isinstance(allowed_types, Iterable):
|
||||||
|
allowed_types = [allowed_types]
|
||||||
|
for t in allowed_types:
|
||||||
|
if not isinstance(t, type):
|
||||||
|
if not isinstance(t, _ComfyType):
|
||||||
|
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}")
|
||||||
|
else:
|
||||||
|
if not issubclass(t, _ComfyType):
|
||||||
|
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}")
|
||||||
|
self.allowed_types = allowed_types
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return {
|
||||||
|
"template_id": self.template_id,
|
||||||
|
"allowed_types": ",".join([t.io_type for t in self.allowed_types]),
|
||||||
|
}
|
||||||
|
|
||||||
|
class Input(Input):
|
||||||
|
def __init__(self, id: str, template: MatchType.Template,
|
||||||
|
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||||
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||||
|
self.template = template
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return super().as_dict() | prune_dict({
|
||||||
|
"template": self.template.as_dict(),
|
||||||
|
})
|
||||||
|
|
||||||
|
class Output(Output):
|
||||||
|
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
|
||||||
|
is_output_list=False):
|
||||||
|
super().__init__(id, display_name, tooltip, is_output_list)
|
||||||
|
self.template = template
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return super().as_dict() | prune_dict({
|
||||||
|
"template": self.template.as_dict(),
|
||||||
|
})
|
||||||
|
|
||||||
class DynamicInput(Input, ABC):
|
class DynamicInput(Input, ABC):
|
||||||
'''
|
'''
|
||||||
Abstract class for dynamic input registration.
|
Abstract class for dynamic input registration.
|
||||||
'''
|
'''
|
||||||
@abstractmethod
|
|
||||||
def get_dynamic(self) -> list[Input]:
|
def get_dynamic(self) -> list[Input]:
|
||||||
...
|
return []
|
||||||
|
|
||||||
|
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DynamicOutput(Output, ABC):
|
class DynamicOutput(Output, ABC):
|
||||||
'''
|
'''
|
||||||
@ -830,99 +885,223 @@ class DynamicOutput(Output, ABC):
|
|||||||
is_output_list=False):
|
is_output_list=False):
|
||||||
super().__init__(id, display_name, tooltip, is_output_list)
|
super().__init__(id, display_name, tooltip, is_output_list)
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_dynamic(self) -> list[Output]:
|
def get_dynamic(self) -> list[Output]:
|
||||||
...
|
return []
|
||||||
|
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_AUTOGROW_V3")
|
@comfytype(io_type="COMFY_AUTOGROW_V3")
|
||||||
class AutogrowDynamic(ComfyTypeI):
|
class Autogrow(ComfyTypeI):
|
||||||
Type = list[Any]
|
Type = dict[str, Any]
|
||||||
class Input(DynamicInput):
|
_MaxNames = 100 # NOTE: max 100 names for sanity
|
||||||
def __init__(self, id: str, template_input: Input, min: int=1, max: int=None,
|
|
||||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
class _AutogrowTemplate:
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
def __init__(self, input: Input):
|
||||||
self.template_input = template_input
|
# dynamic inputs are not allowed as the template input
|
||||||
if min is not None:
|
assert(not isinstance(input, DynamicInput))
|
||||||
assert(min >= 1)
|
self.input = copy.copy(input)
|
||||||
if max is not None:
|
if isinstance(self.input, WidgetInput):
|
||||||
|
self.input.force_input = True
|
||||||
|
self.names: list[str] = []
|
||||||
|
self.cached_inputs = {}
|
||||||
|
|
||||||
|
def _create_input(self, input: Input, name: str):
|
||||||
|
new_input = copy.copy(self.input)
|
||||||
|
new_input.id = name
|
||||||
|
return new_input
|
||||||
|
|
||||||
|
def _create_cached_inputs(self):
|
||||||
|
for name in self.names:
|
||||||
|
self.cached_inputs[name] = self._create_input(self.input, name)
|
||||||
|
|
||||||
|
def get_all(self) -> list[Input]:
|
||||||
|
return list(self.cached_inputs.values())
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return prune_dict({
|
||||||
|
"input": create_input_dict_v1([self.input]),
|
||||||
|
})
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
self.input.validate()
|
||||||
|
|
||||||
|
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||||
|
real_inputs = []
|
||||||
|
for name, input in self.cached_inputs.items():
|
||||||
|
if name in live_inputs:
|
||||||
|
real_inputs.append(input)
|
||||||
|
add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
|
||||||
|
add_dynamic_id_mapping(d, real_inputs, curr_prefix)
|
||||||
|
|
||||||
|
class TemplatePrefix(_AutogrowTemplate):
|
||||||
|
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
|
||||||
|
super().__init__(input)
|
||||||
|
self.prefix = prefix
|
||||||
|
assert(min >= 0)
|
||||||
assert(max >= 1)
|
assert(max >= 1)
|
||||||
|
assert(max <= Autogrow._MaxNames)
|
||||||
self.min = min
|
self.min = min
|
||||||
self.max = max
|
self.max = max
|
||||||
|
self.names = [f"{self.prefix}{i}" for i in range(self.max)]
|
||||||
|
self._create_cached_inputs()
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return super().as_dict() | prune_dict({
|
||||||
|
"prefix": self.prefix,
|
||||||
|
"min": self.min,
|
||||||
|
"max": self.max,
|
||||||
|
})
|
||||||
|
|
||||||
|
class TemplateNames(_AutogrowTemplate):
|
||||||
|
def __init__(self, input: Input, names: list[str], min: int=1):
|
||||||
|
super().__init__(input)
|
||||||
|
self.names = names[:Autogrow._MaxNames]
|
||||||
|
assert(min >= 0)
|
||||||
|
self.min = min
|
||||||
|
self._create_cached_inputs()
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return super().as_dict() | prune_dict({
|
||||||
|
"names": self.names,
|
||||||
|
"min": self.min,
|
||||||
|
})
|
||||||
|
|
||||||
|
class Input(DynamicInput):
|
||||||
|
def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames,
|
||||||
|
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||||
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||||
|
self.template = template
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return super().as_dict() | prune_dict({
|
||||||
|
"template": self.template.as_dict(),
|
||||||
|
})
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Input]:
|
def get_dynamic(self) -> list[Input]:
|
||||||
curr_count = 1
|
return self.template.get_all()
|
||||||
new_inputs = []
|
|
||||||
for i in range(self.min):
|
|
||||||
new_input = copy.copy(self.template_input)
|
|
||||||
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
|
|
||||||
if new_input.display_name is not None:
|
|
||||||
new_input.display_name = f"{new_input.display_name}{curr_count}"
|
|
||||||
new_input.optional = self.optional or new_input.optional
|
|
||||||
if isinstance(self.template_input, WidgetInput):
|
|
||||||
new_input.force_input = True
|
|
||||||
new_inputs.append(new_input)
|
|
||||||
curr_count += 1
|
|
||||||
# pretend to expand up to max
|
|
||||||
for i in range(curr_count-1, self.max):
|
|
||||||
new_input = copy.copy(self.template_input)
|
|
||||||
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
|
|
||||||
if new_input.display_name is not None:
|
|
||||||
new_input.display_name = f"{new_input.display_name}{curr_count}"
|
|
||||||
new_input.optional = True
|
|
||||||
if isinstance(self.template_input, WidgetInput):
|
|
||||||
new_input.force_input = True
|
|
||||||
new_inputs.append(new_input)
|
|
||||||
curr_count += 1
|
|
||||||
return new_inputs
|
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_COMBODYNAMIC_V3")
|
def get_all(self) -> list[Input]:
|
||||||
class ComboDynamic(ComfyTypeI):
|
return [self] + self.template.get_all()
|
||||||
class Input(DynamicInput):
|
|
||||||
def __init__(self, id: str):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_MATCHTYPE_V3")
|
def validate(self):
|
||||||
class MatchType(ComfyTypeIO):
|
self.template.validate()
|
||||||
class Template:
|
|
||||||
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]):
|
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||||
self.template_id = template_id
|
curr_prefix = f"{curr_prefix}{self.id}."
|
||||||
self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types
|
# need to remove self from expected inputs dictionary; replaced by template inputs in frontend
|
||||||
|
for inner_dict in d.values():
|
||||||
|
if self.id in inner_dict:
|
||||||
|
del inner_dict[self.id]
|
||||||
|
self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
|
||||||
|
|
||||||
|
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
|
||||||
|
class DynamicCombo(ComfyTypeI):
|
||||||
|
Type = dict[str, Any]
|
||||||
|
|
||||||
|
class Option:
|
||||||
|
def __init__(self, key: str, inputs: list[Input]):
|
||||||
|
self.key = key
|
||||||
|
self.inputs = inputs
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return {
|
return {
|
||||||
"template_id": self.template_id,
|
"key": self.key,
|
||||||
"allowed_types": "".join(t.io_type for t in self.allowed_types),
|
"inputs": create_input_dict_v1(self.inputs),
|
||||||
}
|
}
|
||||||
|
|
||||||
class Input(DynamicInput):
|
class Input(DynamicInput):
|
||||||
def __init__(self, id: str, template: MatchType.Template,
|
def __init__(self, id: str, options: list[DynamicCombo.Option],
|
||||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||||
self.template = template
|
self.options = options
|
||||||
|
|
||||||
|
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||||
|
# check if dynamic input's id is in live_inputs
|
||||||
|
if self.id in live_inputs:
|
||||||
|
curr_prefix = f"{curr_prefix}{self.id}."
|
||||||
|
key = live_inputs[self.id]
|
||||||
|
selected_option = None
|
||||||
|
for option in self.options:
|
||||||
|
if option.key == key:
|
||||||
|
selected_option = option
|
||||||
|
break
|
||||||
|
if selected_option is not None:
|
||||||
|
add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
|
||||||
|
add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Input]:
|
def get_dynamic(self) -> list[Input]:
|
||||||
return [self]
|
return [input for option in self.options for input in option.inputs]
|
||||||
|
|
||||||
|
def get_all(self) -> list[Input]:
|
||||||
|
return [self] + [input for option in self.options for input in option.inputs]
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return super().as_dict() | prune_dict({
|
return super().as_dict() | prune_dict({
|
||||||
"template": self.template.as_dict(),
|
"options": [o.as_dict() for o in self.options],
|
||||||
})
|
})
|
||||||
|
|
||||||
class Output(DynamicOutput):
|
def validate(self):
|
||||||
def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None,
|
# make sure all nested inputs are validated
|
||||||
is_output_list=False):
|
for option in self.options:
|
||||||
super().__init__(id, display_name, tooltip, is_output_list)
|
for input in option.inputs:
|
||||||
self.template = template
|
input.validate()
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Output]:
|
@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
|
||||||
return [self]
|
class DynamicSlot(ComfyTypeI):
|
||||||
|
Type = dict[str, Any]
|
||||||
|
|
||||||
|
class Input(DynamicInput):
|
||||||
|
def __init__(self, slot: Input, inputs: list[Input],
|
||||||
|
display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||||
|
assert(not isinstance(slot, DynamicInput))
|
||||||
|
self.slot = copy.copy(slot)
|
||||||
|
self.slot.display_name = slot.display_name if slot.display_name is not None else display_name
|
||||||
|
optional = True
|
||||||
|
self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip
|
||||||
|
self.slot.lazy = slot.lazy if slot.lazy is not None else lazy
|
||||||
|
self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict
|
||||||
|
super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict)
|
||||||
|
self.inputs = inputs
|
||||||
|
self.force_input = None
|
||||||
|
# force widget inputs to have no widgets, otherwise this would be awkward
|
||||||
|
if isinstance(self.slot, WidgetInput):
|
||||||
|
self.force_input = True
|
||||||
|
self.slot.force_input = True
|
||||||
|
|
||||||
|
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||||
|
if self.id in live_inputs:
|
||||||
|
curr_prefix = f"{curr_prefix}{self.id}."
|
||||||
|
add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
|
||||||
|
add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
|
||||||
|
|
||||||
|
def get_dynamic(self) -> list[Input]:
|
||||||
|
return [self.slot] + self.inputs
|
||||||
|
|
||||||
|
def get_all(self) -> list[Input]:
|
||||||
|
return [self] + [self.slot] + self.inputs
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return super().as_dict() | prune_dict({
|
return super().as_dict() | prune_dict({
|
||||||
"template": self.template.as_dict(),
|
"slotType": str(self.slot.get_io_type()),
|
||||||
|
"inputs": create_input_dict_v1(self.inputs),
|
||||||
|
"forceInput": self.force_input,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
self.slot.validate()
|
||||||
|
for input in self.inputs:
|
||||||
|
input.validate()
|
||||||
|
|
||||||
|
def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None):
|
||||||
|
dynamic = d.setdefault("dynamic_paths", {})
|
||||||
|
if self is not None:
|
||||||
|
dynamic[self.id] = f"{curr_prefix}{self.id}"
|
||||||
|
for i in inputs:
|
||||||
|
if not isinstance(i, DynamicInput):
|
||||||
|
dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}"
|
||||||
|
|
||||||
|
class V3Data(TypedDict):
|
||||||
|
hidden_inputs: dict[str, Any]
|
||||||
|
dynamic_paths: dict[str, Any]
|
||||||
|
|
||||||
class HiddenHolder:
|
class HiddenHolder:
|
||||||
def __init__(self, unique_id: str, prompt: Any,
|
def __init__(self, unique_id: str, prompt: Any,
|
||||||
@ -984,6 +1163,7 @@ class NodeInfoV1:
|
|||||||
output_is_list: list[bool]=None
|
output_is_list: list[bool]=None
|
||||||
output_name: list[str]=None
|
output_name: list[str]=None
|
||||||
output_tooltips: list[str]=None
|
output_tooltips: list[str]=None
|
||||||
|
output_matchtypes: list[str]=None
|
||||||
name: str=None
|
name: str=None
|
||||||
display_name: str=None
|
display_name: str=None
|
||||||
description: str=None
|
description: str=None
|
||||||
@ -1019,9 +1199,9 @@ class Schema:
|
|||||||
"""Display name of node."""
|
"""Display name of node."""
|
||||||
category: str = "sd"
|
category: str = "sd"
|
||||||
"""The category of the node, as per the "Add Node" menu."""
|
"""The category of the node, as per the "Add Node" menu."""
|
||||||
inputs: list[Input]=None
|
inputs: list[Input] = field(default_factory=list)
|
||||||
outputs: list[Output]=None
|
outputs: list[Output] = field(default_factory=list)
|
||||||
hidden: list[Hidden]=None
|
hidden: list[Hidden] = field(default_factory=list)
|
||||||
description: str=""
|
description: str=""
|
||||||
"""Node description, shown as a tooltip when hovering over the node."""
|
"""Node description, shown as a tooltip when hovering over the node."""
|
||||||
is_input_list: bool = False
|
is_input_list: bool = False
|
||||||
@ -1061,7 +1241,11 @@ class Schema:
|
|||||||
'''Validate the schema:
|
'''Validate the schema:
|
||||||
- verify ids on inputs and outputs are unique - both internally and in relation to each other
|
- verify ids on inputs and outputs are unique - both internally and in relation to each other
|
||||||
'''
|
'''
|
||||||
input_ids = [i.id for i in self.inputs] if self.inputs is not None else []
|
nested_inputs: list[Input] = []
|
||||||
|
if self.inputs is not None:
|
||||||
|
for input in self.inputs:
|
||||||
|
nested_inputs.extend(input.get_all())
|
||||||
|
input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else []
|
||||||
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
|
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
|
||||||
input_set = set(input_ids)
|
input_set = set(input_ids)
|
||||||
output_set = set(output_ids)
|
output_set = set(output_ids)
|
||||||
@ -1077,6 +1261,13 @@ class Schema:
|
|||||||
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
|
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
|
||||||
if len(issues) > 0:
|
if len(issues) > 0:
|
||||||
raise ValueError("\n".join(issues))
|
raise ValueError("\n".join(issues))
|
||||||
|
# validate inputs and outputs
|
||||||
|
if self.inputs is not None:
|
||||||
|
for input in self.inputs:
|
||||||
|
input.validate()
|
||||||
|
if self.outputs is not None:
|
||||||
|
for output in self.outputs:
|
||||||
|
output.validate()
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
|
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
|
||||||
@ -1102,19 +1293,10 @@ class Schema:
|
|||||||
if output.id is None:
|
if output.id is None:
|
||||||
output.id = f"_{i}_{output.io_type}_"
|
output.id = f"_{i}_{output.io_type}_"
|
||||||
|
|
||||||
def get_v1_info(self, cls) -> NodeInfoV1:
|
def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1:
|
||||||
|
# NOTE: live_inputs will not be used anymore very soon and this will be done another way
|
||||||
# get V1 inputs
|
# get V1 inputs
|
||||||
input = {
|
input = create_input_dict_v1(self.inputs, live_inputs)
|
||||||
"required": {}
|
|
||||||
}
|
|
||||||
if self.inputs:
|
|
||||||
for i in self.inputs:
|
|
||||||
if isinstance(i, DynamicInput):
|
|
||||||
dynamic_inputs = i.get_dynamic()
|
|
||||||
for d in dynamic_inputs:
|
|
||||||
add_to_dict_v1(d, input)
|
|
||||||
else:
|
|
||||||
add_to_dict_v1(i, input)
|
|
||||||
if self.hidden:
|
if self.hidden:
|
||||||
for hidden in self.hidden:
|
for hidden in self.hidden:
|
||||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||||
@ -1123,12 +1305,24 @@ class Schema:
|
|||||||
output_is_list = []
|
output_is_list = []
|
||||||
output_name = []
|
output_name = []
|
||||||
output_tooltips = []
|
output_tooltips = []
|
||||||
|
output_matchtypes = []
|
||||||
|
any_matchtypes = False
|
||||||
if self.outputs:
|
if self.outputs:
|
||||||
for o in self.outputs:
|
for o in self.outputs:
|
||||||
output.append(o.io_type)
|
output.append(o.io_type)
|
||||||
output_is_list.append(o.is_output_list)
|
output_is_list.append(o.is_output_list)
|
||||||
output_name.append(o.display_name if o.display_name else o.io_type)
|
output_name.append(o.display_name if o.display_name else o.io_type)
|
||||||
output_tooltips.append(o.tooltip if o.tooltip else None)
|
output_tooltips.append(o.tooltip if o.tooltip else None)
|
||||||
|
# special handling for MatchType
|
||||||
|
if isinstance(o, MatchType.Output):
|
||||||
|
output_matchtypes.append(o.template.template_id)
|
||||||
|
any_matchtypes = True
|
||||||
|
else:
|
||||||
|
output_matchtypes.append(None)
|
||||||
|
|
||||||
|
# clear out lists that are all None
|
||||||
|
if not any_matchtypes:
|
||||||
|
output_matchtypes = None
|
||||||
|
|
||||||
info = NodeInfoV1(
|
info = NodeInfoV1(
|
||||||
input=input,
|
input=input,
|
||||||
@ -1137,6 +1331,7 @@ class Schema:
|
|||||||
output_is_list=output_is_list,
|
output_is_list=output_is_list,
|
||||||
output_name=output_name,
|
output_name=output_name,
|
||||||
output_tooltips=output_tooltips,
|
output_tooltips=output_tooltips,
|
||||||
|
output_matchtypes=output_matchtypes,
|
||||||
name=self.node_id,
|
name=self.node_id,
|
||||||
display_name=self.display_name,
|
display_name=self.display_name,
|
||||||
category=self.category,
|
category=self.category,
|
||||||
@ -1182,16 +1377,57 @@ class Schema:
|
|||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
def add_to_dict_v1(i: Input, input: dict):
|
def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict:
|
||||||
|
input = {
|
||||||
|
"required": {}
|
||||||
|
}
|
||||||
|
add_to_input_dict_v1(input, inputs, live_inputs)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''):
|
||||||
|
for i in inputs:
|
||||||
|
if isinstance(i, DynamicInput):
|
||||||
|
add_to_dict_v1(i, d)
|
||||||
|
if live_inputs is not None:
|
||||||
|
i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
|
||||||
|
else:
|
||||||
|
add_to_dict_v1(i, d)
|
||||||
|
|
||||||
|
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
|
||||||
key = "optional" if i.optional else "required"
|
key = "optional" if i.optional else "required"
|
||||||
as_dict = i.as_dict()
|
as_dict = i.as_dict()
|
||||||
# for v1, we don't want to include the optional key
|
# for v1, we don't want to include the optional key
|
||||||
as_dict.pop("optional", None)
|
as_dict.pop("optional", None)
|
||||||
input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
|
if dynamic_dict is None:
|
||||||
|
value = (i.get_io_type(), as_dict)
|
||||||
|
else:
|
||||||
|
value = (i.get_io_type(), as_dict, dynamic_dict)
|
||||||
|
d.setdefault(key, {})[i.id] = value
|
||||||
|
|
||||||
def add_to_dict_v3(io: Input | Output, d: dict):
|
def add_to_dict_v3(io: Input | Output, d: dict):
|
||||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||||
|
|
||||||
|
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||||
|
paths = v3_data.get("dynamic_paths", None)
|
||||||
|
if paths is None:
|
||||||
|
return values
|
||||||
|
values = values.copy()
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
for key, path in paths.items():
|
||||||
|
parts = path.split(".")
|
||||||
|
current = result
|
||||||
|
|
||||||
|
for i, p in enumerate(parts):
|
||||||
|
is_last = (i == len(parts) - 1)
|
||||||
|
|
||||||
|
if is_last:
|
||||||
|
current[p] = values.pop(key, None)
|
||||||
|
else:
|
||||||
|
current = current.setdefault(p, {})
|
||||||
|
|
||||||
|
values.update(result)
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||||
@ -1311,12 +1547,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]:
|
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
|
||||||
"""Creates clone of real node class to prevent monkey-patching."""
|
"""Creates clone of real node class to prevent monkey-patching."""
|
||||||
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
||||||
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
||||||
# set hidden
|
# set hidden
|
||||||
type_clone.hidden = HiddenHolder.from_dict(hidden_inputs)
|
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
|
||||||
return type_clone
|
return type_clone
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@ -1433,14 +1669,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]:
|
def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
|
||||||
schema = cls.FINALIZE_SCHEMA()
|
schema = cls.FINALIZE_SCHEMA()
|
||||||
info = schema.get_v1_info(cls)
|
info = schema.get_v1_info(cls, live_inputs)
|
||||||
input = info.input
|
input = info.input
|
||||||
if not include_hidden:
|
if not include_hidden:
|
||||||
input.pop("hidden", None)
|
input.pop("hidden", None)
|
||||||
if return_schema:
|
if return_schema:
|
||||||
return input, schema
|
v3_data: V3Data = {}
|
||||||
|
dynamic = input.pop("dynamic_paths", None)
|
||||||
|
if dynamic is not None:
|
||||||
|
v3_data["dynamic_paths"] = dynamic
|
||||||
|
return input, schema, v3_data
|
||||||
return input
|
return input
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@ -1513,7 +1753,7 @@ class ComfyNode(_ComfyNodeBaseInternal):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_inputs(cls, **kwargs) -> bool:
|
def validate_inputs(cls, **kwargs) -> bool | str:
|
||||||
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
|
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -1628,6 +1868,7 @@ __all__ = [
|
|||||||
"StyleModel",
|
"StyleModel",
|
||||||
"Gligen",
|
"Gligen",
|
||||||
"UpscaleModel",
|
"UpscaleModel",
|
||||||
|
"LatentUpscaleModel",
|
||||||
"Audio",
|
"Audio",
|
||||||
"Video",
|
"Video",
|
||||||
"SVG",
|
"SVG",
|
||||||
@ -1651,6 +1892,10 @@ __all__ = [
|
|||||||
"SEGS",
|
"SEGS",
|
||||||
"AnyType",
|
"AnyType",
|
||||||
"MultiType",
|
"MultiType",
|
||||||
|
# Dynamic Types
|
||||||
|
"MatchType",
|
||||||
|
# "DynamicCombo",
|
||||||
|
# "Autogrow",
|
||||||
# Other classes
|
# Other classes
|
||||||
"HiddenHolder",
|
"HiddenHolder",
|
||||||
"Hidden",
|
"Hidden",
|
||||||
@ -1661,4 +1906,5 @@ __all__ = [
|
|||||||
"NodeOutput",
|
"NodeOutput",
|
||||||
"add_to_dict_v1",
|
"add_to_dict_v1",
|
||||||
"add_to_dict_v3",
|
"add_to_dict_v3",
|
||||||
|
"V3Data",
|
||||||
]
|
]
|
||||||
|
|||||||
1
comfy_api/latest/_io_public.py
Normal file
1
comfy_api/latest/_io_public.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from ._io import * # noqa: F403
|
||||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
@ -318,9 +319,10 @@ class AudioSaveHelper:
|
|||||||
for key, value in metadata.items():
|
for key, value in metadata.items():
|
||||||
output_container.metadata[key] = value
|
output_container.metadata[key] = value
|
||||||
|
|
||||||
|
layout = "mono" if waveform.shape[0] == 1 else "stereo"
|
||||||
# Set up the output stream with appropriate properties
|
# Set up the output stream with appropriate properties
|
||||||
if format == "opus":
|
if format == "opus":
|
||||||
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
|
||||||
if quality == "64k":
|
if quality == "64k":
|
||||||
out_stream.bit_rate = 64000
|
out_stream.bit_rate = 64000
|
||||||
elif quality == "96k":
|
elif quality == "96k":
|
||||||
@ -332,7 +334,7 @@ class AudioSaveHelper:
|
|||||||
elif quality == "320k":
|
elif quality == "320k":
|
||||||
out_stream.bit_rate = 320000
|
out_stream.bit_rate = 320000
|
||||||
elif format == "mp3":
|
elif format == "mp3":
|
||||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
|
||||||
if quality == "V0":
|
if quality == "V0":
|
||||||
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||||
out_stream.codec_context.qscale = 1
|
out_stream.codec_context.qscale = 1
|
||||||
@ -341,12 +343,12 @@ class AudioSaveHelper:
|
|||||||
elif quality == "320k":
|
elif quality == "320k":
|
||||||
out_stream.bit_rate = 320000
|
out_stream.bit_rate = 320000
|
||||||
else: # format == "flac":
|
else: # format == "flac":
|
||||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
|
||||||
|
|
||||||
frame = av.AudioFrame.from_ndarray(
|
frame = av.AudioFrame.from_ndarray(
|
||||||
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
||||||
format="flt",
|
format="flt",
|
||||||
layout="mono" if waveform.shape[0] == 1 else "stereo",
|
layout=layout,
|
||||||
)
|
)
|
||||||
frame.sample_rate = sample_rate
|
frame.sample_rate = sample_rate
|
||||||
frame.pts = 0
|
frame.pts = 0
|
||||||
@ -436,9 +438,19 @@ class PreviewUI3D(_UIOutput):
|
|||||||
def __init__(self, model_file, camera_info, **kwargs):
|
def __init__(self, model_file, camera_info, **kwargs):
|
||||||
self.model_file = model_file
|
self.model_file = model_file
|
||||||
self.camera_info = camera_info
|
self.camera_info = camera_info
|
||||||
|
self.bg_image_path = None
|
||||||
|
bg_image = kwargs.get("bg_image", None)
|
||||||
|
if bg_image is not None:
|
||||||
|
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
|
||||||
|
img = PILImage.fromarray(img_array)
|
||||||
|
temp_dir = folder_paths.get_temp_directory()
|
||||||
|
filename = f"bg_{uuid.uuid4().hex}.png"
|
||||||
|
bg_image_path = os.path.join(temp_dir, filename)
|
||||||
|
img.save(bg_image_path, compress_level=1)
|
||||||
|
self.bg_image_path = f"temp/{filename}"
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return {"result": [self.model_file, self.camera_info]}
|
return {"result": [self.model_file, self.camera_info, self.bg_image_path]}
|
||||||
|
|
||||||
|
|
||||||
class PreviewText(_UIOutput):
|
class PreviewText(_UIOutput):
|
||||||
|
|||||||
1
comfy_api/latest/_ui_public.py
Normal file
1
comfy_api/latest/_ui_public.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from ._ui import * # noqa: F403
|
||||||
@ -6,7 +6,7 @@ from comfy_api.latest import (
|
|||||||
)
|
)
|
||||||
from typing import Type, TYPE_CHECKING
|
from typing import Type, TYPE_CHECKING
|
||||||
from comfy_api.internal.async_to_sync import create_sync_class
|
from comfy_api.internal.async_to_sync import create_sync_class
|
||||||
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
|
from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
||||||
@ -42,4 +42,8 @@ __all__ = [
|
|||||||
"InputImpl",
|
"InputImpl",
|
||||||
"Types",
|
"Types",
|
||||||
"ComfyExtension",
|
"ComfyExtension",
|
||||||
|
"io",
|
||||||
|
"IO",
|
||||||
|
"ui",
|
||||||
|
"UI",
|
||||||
]
|
]
|
||||||
|
|||||||
86
comfy_api_nodes/apis/kling_api.py
Normal file
86
comfy_api_nodes/apis/kling_api.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProText2VideoRequest(BaseModel):
|
||||||
|
model_name: str = Field(..., description="kling-video-o1")
|
||||||
|
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||||
|
duration: str = Field(..., description="'5' or '10'")
|
||||||
|
prompt: str = Field(...)
|
||||||
|
mode: str = Field("pro")
|
||||||
|
|
||||||
|
|
||||||
|
class OmniParamImage(BaseModel):
|
||||||
|
image_url: str = Field(...)
|
||||||
|
type: str | None = Field(None, description="Can be 'first_frame' or 'end_frame'")
|
||||||
|
|
||||||
|
|
||||||
|
class OmniParamVideo(BaseModel):
|
||||||
|
video_url: str = Field(...)
|
||||||
|
refer_type: str | None = Field(..., description="Can be 'base' or 'feature'")
|
||||||
|
keep_original_sound: str = Field(..., description="'yes' or 'no'")
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProFirstLastFrameRequest(BaseModel):
|
||||||
|
model_name: str = Field(..., description="kling-video-o1")
|
||||||
|
image_list: list[OmniParamImage] = Field(..., min_length=1, max_length=7)
|
||||||
|
duration: str = Field(..., description="'5' or '10'")
|
||||||
|
prompt: str = Field(...)
|
||||||
|
mode: str = Field("pro")
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProReferences2VideoRequest(BaseModel):
|
||||||
|
model_name: str = Field(..., description="kling-video-o1")
|
||||||
|
aspect_ratio: str | None = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||||
|
image_list: list[OmniParamImage] | None = Field(
|
||||||
|
None, max_length=7, description="Max length 4 when video is present."
|
||||||
|
)
|
||||||
|
video_list: list[OmniParamVideo] | None = Field(None, max_length=1)
|
||||||
|
duration: str | None = Field(..., description="From 3 to 10.")
|
||||||
|
prompt: str = Field(...)
|
||||||
|
mode: str = Field("pro")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatusVideoResult(BaseModel):
|
||||||
|
duration: str | None = Field(None, description="Total video duration")
|
||||||
|
id: str | None = Field(None, description="Generated video ID")
|
||||||
|
url: str | None = Field(None, description="URL for generated video")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatusImageResult(BaseModel):
|
||||||
|
index: int = Field(..., description="Image Number,0-9")
|
||||||
|
url: str = Field(..., description="URL for generated image")
|
||||||
|
|
||||||
|
|
||||||
|
class OmniTaskStatusResults(BaseModel):
|
||||||
|
videos: list[TaskStatusVideoResult] | None = Field(None)
|
||||||
|
images: list[TaskStatusImageResult] | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniTaskStatusResponseData(BaseModel):
|
||||||
|
created_at: int | None = Field(None, description="Task creation time")
|
||||||
|
updated_at: int | None = Field(None, description="Task update time")
|
||||||
|
task_status: str | None = None
|
||||||
|
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
|
||||||
|
task_id: str | None = Field(None, description="Task ID")
|
||||||
|
task_result: OmniTaskStatusResults | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniTaskStatusResponse(BaseModel):
|
||||||
|
code: int | None = Field(None, description="Error code")
|
||||||
|
message: str | None = Field(None, description="Error message")
|
||||||
|
request_id: str | None = Field(None, description="Request ID")
|
||||||
|
data: OmniTaskStatusResponseData | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniImageParamImage(BaseModel):
|
||||||
|
image: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProImageRequest(BaseModel):
|
||||||
|
model_name: str = Field(..., description="kling-image-o1")
|
||||||
|
resolution: str = Field(..., description="'1k' or '2k'")
|
||||||
|
aspect_ratio: str | None = Field(...)
|
||||||
|
prompt: str = Field(...)
|
||||||
|
mode: str = Field("pro")
|
||||||
|
n: int | None = Field(1, le=9)
|
||||||
|
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
|
||||||
@ -4,13 +4,14 @@ For source of truth on the allowed permutations of request fields, please refere
|
|||||||
- [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap)
|
- [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from typing_extensions import override
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
KlingCameraControl,
|
KlingCameraControl,
|
||||||
KlingCameraConfig,
|
KlingCameraConfig,
|
||||||
@ -48,23 +49,33 @@ from comfy_api_nodes.apis import (
|
|||||||
KlingCharacterEffectModelName,
|
KlingCharacterEffectModelName,
|
||||||
KlingSingleImageEffectModelName,
|
KlingSingleImageEffectModelName,
|
||||||
)
|
)
|
||||||
|
from comfy_api_nodes.apis.kling_api import (
|
||||||
|
OmniImageParamImage,
|
||||||
|
OmniParamImage,
|
||||||
|
OmniParamVideo,
|
||||||
|
OmniProFirstLastFrameRequest,
|
||||||
|
OmniProImageRequest,
|
||||||
|
OmniProReferences2VideoRequest,
|
||||||
|
OmniProText2VideoRequest,
|
||||||
|
OmniTaskStatusResponse,
|
||||||
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
validate_image_dimensions,
|
ApiEndpoint,
|
||||||
|
download_url_to_image_tensor,
|
||||||
|
download_url_to_video_output,
|
||||||
|
get_number_of_images,
|
||||||
|
poll_op,
|
||||||
|
sync_op,
|
||||||
|
tensor_to_base64_string,
|
||||||
|
upload_audio_to_comfyapi,
|
||||||
|
upload_images_to_comfyapi,
|
||||||
|
upload_video_to_comfyapi,
|
||||||
validate_image_aspect_ratio,
|
validate_image_aspect_ratio,
|
||||||
|
validate_image_dimensions,
|
||||||
|
validate_string,
|
||||||
validate_video_dimensions,
|
validate_video_dimensions,
|
||||||
validate_video_duration,
|
validate_video_duration,
|
||||||
tensor_to_base64_string,
|
|
||||||
validate_string,
|
|
||||||
upload_audio_to_comfyapi,
|
|
||||||
download_url_to_image_tensor,
|
|
||||||
upload_video_to_comfyapi,
|
|
||||||
download_url_to_video_output,
|
|
||||||
sync_op,
|
|
||||||
ApiEndpoint,
|
|
||||||
poll_op,
|
|
||||||
)
|
)
|
||||||
from comfy_api.input_impl import VideoFromFile
|
|
||||||
from comfy_api.latest import ComfyExtension, IO, Input
|
|
||||||
|
|
||||||
KLING_API_VERSION = "v1"
|
KLING_API_VERSION = "v1"
|
||||||
PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video"
|
PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video"
|
||||||
@ -202,6 +213,50 @@ VOICES_CONFIG = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_omni_prompt_references(prompt: str) -> str:
|
||||||
|
"""
|
||||||
|
Rewrites Kling Omni-style placeholders used in the app, like:
|
||||||
|
|
||||||
|
@image, @image1, @image2, ... @imageN
|
||||||
|
@video, @video1, @video2, ... @videoN
|
||||||
|
|
||||||
|
into the API-compatible form:
|
||||||
|
|
||||||
|
<<<image_1>>>, <<<image_2>>>, ...
|
||||||
|
<<<video_1>>>, <<<video_2>>>, ...
|
||||||
|
|
||||||
|
This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app.
|
||||||
|
"""
|
||||||
|
if not prompt:
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def _image_repl(match):
|
||||||
|
return f"<<<image_{match.group('idx') or '1'}>>>"
|
||||||
|
|
||||||
|
def _video_repl(match):
|
||||||
|
return f"<<<video_{match.group('idx') or '1'}>>>"
|
||||||
|
|
||||||
|
# (?<!\w) avoids matching e.g. "test@image.com"
|
||||||
|
# (?!\w) makes sure we only match @image / @image<digits> and not @imageFoo
|
||||||
|
prompt = re.sub(r"(?<!\w)@image(?P<idx>\d*)(?!\w)", _image_repl, prompt)
|
||||||
|
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
|
||||||
|
|
||||||
|
|
||||||
|
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput:
|
||||||
|
if response.code:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||||
|
)
|
||||||
|
final_response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
|
||||||
|
response_model=OmniTaskStatusResponse,
|
||||||
|
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||||
|
max_poll_attempts=160,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||||
|
|
||||||
|
|
||||||
def is_valid_camera_control_configs(configs: list[float]) -> bool:
|
def is_valid_camera_control_configs(configs: list[float]) -> bool:
|
||||||
"""Verifies that at least one camera control configuration is non-zero."""
|
"""Verifies that at least one camera control configuration is non-zero."""
|
||||||
return any(not math.isclose(value, 0.0) for value in configs)
|
return any(not math.isclose(value, 0.0) for value in configs)
|
||||||
@ -449,7 +504,7 @@ async def execute_video_effect(
|
|||||||
image_1: torch.Tensor,
|
image_1: torch.Tensor,
|
||||||
image_2: torch.Tensor | None = None,
|
image_2: torch.Tensor | None = None,
|
||||||
model_mode: KlingVideoGenMode | None = None,
|
model_mode: KlingVideoGenMode | None = None,
|
||||||
) -> tuple[VideoFromFile, str, str]:
|
) -> tuple[InputImpl.VideoFromFile, str, str]:
|
||||||
if dual_character:
|
if dual_character:
|
||||||
request_input_field = KlingDualCharacterEffectInput(
|
request_input_field = KlingDualCharacterEffectInput(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -736,6 +791,474 @@ class KlingTextToVideoNode(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProTextToVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingOmniProTextToVideoNode",
|
||||||
|
display_name="Kling Omni Text to Video (Pro)",
|
||||||
|
category="api node/video/Kling",
|
||||||
|
description="Use text prompts to generate videos with the latest Kling model.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="A text prompt describing the video content. "
|
||||||
|
"This can include both positive and negative descriptions.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||||
|
IO.Combo.Input("duration", options=[5, 10]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
prompt: str,
|
||||||
|
aspect_ratio: str,
|
||||||
|
duration: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
|
response_model=OmniTaskStatusResponse,
|
||||||
|
data=OmniProText2VideoRequest(
|
||||||
|
model_name=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
duration=str(duration),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return await finish_omni_video_task(cls, response)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingOmniProFirstLastFrameNode",
|
||||||
|
display_name="Kling Omni First-Last-Frame to Video (Pro)",
|
||||||
|
category="api node/video/Kling",
|
||||||
|
description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="A text prompt describing the video content. "
|
||||||
|
"This can include both positive and negative descriptions.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("duration", options=["5", "10"]),
|
||||||
|
IO.Image.Input("first_frame"),
|
||||||
|
IO.Image.Input(
|
||||||
|
"end_frame",
|
||||||
|
optional=True,
|
||||||
|
tooltip="An optional end frame for the video. "
|
||||||
|
"This cannot be used simultaneously with 'reference_images'.",
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"reference_images",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Up to 6 additional reference images.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
prompt: str,
|
||||||
|
duration: int,
|
||||||
|
first_frame: Input.Image,
|
||||||
|
end_frame: Input.Image | None = None,
|
||||||
|
reference_images: Input.Image | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
|
if end_frame is not None and reference_images is not None:
|
||||||
|
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
|
||||||
|
validate_image_dimensions(first_frame, min_width=300, min_height=300)
|
||||||
|
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
|
||||||
|
image_list: list[OmniParamImage] = [
|
||||||
|
OmniParamImage(
|
||||||
|
image_url=(await upload_images_to_comfyapi(cls, first_frame, wait_label="Uploading first frame"))[0],
|
||||||
|
type="first_frame",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if end_frame is not None:
|
||||||
|
validate_image_dimensions(end_frame, min_width=300, min_height=300)
|
||||||
|
validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
|
||||||
|
image_list.append(
|
||||||
|
OmniParamImage(
|
||||||
|
image_url=(await upload_images_to_comfyapi(cls, end_frame, wait_label="Uploading end frame"))[0],
|
||||||
|
type="end_frame",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if reference_images is not None:
|
||||||
|
if get_number_of_images(reference_images) > 6:
|
||||||
|
raise ValueError("The maximum number of reference images allowed is 6.")
|
||||||
|
for i in reference_images:
|
||||||
|
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||||
|
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||||
|
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"):
|
||||||
|
image_list.append(OmniParamImage(image_url=i))
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
|
response_model=OmniTaskStatusResponse,
|
||||||
|
data=OmniProFirstLastFrameRequest(
|
||||||
|
model_name=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
duration=str(duration),
|
||||||
|
image_list=image_list,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return await finish_omni_video_task(cls, response)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProImageToVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingOmniProImageToVideoNode",
|
||||||
|
display_name="Kling Omni Image to Video (Pro)",
|
||||||
|
category="api node/video/Kling",
|
||||||
|
description="Use up to 7 reference images to generate a video with the latest Kling model.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="A text prompt describing the video content. "
|
||||||
|
"This can include both positive and negative descriptions.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||||
|
IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
|
||||||
|
IO.Image.Input(
|
||||||
|
"reference_images",
|
||||||
|
tooltip="Up to 7 reference images.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
prompt: str,
|
||||||
|
aspect_ratio: str,
|
||||||
|
duration: int,
|
||||||
|
reference_images: Input.Image,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
|
if get_number_of_images(reference_images) > 7:
|
||||||
|
raise ValueError("The maximum number of reference images is 7.")
|
||||||
|
for i in reference_images:
|
||||||
|
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||||
|
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||||
|
image_list: list[OmniParamImage] = []
|
||||||
|
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||||
|
image_list.append(OmniParamImage(image_url=i))
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
|
response_model=OmniTaskStatusResponse,
|
||||||
|
data=OmniProReferences2VideoRequest(
|
||||||
|
model_name=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
duration=str(duration),
|
||||||
|
image_list=image_list,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return await finish_omni_video_task(cls, response)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingOmniProVideoToVideoNode",
|
||||||
|
display_name="Kling Omni Video to Video (Pro)",
|
||||||
|
category="api node/video/Kling",
|
||||||
|
description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="A text prompt describing the video content. "
|
||||||
|
"This can include both positive and negative descriptions.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||||
|
IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
|
||||||
|
IO.Video.Input("reference_video", tooltip="Video to use as a reference."),
|
||||||
|
IO.Boolean.Input("keep_original_sound", default=True),
|
||||||
|
IO.Image.Input(
|
||||||
|
"reference_images",
|
||||||
|
tooltip="Up to 4 additional reference images.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
prompt: str,
|
||||||
|
aspect_ratio: str,
|
||||||
|
duration: int,
|
||||||
|
reference_video: Input.Video,
|
||||||
|
keep_original_sound: bool,
|
||||||
|
reference_images: Input.Image | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
|
validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
|
||||||
|
validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160)
|
||||||
|
image_list: list[OmniParamImage] = []
|
||||||
|
if reference_images is not None:
|
||||||
|
if get_number_of_images(reference_images) > 4:
|
||||||
|
raise ValueError("The maximum number of reference images allowed with a video input is 4.")
|
||||||
|
for i in reference_images:
|
||||||
|
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||||
|
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||||
|
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||||
|
image_list.append(OmniParamImage(image_url=i))
|
||||||
|
video_list = [
|
||||||
|
OmniParamVideo(
|
||||||
|
video_url=await upload_video_to_comfyapi(cls, reference_video, wait_label="Uploading reference video"),
|
||||||
|
refer_type="feature",
|
||||||
|
keep_original_sound="yes" if keep_original_sound else "no",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
|
response_model=OmniTaskStatusResponse,
|
||||||
|
data=OmniProReferences2VideoRequest(
|
||||||
|
model_name=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
duration=str(duration),
|
||||||
|
image_list=image_list if image_list else None,
|
||||||
|
video_list=video_list,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return await finish_omni_video_task(cls, response)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProEditVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingOmniProEditVideoNode",
|
||||||
|
display_name="Kling Omni Edit Video (Pro)",
|
||||||
|
category="api node/video/Kling",
|
||||||
|
description="Edit an existing video with the latest model from Kling.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="A text prompt describing the video content. "
|
||||||
|
"This can include both positive and negative descriptions.",
|
||||||
|
),
|
||||||
|
IO.Video.Input("video", tooltip="Video for editing. The output video length will be the same."),
|
||||||
|
IO.Boolean.Input("keep_original_sound", default=True),
|
||||||
|
IO.Image.Input(
|
||||||
|
"reference_images",
|
||||||
|
tooltip="Up to 4 additional reference images.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
prompt: str,
|
||||||
|
video: Input.Video,
|
||||||
|
keep_original_sound: bool,
|
||||||
|
reference_images: Input.Image | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
|
validate_video_duration(video, min_duration=3.0, max_duration=10.05)
|
||||||
|
validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160)
|
||||||
|
image_list: list[OmniParamImage] = []
|
||||||
|
if reference_images is not None:
|
||||||
|
if get_number_of_images(reference_images) > 4:
|
||||||
|
raise ValueError("The maximum number of reference images allowed with a video input is 4.")
|
||||||
|
for i in reference_images:
|
||||||
|
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||||
|
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||||
|
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||||
|
image_list.append(OmniParamImage(image_url=i))
|
||||||
|
video_list = [
|
||||||
|
OmniParamVideo(
|
||||||
|
video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading base video"),
|
||||||
|
refer_type="base",
|
||||||
|
keep_original_sound="yes" if keep_original_sound else "no",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
|
response_model=OmniTaskStatusResponse,
|
||||||
|
data=OmniProReferences2VideoRequest(
|
||||||
|
model_name=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
aspect_ratio=None,
|
||||||
|
duration=None,
|
||||||
|
image_list=image_list if image_list else None,
|
||||||
|
video_list=video_list,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return await finish_omni_video_task(cls, response)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProImageNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingOmniProImageNode",
|
||||||
|
display_name="Kling Omni Image (Pro)",
|
||||||
|
category="api node/image/Kling",
|
||||||
|
description="Create or edit images with the latest model from Kling.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_name", options=["kling-image-o1"]),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="A text prompt describing the image content. "
|
||||||
|
"This can include both positive and negative descriptions.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("resolution", options=["1K", "2K"]),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"],
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"reference_images",
|
||||||
|
tooltip="Up to 10 additional reference images.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Image.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
prompt: str,
|
||||||
|
resolution: str,
|
||||||
|
aspect_ratio: str,
|
||||||
|
reference_images: Input.Image | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
|
image_list: list[OmniImageParamImage] = []
|
||||||
|
if reference_images is not None:
|
||||||
|
if get_number_of_images(reference_images) > 10:
|
||||||
|
raise ValueError("The maximum number of reference images is 10.")
|
||||||
|
for i in reference_images:
|
||||||
|
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||||
|
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||||
|
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||||
|
image_list.append(OmniImageParamImage(image=i))
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
|
||||||
|
response_model=OmniTaskStatusResponse,
|
||||||
|
data=OmniProImageRequest(
|
||||||
|
model_name=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
resolution=resolution.lower(),
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
image_list=image_list if image_list else None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if response.code:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||||
|
)
|
||||||
|
final_response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
|
||||||
|
response_model=OmniTaskStatusResponse,
|
||||||
|
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
|
||||||
|
|
||||||
|
|
||||||
class KlingCameraControlT2VNode(IO.ComfyNode):
|
class KlingCameraControlT2VNode(IO.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
|
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
|
||||||
@ -1162,7 +1685,10 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
|
|||||||
category="api node/video/Kling",
|
category="api node/video/Kling",
|
||||||
description="Achieve different special effects when generating a video based on the effect_scene.",
|
description="Achieve different special effects when generating a video based on the effect_scene.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image", tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1"),
|
IO.Image.Input(
|
||||||
|
"image",
|
||||||
|
tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1",
|
||||||
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"effect_scene",
|
"effect_scene",
|
||||||
options=[i.value for i in KlingSingleImageEffectsScene],
|
options=[i.value for i in KlingSingleImageEffectsScene],
|
||||||
@ -1525,6 +2051,12 @@ class KlingExtension(ComfyExtension):
|
|||||||
KlingImageGenerationNode,
|
KlingImageGenerationNode,
|
||||||
KlingSingleImageVideoEffectNode,
|
KlingSingleImageVideoEffectNode,
|
||||||
KlingDualCharacterVideoEffectNode,
|
KlingDualCharacterVideoEffectNode,
|
||||||
|
OmniProTextToVideoNode,
|
||||||
|
OmniProFirstLastFrameNode,
|
||||||
|
OmniProImageToVideoNode,
|
||||||
|
OmniProVideoToVideoNode,
|
||||||
|
OmniProEditVideoNode,
|
||||||
|
# OmniProImageNode, # need support from backend
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -47,6 +47,7 @@ from .validation_utils import (
|
|||||||
validate_string,
|
validate_string,
|
||||||
validate_video_dimensions,
|
validate_video_dimensions,
|
||||||
validate_video_duration,
|
validate_video_duration,
|
||||||
|
validate_video_frame_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -94,6 +95,7 @@ __all__ = [
|
|||||||
"validate_string",
|
"validate_string",
|
||||||
"validate_video_dimensions",
|
"validate_video_dimensions",
|
||||||
"validate_video_duration",
|
"validate_video_duration",
|
||||||
|
"validate_video_frame_count",
|
||||||
# Misc functions
|
# Misc functions
|
||||||
"get_fs_object_size",
|
"get_fs_object_size",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -2,8 +2,8 @@ import asyncio
|
|||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Callable, Optional, Union
|
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy.model_management import processing_interrupted
|
from comfy.model_management import processing_interrupted
|
||||||
@ -35,12 +35,12 @@ def default_base_url() -> str:
|
|||||||
|
|
||||||
async def sleep_with_interrupt(
|
async def sleep_with_interrupt(
|
||||||
seconds: float,
|
seconds: float,
|
||||||
node_cls: Optional[type[IO.ComfyNode]],
|
node_cls: type[IO.ComfyNode] | None,
|
||||||
label: Optional[str] = None,
|
label: str | None = None,
|
||||||
start_ts: Optional[float] = None,
|
start_ts: float | None = None,
|
||||||
estimated_total: Optional[int] = None,
|
estimated_total: int | None = None,
|
||||||
*,
|
*,
|
||||||
display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
|
display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Sleep in 1s slices while:
|
Sleep in 1s slices while:
|
||||||
@ -65,7 +65,7 @@ def mimetype_to_extension(mime_type: str) -> str:
|
|||||||
return mime_type.split("/")[-1].lower()
|
return mime_type.split("/")[-1].lower()
|
||||||
|
|
||||||
|
|
||||||
def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
|
def get_fs_object_size(path_or_object: str | BytesIO) -> int:
|
||||||
if isinstance(path_or_object, str):
|
if isinstance(path_or_object, str):
|
||||||
return os.path.getsize(path_or_object)
|
return os.path.getsize(path_or_object)
|
||||||
return len(path_or_object.getvalue())
|
return len(path_or_object.getvalue())
|
||||||
|
|||||||
@ -4,10 +4,11 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import Callable, Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
|
from typing import Any, Literal, TypeVar
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -37,8 +38,8 @@ class ApiEndpoint:
|
|||||||
path: str,
|
path: str,
|
||||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
|
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
|
||||||
*,
|
*,
|
||||||
query_params: Optional[dict[str, Any]] = None,
|
query_params: dict[str, Any] | None = None,
|
||||||
headers: Optional[dict[str, str]] = None,
|
headers: dict[str, str] | None = None,
|
||||||
):
|
):
|
||||||
self.path = path
|
self.path = path
|
||||||
self.method = method
|
self.method = method
|
||||||
@ -52,18 +53,18 @@ class _RequestConfig:
|
|||||||
endpoint: ApiEndpoint
|
endpoint: ApiEndpoint
|
||||||
timeout: float
|
timeout: float
|
||||||
content_type: str
|
content_type: str
|
||||||
data: Optional[dict[str, Any]]
|
data: dict[str, Any] | None
|
||||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
|
files: dict[str, Any] | list[tuple[str, Any]] | None
|
||||||
multipart_parser: Optional[Callable]
|
multipart_parser: Callable | None
|
||||||
max_retries: int
|
max_retries: int
|
||||||
retry_delay: float
|
retry_delay: float
|
||||||
retry_backoff: float
|
retry_backoff: float
|
||||||
wait_label: str = "Waiting"
|
wait_label: str = "Waiting"
|
||||||
monitor_progress: bool = True
|
monitor_progress: bool = True
|
||||||
estimated_total: Optional[int] = None
|
estimated_total: int | None = None
|
||||||
final_label_on_success: Optional[str] = "Completed"
|
final_label_on_success: str | None = "Completed"
|
||||||
progress_origin_ts: Optional[float] = None
|
progress_origin_ts: float | None = None
|
||||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None
|
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -71,10 +72,10 @@ class _PollUIState:
|
|||||||
started: float
|
started: float
|
||||||
status_label: str = "Queued"
|
status_label: str = "Queued"
|
||||||
is_queued: bool = True
|
is_queued: bool = True
|
||||||
price: Optional[float] = None
|
price: float | None = None
|
||||||
estimated_duration: Optional[int] = None
|
estimated_duration: int | None = None
|
||||||
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
||||||
active_since: Optional[float] = None # start time of current active interval (None if queued)
|
active_since: float | None = None # start time of current active interval (None if queued)
|
||||||
|
|
||||||
|
|
||||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||||
@ -87,20 +88,20 @@ async def sync_op(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
endpoint: ApiEndpoint,
|
endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
response_model: Type[M],
|
response_model: type[M],
|
||||||
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
price_extractor: Callable[[M | Any], float | None] | None = None,
|
||||||
data: Optional[BaseModel] = None,
|
data: BaseModel | None = None,
|
||||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
files: dict[str, Any] | list[tuple[str, Any]] | None = None,
|
||||||
content_type: str = "application/json",
|
content_type: str = "application/json",
|
||||||
timeout: float = 3600.0,
|
timeout: float = 3600.0,
|
||||||
multipart_parser: Optional[Callable] = None,
|
multipart_parser: Callable | None = None,
|
||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
retry_backoff: float = 2.0,
|
retry_backoff: float = 2.0,
|
||||||
wait_label: str = "Waiting for server",
|
wait_label: str = "Waiting for server",
|
||||||
estimated_duration: Optional[int] = None,
|
estimated_duration: int | None = None,
|
||||||
final_label_on_success: Optional[str] = "Completed",
|
final_label_on_success: str | None = "Completed",
|
||||||
progress_origin_ts: Optional[float] = None,
|
progress_origin_ts: float | None = None,
|
||||||
monitor_progress: bool = True,
|
monitor_progress: bool = True,
|
||||||
) -> M:
|
) -> M:
|
||||||
raw = await sync_op_raw(
|
raw = await sync_op_raw(
|
||||||
@ -131,22 +132,22 @@ async def poll_op(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
poll_endpoint: ApiEndpoint,
|
poll_endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
response_model: Type[M],
|
response_model: type[M],
|
||||||
status_extractor: Callable[[M], Optional[Union[str, int]]],
|
status_extractor: Callable[[M | Any], str | int | None],
|
||||||
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
|
progress_extractor: Callable[[M | Any], int | None] | None = None,
|
||||||
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
price_extractor: Callable[[M | Any], float | None] | None = None,
|
||||||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
completed_statuses: list[str | int] | None = None,
|
||||||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
failed_statuses: list[str | int] | None = None,
|
||||||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
queued_statuses: list[str | int] | None = None,
|
||||||
data: Optional[BaseModel] = None,
|
data: BaseModel | None = None,
|
||||||
poll_interval: float = 5.0,
|
poll_interval: float = 5.0,
|
||||||
max_poll_attempts: int = 120,
|
max_poll_attempts: int = 120,
|
||||||
timeout_per_poll: float = 120.0,
|
timeout_per_poll: float = 120.0,
|
||||||
max_retries_per_poll: int = 3,
|
max_retries_per_poll: int = 3,
|
||||||
retry_delay_per_poll: float = 1.0,
|
retry_delay_per_poll: float = 1.0,
|
||||||
retry_backoff_per_poll: float = 2.0,
|
retry_backoff_per_poll: float = 2.0,
|
||||||
estimated_duration: Optional[int] = None,
|
estimated_duration: int | None = None,
|
||||||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
cancel_endpoint: ApiEndpoint | None = None,
|
||||||
cancel_timeout: float = 10.0,
|
cancel_timeout: float = 10.0,
|
||||||
) -> M:
|
) -> M:
|
||||||
raw = await poll_op_raw(
|
raw = await poll_op_raw(
|
||||||
@ -178,22 +179,22 @@ async def sync_op_raw(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
endpoint: ApiEndpoint,
|
endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
|
||||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
data: dict[str, Any] | BaseModel | None = None,
|
||||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
files: dict[str, Any] | list[tuple[str, Any]] | None = None,
|
||||||
content_type: str = "application/json",
|
content_type: str = "application/json",
|
||||||
timeout: float = 3600.0,
|
timeout: float = 3600.0,
|
||||||
multipart_parser: Optional[Callable] = None,
|
multipart_parser: Callable | None = None,
|
||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
retry_backoff: float = 2.0,
|
retry_backoff: float = 2.0,
|
||||||
wait_label: str = "Waiting for server",
|
wait_label: str = "Waiting for server",
|
||||||
estimated_duration: Optional[int] = None,
|
estimated_duration: int | None = None,
|
||||||
as_binary: bool = False,
|
as_binary: bool = False,
|
||||||
final_label_on_success: Optional[str] = "Completed",
|
final_label_on_success: str | None = "Completed",
|
||||||
progress_origin_ts: Optional[float] = None,
|
progress_origin_ts: float | None = None,
|
||||||
monitor_progress: bool = True,
|
monitor_progress: bool = True,
|
||||||
) -> Union[dict[str, Any], bytes]:
|
) -> dict[str, Any] | bytes:
|
||||||
"""
|
"""
|
||||||
Make a single network request.
|
Make a single network request.
|
||||||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||||||
@ -229,21 +230,21 @@ async def poll_op_raw(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
poll_endpoint: ApiEndpoint,
|
poll_endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
|
status_extractor: Callable[[dict[str, Any]], str | int | None],
|
||||||
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
|
progress_extractor: Callable[[dict[str, Any]], int | None] | None = None,
|
||||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
|
||||||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
completed_statuses: list[str | int] | None = None,
|
||||||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
failed_statuses: list[str | int] | None = None,
|
||||||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
queued_statuses: list[str | int] | None = None,
|
||||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
data: dict[str, Any] | BaseModel | None = None,
|
||||||
poll_interval: float = 5.0,
|
poll_interval: float = 5.0,
|
||||||
max_poll_attempts: int = 120,
|
max_poll_attempts: int = 120,
|
||||||
timeout_per_poll: float = 120.0,
|
timeout_per_poll: float = 120.0,
|
||||||
max_retries_per_poll: int = 3,
|
max_retries_per_poll: int = 3,
|
||||||
retry_delay_per_poll: float = 1.0,
|
retry_delay_per_poll: float = 1.0,
|
||||||
retry_backoff_per_poll: float = 2.0,
|
retry_backoff_per_poll: float = 2.0,
|
||||||
estimated_duration: Optional[int] = None,
|
estimated_duration: int | None = None,
|
||||||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
cancel_endpoint: ApiEndpoint | None = None,
|
||||||
cancel_timeout: float = 10.0,
|
cancel_timeout: float = 10.0,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@ -261,7 +262,7 @@ async def poll_op_raw(
|
|||||||
consumed_attempts = 0 # counts only non-queued polls
|
consumed_attempts = 0 # counts only non-queued polls
|
||||||
|
|
||||||
progress_bar = utils.ProgressBar(100) if progress_extractor else None
|
progress_bar = utils.ProgressBar(100) if progress_extractor else None
|
||||||
last_progress: Optional[int] = None
|
last_progress: int | None = None
|
||||||
|
|
||||||
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
||||||
stop_ticker = asyncio.Event()
|
stop_ticker = asyncio.Event()
|
||||||
@ -420,10 +421,10 @@ async def poll_op_raw(
|
|||||||
|
|
||||||
def _display_text(
|
def _display_text(
|
||||||
node_cls: type[IO.ComfyNode],
|
node_cls: type[IO.ComfyNode],
|
||||||
text: Optional[str],
|
text: str | None,
|
||||||
*,
|
*,
|
||||||
status: Optional[Union[str, int]] = None,
|
status: str | int | None = None,
|
||||||
price: Optional[float] = None,
|
price: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
display_lines: list[str] = []
|
display_lines: list[str] = []
|
||||||
if status:
|
if status:
|
||||||
@ -440,13 +441,13 @@ def _display_text(
|
|||||||
|
|
||||||
def _display_time_progress(
|
def _display_time_progress(
|
||||||
node_cls: type[IO.ComfyNode],
|
node_cls: type[IO.ComfyNode],
|
||||||
status: Optional[Union[str, int]],
|
status: str | int | None,
|
||||||
elapsed_seconds: int,
|
elapsed_seconds: int,
|
||||||
estimated_total: Optional[int] = None,
|
estimated_total: int | None = None,
|
||||||
*,
|
*,
|
||||||
price: Optional[float] = None,
|
price: float | None = None,
|
||||||
is_queued: Optional[bool] = None,
|
is_queued: bool | None = None,
|
||||||
processing_elapsed_seconds: Optional[int] = None,
|
processing_elapsed_seconds: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||||||
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
||||||
@ -488,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
|
|||||||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||||||
|
|
||||||
|
|
||||||
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
|
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
|
||||||
params = dict(endpoint_params or {})
|
params = dict(endpoint_params or {})
|
||||||
if method.upper() == "GET" and data:
|
if method.upper() == "GET" and data:
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
@ -534,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str:
|
|||||||
def _snapshot_request_body_for_logging(
|
def _snapshot_request_body_for_logging(
|
||||||
content_type: str,
|
content_type: str,
|
||||||
method: str,
|
method: str,
|
||||||
data: Optional[dict[str, Any]],
|
data: dict[str, Any] | None,
|
||||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
|
files: dict[str, Any] | list[tuple[str, Any]] | None,
|
||||||
) -> Optional[Union[dict[str, Any], str]]:
|
) -> dict[str, Any] | str | None:
|
||||||
if method.upper() == "GET":
|
if method.upper() == "GET":
|
||||||
return None
|
return None
|
||||||
if content_type == "multipart/form-data":
|
if content_type == "multipart/form-data":
|
||||||
@ -586,13 +587,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
attempt = 0
|
attempt = 0
|
||||||
delay = cfg.retry_delay
|
delay = cfg.retry_delay
|
||||||
operation_succeeded: bool = False
|
operation_succeeded: bool = False
|
||||||
final_elapsed_seconds: Optional[int] = None
|
final_elapsed_seconds: int | None = None
|
||||||
extracted_price: Optional[float] = None
|
extracted_price: float | None = None
|
||||||
while True:
|
while True:
|
||||||
attempt += 1
|
attempt += 1
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
monitor_task: Optional[asyncio.Task] = None
|
monitor_task: asyncio.Task | None = None
|
||||||
sess: Optional[aiohttp.ClientSession] = None
|
sess: aiohttp.ClientSession | None = None
|
||||||
|
|
||||||
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||||||
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||||||
@ -887,7 +888,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
|
def _validate_or_raise(response_model: type[M], payload: Any) -> M:
|
||||||
try:
|
try:
|
||||||
return response_model.model_validate(payload)
|
return response_model.model_validate(payload)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -902,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
|
|||||||
|
|
||||||
|
|
||||||
def _wrap_model_extractor(
|
def _wrap_model_extractor(
|
||||||
response_model: Type[M],
|
response_model: type[M],
|
||||||
extractor: Optional[Callable[[M], Any]],
|
extractor: Callable[[M], Any] | None,
|
||||||
) -> Optional[Callable[[dict[str, Any]], Any]]:
|
) -> Callable[[dict[str, Any]], Any] | None:
|
||||||
"""Wrap a typed extractor so it can be used by the dict-based poller.
|
"""Wrap a typed extractor so it can be used by the dict-based poller.
|
||||||
Validates the dict into `response_model` before invoking `extractor`.
|
Validates the dict into `response_model` before invoking `extractor`.
|
||||||
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
|
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
|
||||||
@ -929,10 +930,10 @@ def _wrap_model_extractor(
|
|||||||
return _wrapped
|
return _wrapped
|
||||||
|
|
||||||
|
|
||||||
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
|
def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]:
|
||||||
if not values:
|
if not values:
|
||||||
return set()
|
return set()
|
||||||
out: set[Union[str, int]] = set()
|
out: set[str | int] = set()
|
||||||
for v in values:
|
for v in values:
|
||||||
nv = _normalize_status_value(v)
|
nv = _normalize_status_value(v)
|
||||||
if nv is not None:
|
if nv is not None:
|
||||||
@ -940,7 +941,7 @@ def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Unio
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
|
def _normalize_status_value(val: str | int | None) -> str | int | None:
|
||||||
if isinstance(val, str):
|
if isinstance(val, str):
|
||||||
return val.strip().lower()
|
return val.strip().lower()
|
||||||
return val
|
return val
|
||||||
|
|||||||
@ -4,7 +4,6 @@ import math
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import uuid
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import av
|
import av
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -12,8 +11,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from comfy.utils import common_upscale
|
from comfy.utils import common_upscale
|
||||||
from comfy_api.latest import Input, InputImpl
|
from comfy_api.latest import Input, InputImpl, Types
|
||||||
from comfy_api.util import VideoCodec, VideoContainer
|
|
||||||
|
|
||||||
from ._helpers import mimetype_to_extension
|
from ._helpers import mimetype_to_extension
|
||||||
|
|
||||||
@ -57,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
|
|||||||
|
|
||||||
def tensor_to_bytesio(
|
def tensor_to_bytesio(
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
name: Optional[str] = None,
|
name: str | None = None,
|
||||||
total_pixels: int = 2048 * 2048,
|
total_pixels: int = 2048 * 2048,
|
||||||
mime_type: str = "image/png",
|
mime_type: str = "image/png",
|
||||||
) -> BytesIO:
|
) -> BytesIO:
|
||||||
@ -177,8 +175,8 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co
|
|||||||
|
|
||||||
def video_to_base64_string(
|
def video_to_base64_string(
|
||||||
video: Input.Video,
|
video: Input.Video,
|
||||||
container_format: VideoContainer = None,
|
container_format: Types.VideoContainer | None = None,
|
||||||
codec: VideoCodec = None
|
codec: Types.VideoCodec | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Converts a video input to a base64 string.
|
Converts a video input to a base64 string.
|
||||||
@ -189,12 +187,11 @@ def video_to_base64_string(
|
|||||||
codec: Optional codec to use (defaults to video.codec if available)
|
codec: Optional codec to use (defaults to video.codec if available)
|
||||||
"""
|
"""
|
||||||
video_bytes_io = BytesIO()
|
video_bytes_io = BytesIO()
|
||||||
|
video.save_to(
|
||||||
# Use provided format/codec if specified, otherwise use video's own if available
|
video_bytes_io,
|
||||||
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
format=container_format or getattr(video, "container", Types.VideoContainer.MP4),
|
||||||
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
codec=codec or getattr(video, "codec", Types.VideoCodec.H264),
|
||||||
|
)
|
||||||
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
|
||||||
video_bytes_io.seek(0)
|
video_bytes_io.seek(0)
|
||||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
|||||||
@ -3,15 +3,15 @@ import contextlib
|
|||||||
import uuid
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import IO, Optional, Union
|
from typing import IO
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import torch
|
import torch
|
||||||
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||||
|
|
||||||
from comfy_api.input_impl import VideoFromFile
|
|
||||||
from comfy_api.latest import IO as COMFY_IO
|
from comfy_api.latest import IO as COMFY_IO
|
||||||
|
from comfy_api.latest import InputImpl
|
||||||
|
|
||||||
from . import request_logger
|
from . import request_logger
|
||||||
from ._helpers import (
|
from ._helpers import (
|
||||||
@ -29,9 +29,9 @@ _RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
|||||||
|
|
||||||
async def download_url_to_bytesio(
|
async def download_url_to_bytesio(
|
||||||
url: str,
|
url: str,
|
||||||
dest: Optional[Union[BytesIO, IO[bytes], str, Path]],
|
dest: BytesIO | IO[bytes] | str | Path | None,
|
||||||
*,
|
*,
|
||||||
timeout: Optional[float] = None,
|
timeout: float | None = None,
|
||||||
max_retries: int = 5,
|
max_retries: int = 5,
|
||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
retry_backoff: float = 2.0,
|
retry_backoff: float = 2.0,
|
||||||
@ -71,10 +71,10 @@ async def download_url_to_bytesio(
|
|||||||
|
|
||||||
is_path_sink = isinstance(dest, (str, Path))
|
is_path_sink = isinstance(dest, (str, Path))
|
||||||
fhandle = None
|
fhandle = None
|
||||||
session: Optional[aiohttp.ClientSession] = None
|
session: aiohttp.ClientSession | None = None
|
||||||
stop_evt: Optional[asyncio.Event] = None
|
stop_evt: asyncio.Event | None = None
|
||||||
monitor_task: Optional[asyncio.Task] = None
|
monitor_task: asyncio.Task | None = None
|
||||||
req_task: Optional[asyncio.Task] = None
|
req_task: asyncio.Task | None = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
@ -234,11 +234,11 @@ async def download_url_to_video_output(
|
|||||||
timeout: float = None,
|
timeout: float = None,
|
||||||
max_retries: int = 5,
|
max_retries: int = 5,
|
||||||
cls: type[COMFY_IO.ComfyNode] = None,
|
cls: type[COMFY_IO.ComfyNode] = None,
|
||||||
) -> VideoFromFile:
|
) -> InputImpl.VideoFromFile:
|
||||||
"""Downloads a video from a URL and returns a `VIDEO` output."""
|
"""Downloads a video from a URL and returns a `VIDEO` output."""
|
||||||
result = BytesIO()
|
result = BytesIO()
|
||||||
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
|
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
|
||||||
return VideoFromFile(result)
|
return InputImpl.VideoFromFile(result)
|
||||||
|
|
||||||
|
|
||||||
async def download_url_as_bytesio(
|
async def download_url_as_bytesio(
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
|||||||
@ -4,15 +4,13 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from comfy_api.latest import IO, Input
|
from comfy_api.latest import IO, Input, Types
|
||||||
from comfy_api.util import VideoCodec, VideoContainer
|
|
||||||
|
|
||||||
from . import request_logger
|
from . import request_logger
|
||||||
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
||||||
@ -32,7 +30,7 @@ from .conversions import (
|
|||||||
|
|
||||||
class UploadRequest(BaseModel):
|
class UploadRequest(BaseModel):
|
||||||
file_name: str = Field(..., description="Filename to upload")
|
file_name: str = Field(..., description="Filename to upload")
|
||||||
content_type: Optional[str] = Field(
|
content_type: str | None = Field(
|
||||||
None,
|
None,
|
||||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||||
)
|
)
|
||||||
@ -56,7 +54,7 @@ async def upload_images_to_comfyapi(
|
|||||||
Uploads images to ComfyUI API and returns download URLs.
|
Uploads images to ComfyUI API and returns download URLs.
|
||||||
To upload multiple images, stack them in the batch dimension first.
|
To upload multiple images, stack them in the batch dimension first.
|
||||||
"""
|
"""
|
||||||
# if batch, try to upload each file if max_images is greater than 0
|
# if batched, try to upload each file if max_images is greater than 0
|
||||||
download_urls: list[str] = []
|
download_urls: list[str] = []
|
||||||
is_batch = len(image.shape) > 3
|
is_batch = len(image.shape) > 3
|
||||||
batch_len = image.shape[0] if is_batch else 1
|
batch_len = image.shape[0] if is_batch else 1
|
||||||
@ -100,9 +98,10 @@ async def upload_video_to_comfyapi(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
video: Input.Video,
|
video: Input.Video,
|
||||||
*,
|
*,
|
||||||
container: VideoContainer = VideoContainer.MP4,
|
container: Types.VideoContainer = Types.VideoContainer.MP4,
|
||||||
codec: VideoCodec = VideoCodec.H264,
|
codec: Types.VideoCodec = Types.VideoCodec.H264,
|
||||||
max_duration: Optional[int] = None,
|
max_duration: int | None = None,
|
||||||
|
wait_label: str | None = "Uploading",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Uploads a single video to ComfyUI API and returns its download URL.
|
Uploads a single video to ComfyUI API and returns its download URL.
|
||||||
@ -127,7 +126,7 @@ async def upload_video_to_comfyapi(
|
|||||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||||
video_bytes_io.seek(0)
|
video_bytes_io.seek(0)
|
||||||
|
|
||||||
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type)
|
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
|
||||||
|
|
||||||
|
|
||||||
async def upload_file_to_comfyapi(
|
async def upload_file_to_comfyapi(
|
||||||
@ -219,7 +218,7 @@ async def upload_file(
|
|||||||
return
|
return
|
||||||
|
|
||||||
monitor_task = asyncio.create_task(_monitor())
|
monitor_task = asyncio.create_task(_monitor())
|
||||||
sess: Optional[aiohttp.ClientSession] = None
|
sess: aiohttp.ClientSession | None = None
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
request_logger.log_request_response(
|
request_logger.log_request_response(
|
||||||
|
|||||||
@ -1,9 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy_api.input.video_types import VideoInput
|
|
||||||
from comfy_api.latest import Input
|
from comfy_api.latest import Input
|
||||||
|
|
||||||
|
|
||||||
@ -18,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
|
|||||||
|
|
||||||
def validate_image_dimensions(
|
def validate_image_dimensions(
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
min_width: Optional[int] = None,
|
min_width: int | None = None,
|
||||||
max_width: Optional[int] = None,
|
max_width: int | None = None,
|
||||||
min_height: Optional[int] = None,
|
min_height: int | None = None,
|
||||||
max_height: Optional[int] = None,
|
max_height: int | None = None,
|
||||||
):
|
):
|
||||||
height, width = get_image_dimensions(image)
|
height, width = get_image_dimensions(image)
|
||||||
|
|
||||||
@ -37,8 +35,8 @@ def validate_image_dimensions(
|
|||||||
|
|
||||||
def validate_image_aspect_ratio(
|
def validate_image_aspect_ratio(
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
|
||||||
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
|
||||||
*,
|
*,
|
||||||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||||
) -> float:
|
) -> float:
|
||||||
@ -84,8 +82,8 @@ def validate_images_aspect_ratio_closeness(
|
|||||||
|
|
||||||
def validate_aspect_ratio_string(
|
def validate_aspect_ratio_string(
|
||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
|
||||||
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
|
||||||
*,
|
*,
|
||||||
strict: bool = False, # True -> (min, max); False -> [min, max]
|
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||||
) -> float:
|
) -> float:
|
||||||
@ -97,10 +95,10 @@ def validate_aspect_ratio_string(
|
|||||||
|
|
||||||
def validate_video_dimensions(
|
def validate_video_dimensions(
|
||||||
video: Input.Video,
|
video: Input.Video,
|
||||||
min_width: Optional[int] = None,
|
min_width: int | None = None,
|
||||||
max_width: Optional[int] = None,
|
max_width: int | None = None,
|
||||||
min_height: Optional[int] = None,
|
min_height: int | None = None,
|
||||||
max_height: Optional[int] = None,
|
max_height: int | None = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
width, height = video.get_dimensions()
|
width, height = video.get_dimensions()
|
||||||
@ -120,8 +118,8 @@ def validate_video_dimensions(
|
|||||||
|
|
||||||
def validate_video_duration(
|
def validate_video_duration(
|
||||||
video: Input.Video,
|
video: Input.Video,
|
||||||
min_duration: Optional[float] = None,
|
min_duration: float | None = None,
|
||||||
max_duration: Optional[float] = None,
|
max_duration: float | None = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
duration = video.get_duration()
|
duration = video.get_duration()
|
||||||
@ -136,6 +134,23 @@ def validate_video_duration(
|
|||||||
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
|
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_video_frame_count(
|
||||||
|
video: Input.Video,
|
||||||
|
min_frame_count: int | None = None,
|
||||||
|
max_frame_count: int | None = None,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
frame_count = video.get_frame_count()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Error getting frame count of video: %s", e)
|
||||||
|
return
|
||||||
|
|
||||||
|
if min_frame_count is not None and min_frame_count > frame_count:
|
||||||
|
raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}")
|
||||||
|
if max_frame_count is not None and frame_count > max_frame_count:
|
||||||
|
raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}")
|
||||||
|
|
||||||
|
|
||||||
def get_number_of_images(images):
|
def get_number_of_images(images):
|
||||||
if isinstance(images, torch.Tensor):
|
if isinstance(images, torch.Tensor):
|
||||||
return images.shape[0] if images.ndim >= 4 else 1
|
return images.shape[0] if images.ndim >= 4 else 1
|
||||||
@ -144,8 +159,8 @@ def get_number_of_images(images):
|
|||||||
|
|
||||||
def validate_audio_duration(
|
def validate_audio_duration(
|
||||||
audio: Input.Audio,
|
audio: Input.Audio,
|
||||||
min_duration: Optional[float] = None,
|
min_duration: float | None = None,
|
||||||
max_duration: Optional[float] = None,
|
max_duration: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
sr = int(audio["sample_rate"])
|
sr = int(audio["sample_rate"])
|
||||||
dur = int(audio["waveform"].shape[-1]) / sr
|
dur = int(audio["waveform"].shape[-1]) / sr
|
||||||
@ -177,7 +192,7 @@ def validate_string(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def validate_container_format_is_mp4(video: VideoInput) -> None:
|
def validate_container_format_is_mp4(video: Input.Video) -> None:
|
||||||
"""Validates video container format is MP4."""
|
"""Validates video container format is MP4."""
|
||||||
container_format = video.get_container_format()
|
container_format = video.get_container_format()
|
||||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||||
@ -194,8 +209,8 @@ def _ratio_from_tuple(r: tuple[float, float]) -> float:
|
|||||||
def _assert_ratio_bounds(
|
def _assert_ratio_bounds(
|
||||||
ar: float,
|
ar: float,
|
||||||
*,
|
*,
|
||||||
min_ratio: Optional[tuple[float, float]] = None,
|
min_ratio: tuple[float, float] | None = None,
|
||||||
max_ratio: Optional[tuple[float, float]] = None,
|
max_ratio: tuple[float, float] | None = None,
|
||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
|
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
from comfy_api.latest import IO
|
||||||
|
|
||||||
|
|
||||||
def validate_node_input(
|
def validate_node_input(
|
||||||
@ -23,6 +24,11 @@ def validate_node_input(
|
|||||||
if not received_type != input_type:
|
if not received_type != input_type:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# If the received type or input_type is a MatchType, we can return True immediately;
|
||||||
|
# validation for this is handled by the frontend
|
||||||
|
if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
|
||||||
|
return True
|
||||||
|
|
||||||
# Not equal, and not strings
|
# Not equal, and not strings
|
||||||
if not isinstance(received_type, str) or not isinstance(input_type, str):
|
if not isinstance(received_type, str) or not isinstance(input_type, str):
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -6,65 +6,80 @@ import torch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import os
|
import os
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import random
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import logging
|
import logging
|
||||||
from comfy.cli_args import args
|
from typing_extensions import override
|
||||||
from comfy.comfy_types import FileLocator
|
from comfy_api.latest import ComfyExtension, IO, UI
|
||||||
|
|
||||||
class EmptyLatentAudio:
|
class EmptyLatentAudio(IO.ComfyNode):
|
||||||
def __init__(self):
|
@classmethod
|
||||||
self.device = comfy.model_management.intermediate_device()
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="EmptyLatentAudio",
|
||||||
|
display_name="Empty Latent Audio",
|
||||||
|
category="latent/audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
||||||
|
IO.Int.Input(
|
||||||
|
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Latent.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
|
||||||
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
|
||||||
}}
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
|
||||||
FUNCTION = "generate"
|
|
||||||
|
|
||||||
CATEGORY = "latent/audio"
|
|
||||||
|
|
||||||
def generate(self, seconds, batch_size):
|
|
||||||
length = round((seconds * 44100 / 2048) / 2) * 2
|
length = round((seconds * 44100 / 2048) / 2) * 2
|
||||||
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
|
||||||
return ({"samples":latent, "type": "audio"}, )
|
return IO.NodeOutput({"samples":latent, "type": "audio"})
|
||||||
|
|
||||||
class ConditioningStableAudio:
|
generate = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningStableAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"positive": ("CONDITIONING", ),
|
return IO.Schema(
|
||||||
"negative": ("CONDITIONING", ),
|
node_id="ConditioningStableAudio",
|
||||||
"seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
|
category="conditioning",
|
||||||
"seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
|
inputs=[
|
||||||
}}
|
IO.Conditioning.Input("positive"),
|
||||||
|
IO.Conditioning.Input("negative"),
|
||||||
|
IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
|
||||||
|
IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Conditioning.Output(display_name="positive"),
|
||||||
|
IO.Conditioning.Output(display_name="negative"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
@classmethod
|
||||||
RETURN_NAMES = ("positive", "negative")
|
def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput:
|
||||||
|
|
||||||
FUNCTION = "append"
|
|
||||||
|
|
||||||
CATEGORY = "conditioning"
|
|
||||||
|
|
||||||
def append(self, positive, negative, seconds_start, seconds_total):
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
|
positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
|
negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
|
||||||
return (positive, negative)
|
return IO.NodeOutput(positive, negative)
|
||||||
|
|
||||||
class VAEEncodeAudio:
|
append = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class VAEEncodeAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
|
return IO.Schema(
|
||||||
RETURN_TYPES = ("LATENT",)
|
node_id="VAEEncodeAudio",
|
||||||
FUNCTION = "encode"
|
display_name="VAE Encode Audio",
|
||||||
|
category="latent/audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Audio.Input("audio"),
|
||||||
|
IO.Vae.Input("vae"),
|
||||||
|
],
|
||||||
|
outputs=[IO.Latent.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "latent/audio"
|
@classmethod
|
||||||
|
def execute(cls, vae, audio) -> IO.NodeOutput:
|
||||||
def encode(self, vae, audio):
|
|
||||||
sample_rate = audio["sample_rate"]
|
sample_rate = audio["sample_rate"]
|
||||||
if 44100 != sample_rate:
|
if 44100 != sample_rate:
|
||||||
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
|
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
|
||||||
@ -72,213 +87,134 @@ class VAEEncodeAudio:
|
|||||||
waveform = audio["waveform"]
|
waveform = audio["waveform"]
|
||||||
|
|
||||||
t = vae.encode(waveform.movedim(1, -1))
|
t = vae.encode(waveform.movedim(1, -1))
|
||||||
return ({"samples":t}, )
|
return IO.NodeOutput({"samples":t})
|
||||||
|
|
||||||
class VAEDecodeAudio:
|
encode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class VAEDecodeAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
return IO.Schema(
|
||||||
RETURN_TYPES = ("AUDIO",)
|
node_id="VAEDecodeAudio",
|
||||||
FUNCTION = "decode"
|
display_name="VAE Decode Audio",
|
||||||
|
category="latent/audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Latent.Input("samples"),
|
||||||
|
IO.Vae.Input("vae"),
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "latent/audio"
|
@classmethod
|
||||||
|
def execute(cls, vae, samples) -> IO.NodeOutput:
|
||||||
def decode(self, vae, samples):
|
|
||||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||||
std[std < 1.0] = 1.0
|
std[std < 1.0] = 1.0
|
||||||
audio /= std
|
audio /= std
|
||||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
return IO.NodeOutput({"waveform": audio, "sample_rate": 44100})
|
||||||
|
|
||||||
|
decode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
|
class SaveAudio(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
filename_prefix += self.prefix_append
|
def define_schema(cls):
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
return IO.Schema(
|
||||||
results: list[FileLocator] = []
|
node_id="SaveAudio",
|
||||||
|
display_name="Save Audio (FLAC)",
|
||||||
# Prepare metadata dictionary
|
category="audio",
|
||||||
metadata = {}
|
inputs=[
|
||||||
if not args.disable_metadata:
|
IO.Audio.Input("audio"),
|
||||||
if prompt is not None:
|
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||||
metadata["prompt"] = json.dumps(prompt)
|
],
|
||||||
if extra_pnginfo is not None:
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
for x in extra_pnginfo:
|
is_output_node=True,
|
||||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
)
|
||||||
|
|
||||||
# Opus supported sample rates
|
|
||||||
OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
|
|
||||||
|
|
||||||
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
|
|
||||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
|
||||||
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
|
||||||
output_path = os.path.join(full_output_folder, file)
|
|
||||||
|
|
||||||
# Use original sample rate initially
|
|
||||||
sample_rate = audio["sample_rate"]
|
|
||||||
|
|
||||||
# Handle Opus sample rate requirements
|
|
||||||
if format == "opus":
|
|
||||||
if sample_rate > 48000:
|
|
||||||
sample_rate = 48000
|
|
||||||
elif sample_rate not in OPUS_RATES:
|
|
||||||
# Find the next highest supported rate
|
|
||||||
for rate in sorted(OPUS_RATES):
|
|
||||||
if rate > sample_rate:
|
|
||||||
sample_rate = rate
|
|
||||||
break
|
|
||||||
if sample_rate not in OPUS_RATES: # Fallback if still not supported
|
|
||||||
sample_rate = 48000
|
|
||||||
|
|
||||||
# Resample if necessary
|
|
||||||
if sample_rate != audio["sample_rate"]:
|
|
||||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
|
||||||
|
|
||||||
# Create output with specified format
|
|
||||||
output_buffer = io.BytesIO()
|
|
||||||
output_container = av.open(output_buffer, mode='w', format=format)
|
|
||||||
|
|
||||||
# Set metadata on the container
|
|
||||||
for key, value in metadata.items():
|
|
||||||
output_container.metadata[key] = value
|
|
||||||
|
|
||||||
layout = 'mono' if waveform.shape[0] == 1 else 'stereo'
|
|
||||||
# Set up the output stream with appropriate properties
|
|
||||||
if format == "opus":
|
|
||||||
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
|
|
||||||
if quality == "64k":
|
|
||||||
out_stream.bit_rate = 64000
|
|
||||||
elif quality == "96k":
|
|
||||||
out_stream.bit_rate = 96000
|
|
||||||
elif quality == "128k":
|
|
||||||
out_stream.bit_rate = 128000
|
|
||||||
elif quality == "192k":
|
|
||||||
out_stream.bit_rate = 192000
|
|
||||||
elif quality == "320k":
|
|
||||||
out_stream.bit_rate = 320000
|
|
||||||
elif format == "mp3":
|
|
||||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
|
|
||||||
if quality == "V0":
|
|
||||||
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
|
||||||
out_stream.codec_context.qscale = 1
|
|
||||||
elif quality == "128k":
|
|
||||||
out_stream.bit_rate = 128000
|
|
||||||
elif quality == "320k":
|
|
||||||
out_stream.bit_rate = 320000
|
|
||||||
else: #format == "flac":
|
|
||||||
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
|
|
||||||
|
|
||||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout)
|
|
||||||
frame.sample_rate = sample_rate
|
|
||||||
frame.pts = 0
|
|
||||||
output_container.mux(out_stream.encode(frame))
|
|
||||||
|
|
||||||
# Flush encoder
|
|
||||||
output_container.mux(out_stream.encode(None))
|
|
||||||
|
|
||||||
# Close containers
|
|
||||||
output_container.close()
|
|
||||||
|
|
||||||
# Write the output to file
|
|
||||||
output_buffer.seek(0)
|
|
||||||
with open(output_path, 'wb') as f:
|
|
||||||
f.write(output_buffer.getbuffer())
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
"filename": file,
|
|
||||||
"subfolder": subfolder,
|
|
||||||
"type": self.type
|
|
||||||
})
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
return { "ui": { "audio": results } }
|
|
||||||
|
|
||||||
class SaveAudio:
|
|
||||||
def __init__(self):
|
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
|
||||||
self.type = "output"
|
|
||||||
self.prefix_append = ""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
|
||||||
return {"required": { "audio": ("AUDIO", ),
|
return IO.NodeOutput(
|
||||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
|
||||||
},
|
)
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
save_flac = execute # TODO: remove
|
||||||
FUNCTION = "save_flac"
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
CATEGORY = "audio"
|
class SaveAudioMP3(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
|
def define_schema(cls):
|
||||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
|
return IO.Schema(
|
||||||
|
node_id="SaveAudioMP3",
|
||||||
class SaveAudioMP3:
|
display_name="Save Audio (MP3)",
|
||||||
def __init__(self):
|
category="audio",
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
inputs=[
|
||||||
self.type = "output"
|
IO.Audio.Input("audio"),
|
||||||
self.prefix_append = ""
|
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||||
|
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
|
||||||
|
],
|
||||||
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
|
||||||
return {"required": { "audio": ("AUDIO", ),
|
return IO.NodeOutput(
|
||||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
ui=UI.AudioSaveHelper.get_save_audio_ui(
|
||||||
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
|
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
|
||||||
},
|
)
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
)
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
save_mp3 = execute # TODO: remove
|
||||||
FUNCTION = "save_mp3"
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
CATEGORY = "audio"
|
class SaveAudioOpus(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
|
def define_schema(cls):
|
||||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
return IO.Schema(
|
||||||
|
node_id="SaveAudioOpus",
|
||||||
class SaveAudioOpus:
|
display_name="Save Audio (Opus)",
|
||||||
def __init__(self):
|
category="audio",
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
inputs=[
|
||||||
self.type = "output"
|
IO.Audio.Input("audio"),
|
||||||
self.prefix_append = ""
|
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||||
|
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
|
||||||
|
],
|
||||||
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
|
||||||
return {"required": { "audio": ("AUDIO", ),
|
return IO.NodeOutput(
|
||||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
ui=UI.AudioSaveHelper.get_save_audio_ui(
|
||||||
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
|
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
|
||||||
},
|
)
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
)
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
save_opus = execute # TODO: remove
|
||||||
FUNCTION = "save_opus"
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
CATEGORY = "audio"
|
class PreviewAudio(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
|
def define_schema(cls):
|
||||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
return IO.Schema(
|
||||||
|
node_id="PreviewAudio",
|
||||||
class PreviewAudio(SaveAudio):
|
display_name="Preview Audio",
|
||||||
def __init__(self):
|
category="audio",
|
||||||
self.output_dir = folder_paths.get_temp_directory()
|
inputs=[
|
||||||
self.type = "temp"
|
IO.Audio.Input("audio"),
|
||||||
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
|
],
|
||||||
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, audio) -> IO.NodeOutput:
|
||||||
return {"required":
|
return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
|
||||||
{"audio": ("AUDIO", ), },
|
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
save_flac = execute # TODO: remove
|
||||||
}
|
|
||||||
|
|
||||||
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||||
"""Convert audio to float 32 bits PCM format."""
|
"""Convert audio to float 32 bits PCM format."""
|
||||||
@ -316,26 +252,30 @@ def load(filepath: str) -> tuple[torch.Tensor, int]:
|
|||||||
wav = f32_pcm(wav)
|
wav = f32_pcm(wav)
|
||||||
return wav, sr
|
return wav, sr
|
||||||
|
|
||||||
class LoadAudio:
|
class LoadAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
|
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
|
||||||
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
return IO.Schema(
|
||||||
|
node_id="LoadAudio",
|
||||||
|
display_name="Load Audio",
|
||||||
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)),
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "audio"
|
@classmethod
|
||||||
|
def execute(cls, audio) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("AUDIO", )
|
|
||||||
FUNCTION = "load"
|
|
||||||
|
|
||||||
def load(self, audio):
|
|
||||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||||
waveform, sample_rate = load(audio_path)
|
waveform, sample_rate = load(audio_path)
|
||||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||||
return (audio, )
|
return IO.NodeOutput(audio)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(s, audio):
|
def fingerprint_inputs(cls, audio):
|
||||||
image_path = folder_paths.get_annotated_filepath(audio)
|
image_path = folder_paths.get_annotated_filepath(audio)
|
||||||
m = hashlib.sha256()
|
m = hashlib.sha256()
|
||||||
with open(image_path, 'rb') as f:
|
with open(image_path, 'rb') as f:
|
||||||
@ -343,46 +283,69 @@ class LoadAudio:
|
|||||||
return m.digest().hex()
|
return m.digest().hex()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def VALIDATE_INPUTS(s, audio):
|
def validate_inputs(cls, audio):
|
||||||
if not folder_paths.exists_annotated_filepath(audio):
|
if not folder_paths.exists_annotated_filepath(audio):
|
||||||
return "Invalid audio file: {}".format(audio)
|
return "Invalid audio file: {}".format(audio)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
class RecordAudio:
|
load = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class RecordAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"audio": ("AUDIO_RECORD", {})}}
|
return IO.Schema(
|
||||||
|
node_id="RecordAudio",
|
||||||
|
display_name="Record Audio",
|
||||||
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Custom("AUDIO_RECORD").Input("audio"),
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "audio"
|
@classmethod
|
||||||
|
def execute(cls, audio) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("AUDIO", )
|
|
||||||
FUNCTION = "load"
|
|
||||||
|
|
||||||
def load(self, audio):
|
|
||||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||||
|
|
||||||
waveform, sample_rate = load(audio_path)
|
waveform, sample_rate = load(audio_path)
|
||||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||||
return (audio, )
|
return IO.NodeOutput(audio)
|
||||||
|
|
||||||
|
load = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class TrimAudioDuration:
|
class TrimAudioDuration(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="TrimAudioDuration",
|
||||||
"audio": ("AUDIO",),
|
display_name="Trim Audio Duration",
|
||||||
"start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}),
|
description="Trim audio tensor into chosen time range.",
|
||||||
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}),
|
category="audio",
|
||||||
},
|
inputs=[
|
||||||
}
|
IO.Audio.Input("audio"),
|
||||||
|
IO.Float.Input(
|
||||||
|
"start_index",
|
||||||
|
default=0.0,
|
||||||
|
min=-0xffffffffffffffff,
|
||||||
|
max=0xffffffffffffffff,
|
||||||
|
step=0.01,
|
||||||
|
tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).",
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"duration",
|
||||||
|
default=60.0,
|
||||||
|
min=0.0,
|
||||||
|
step=0.01,
|
||||||
|
tooltip="Duration in seconds",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
FUNCTION = "trim"
|
@classmethod
|
||||||
RETURN_TYPES = ("AUDIO",)
|
def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
|
||||||
CATEGORY = "audio"
|
|
||||||
DESCRIPTION = "Trim audio tensor into chosen time range."
|
|
||||||
|
|
||||||
def trim(self, audio, start_index, duration):
|
|
||||||
waveform = audio["waveform"]
|
waveform = audio["waveform"]
|
||||||
sample_rate = audio["sample_rate"]
|
sample_rate = audio["sample_rate"]
|
||||||
audio_length = waveform.shape[-1]
|
audio_length = waveform.shape[-1]
|
||||||
@ -399,23 +362,30 @@ class TrimAudioDuration:
|
|||||||
if start_frame >= end_frame:
|
if start_frame >= end_frame:
|
||||||
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
|
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
|
||||||
|
|
||||||
return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},)
|
return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
|
||||||
|
|
||||||
|
trim = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class SplitAudioChannels:
|
class SplitAudioChannels(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return IO.Schema(
|
||||||
"audio": ("AUDIO",),
|
node_id="SplitAudioChannels",
|
||||||
}}
|
display_name="Split Audio Channels",
|
||||||
|
description="Separates the audio into left and right channels.",
|
||||||
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Audio.Input("audio"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Audio.Output(display_name="left"),
|
||||||
|
IO.Audio.Output(display_name="right"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO", "AUDIO")
|
@classmethod
|
||||||
RETURN_NAMES = ("left", "right")
|
def execute(cls, audio) -> IO.NodeOutput:
|
||||||
FUNCTION = "separate"
|
|
||||||
CATEGORY = "audio"
|
|
||||||
DESCRIPTION = "Separates the audio into left and right channels."
|
|
||||||
|
|
||||||
def separate(self, audio):
|
|
||||||
waveform = audio["waveform"]
|
waveform = audio["waveform"]
|
||||||
sample_rate = audio["sample_rate"]
|
sample_rate = audio["sample_rate"]
|
||||||
|
|
||||||
@ -425,7 +395,9 @@ class SplitAudioChannels:
|
|||||||
left_channel = waveform[..., 0:1, :]
|
left_channel = waveform[..., 0:1, :]
|
||||||
right_channel = waveform[..., 1:2, :]
|
right_channel = waveform[..., 1:2, :]
|
||||||
|
|
||||||
return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
|
return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
|
||||||
|
|
||||||
|
separate = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
|
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
|
||||||
@ -443,21 +415,29 @@ def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_
|
|||||||
return waveform_1, waveform_2, output_sample_rate
|
return waveform_1, waveform_2, output_sample_rate
|
||||||
|
|
||||||
|
|
||||||
class AudioConcat:
|
class AudioConcat(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return IO.Schema(
|
||||||
"audio1": ("AUDIO",),
|
node_id="AudioConcat",
|
||||||
"audio2": ("AUDIO",),
|
display_name="Audio Concat",
|
||||||
"direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}),
|
description="Concatenates the audio1 to audio2 in the specified direction.",
|
||||||
}}
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Audio.Input("audio1"),
|
||||||
|
IO.Audio.Input("audio2"),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"direction",
|
||||||
|
options=['after', 'before'],
|
||||||
|
default="after",
|
||||||
|
tooltip="Whether to append audio2 after or before audio1.",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
@classmethod
|
||||||
FUNCTION = "concat"
|
def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
|
||||||
CATEGORY = "audio"
|
|
||||||
DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction."
|
|
||||||
|
|
||||||
def concat(self, audio1, audio2, direction):
|
|
||||||
waveform_1 = audio1["waveform"]
|
waveform_1 = audio1["waveform"]
|
||||||
waveform_2 = audio2["waveform"]
|
waveform_2 = audio2["waveform"]
|
||||||
sample_rate_1 = audio1["sample_rate"]
|
sample_rate_1 = audio1["sample_rate"]
|
||||||
@ -477,26 +457,33 @@ class AudioConcat:
|
|||||||
elif direction == 'before':
|
elif direction == 'before':
|
||||||
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
|
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
|
||||||
|
|
||||||
return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},)
|
return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate})
|
||||||
|
|
||||||
|
concat = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class AudioMerge:
|
class AudioMerge(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="AudioMerge",
|
||||||
"audio1": ("AUDIO",),
|
display_name="Audio Merge",
|
||||||
"audio2": ("AUDIO",),
|
description="Combine two audio tracks by overlaying their waveforms.",
|
||||||
"merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}),
|
category="audio",
|
||||||
},
|
inputs=[
|
||||||
}
|
IO.Audio.Input("audio1"),
|
||||||
|
IO.Audio.Input("audio2"),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"merge_method",
|
||||||
|
options=["add", "mean", "subtract", "multiply"],
|
||||||
|
tooltip="The method used to combine the audio waveforms.",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
FUNCTION = "merge"
|
@classmethod
|
||||||
RETURN_TYPES = ("AUDIO",)
|
def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
|
||||||
CATEGORY = "audio"
|
|
||||||
DESCRIPTION = "Combine two audio tracks by overlaying their waveforms."
|
|
||||||
|
|
||||||
def merge(self, audio1, audio2, merge_method):
|
|
||||||
waveform_1 = audio1["waveform"]
|
waveform_1 = audio1["waveform"]
|
||||||
waveform_2 = audio2["waveform"]
|
waveform_2 = audio2["waveform"]
|
||||||
sample_rate_1 = audio1["sample_rate"]
|
sample_rate_1 = audio1["sample_rate"]
|
||||||
@ -530,85 +517,108 @@ class AudioMerge:
|
|||||||
if max_val > 1.0:
|
if max_val > 1.0:
|
||||||
waveform = waveform / max_val
|
waveform = waveform / max_val
|
||||||
|
|
||||||
return ({"waveform": waveform, "sample_rate": output_sample_rate},)
|
return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate})
|
||||||
|
|
||||||
|
merge = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class AudioAdjustVolume:
|
class AudioAdjustVolume(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return IO.Schema(
|
||||||
"audio": ("AUDIO",),
|
node_id="AudioAdjustVolume",
|
||||||
"volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}),
|
display_name="Audio Adjust Volume",
|
||||||
}}
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Audio.Input("audio"),
|
||||||
|
IO.Int.Input(
|
||||||
|
"volume",
|
||||||
|
default=1,
|
||||||
|
min=-100,
|
||||||
|
max=100,
|
||||||
|
tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
@classmethod
|
||||||
FUNCTION = "adjust_volume"
|
def execute(cls, audio, volume) -> IO.NodeOutput:
|
||||||
CATEGORY = "audio"
|
|
||||||
|
|
||||||
def adjust_volume(self, audio, volume):
|
|
||||||
if volume == 0:
|
if volume == 0:
|
||||||
return (audio,)
|
return IO.NodeOutput(audio)
|
||||||
waveform = audio["waveform"]
|
waveform = audio["waveform"]
|
||||||
sample_rate = audio["sample_rate"]
|
sample_rate = audio["sample_rate"]
|
||||||
|
|
||||||
gain = 10 ** (volume / 20)
|
gain = 10 ** (volume / 20)
|
||||||
waveform = waveform * gain
|
waveform = waveform * gain
|
||||||
|
|
||||||
return ({"waveform": waveform, "sample_rate": sample_rate},)
|
return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
|
||||||
|
|
||||||
|
adjust_volume = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class EmptyAudio:
|
class EmptyAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return IO.Schema(
|
||||||
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}),
|
node_id="EmptyAudio",
|
||||||
"sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}),
|
display_name="Empty Audio",
|
||||||
"channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}),
|
category="audio",
|
||||||
}}
|
inputs=[
|
||||||
|
IO.Float.Input(
|
||||||
|
"duration",
|
||||||
|
default=60.0,
|
||||||
|
min=0.0,
|
||||||
|
max=0xffffffffffffffff,
|
||||||
|
step=0.01,
|
||||||
|
tooltip="Duration of the empty audio clip in seconds",
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"sample_rate",
|
||||||
|
default=44100,
|
||||||
|
tooltip="Sample rate of the empty audio clip.",
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"channels",
|
||||||
|
default=2,
|
||||||
|
min=1,
|
||||||
|
max=2,
|
||||||
|
tooltip="Number of audio channels (1 for mono, 2 for stereo).",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
@classmethod
|
||||||
FUNCTION = "create_empty_audio"
|
def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput:
|
||||||
CATEGORY = "audio"
|
|
||||||
|
|
||||||
def create_empty_audio(self, duration, sample_rate, channels):
|
|
||||||
num_samples = int(round(duration * sample_rate))
|
num_samples = int(round(duration * sample_rate))
|
||||||
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
|
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
|
||||||
return ({"waveform": waveform, "sample_rate": sample_rate},)
|
return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
|
||||||
|
|
||||||
|
create_empty_audio = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class AudioExtension(ComfyExtension):
|
||||||
"EmptyLatentAudio": EmptyLatentAudio,
|
@override
|
||||||
"VAEEncodeAudio": VAEEncodeAudio,
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
"VAEDecodeAudio": VAEDecodeAudio,
|
return [
|
||||||
"SaveAudio": SaveAudio,
|
EmptyLatentAudio,
|
||||||
"SaveAudioMP3": SaveAudioMP3,
|
VAEEncodeAudio,
|
||||||
"SaveAudioOpus": SaveAudioOpus,
|
VAEDecodeAudio,
|
||||||
"LoadAudio": LoadAudio,
|
SaveAudio,
|
||||||
"PreviewAudio": PreviewAudio,
|
SaveAudioMP3,
|
||||||
"ConditioningStableAudio": ConditioningStableAudio,
|
SaveAudioOpus,
|
||||||
"RecordAudio": RecordAudio,
|
LoadAudio,
|
||||||
"TrimAudioDuration": TrimAudioDuration,
|
PreviewAudio,
|
||||||
"SplitAudioChannels": SplitAudioChannels,
|
ConditioningStableAudio,
|
||||||
"AudioConcat": AudioConcat,
|
RecordAudio,
|
||||||
"AudioMerge": AudioMerge,
|
TrimAudioDuration,
|
||||||
"AudioAdjustVolume": AudioAdjustVolume,
|
SplitAudioChannels,
|
||||||
"EmptyAudio": EmptyAudio,
|
AudioConcat,
|
||||||
}
|
AudioMerge,
|
||||||
|
AudioAdjustVolume,
|
||||||
|
EmptyAudio,
|
||||||
|
]
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
async def comfy_entrypoint() -> AudioExtension:
|
||||||
"EmptyLatentAudio": "Empty Latent Audio",
|
return AudioExtension()
|
||||||
"VAEEncodeAudio": "VAE Encode Audio",
|
|
||||||
"VAEDecodeAudio": "VAE Decode Audio",
|
|
||||||
"PreviewAudio": "Preview Audio",
|
|
||||||
"LoadAudio": "Load Audio",
|
|
||||||
"SaveAudio": "Save Audio (FLAC)",
|
|
||||||
"SaveAudioMP3": "Save Audio (MP3)",
|
|
||||||
"SaveAudioOpus": "Save Audio (Opus)",
|
|
||||||
"RecordAudio": "Record Audio",
|
|
||||||
"TrimAudioDuration": "Trim Audio Duration",
|
|
||||||
"SplitAudioChannels": "Split Audio Channels",
|
|
||||||
"AudioConcat": "Audio Concat",
|
|
||||||
"AudioMerge": "Audio Merge",
|
|
||||||
"AudioAdjustVolume": "Audio Adjust Volume",
|
|
||||||
"EmptyAudio": "Empty Audio",
|
|
||||||
}
|
|
||||||
|
|||||||
@ -2,22 +2,18 @@ import nodes
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from comfy.comfy_types import IO
|
from typing_extensions import override
|
||||||
from comfy_api.input_impl import VideoFromFile
|
from comfy_api.latest import IO, ComfyExtension, InputImpl, UI
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
def normalize_path(path):
|
def normalize_path(path):
|
||||||
return path.replace('\\', '/')
|
return path.replace('\\', '/')
|
||||||
|
|
||||||
class Load3D():
|
class Load3D(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
|
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
|
||||||
|
|
||||||
os.makedirs(input_dir, exist_ok=True)
|
os.makedirs(input_dir, exist_ok=True)
|
||||||
@ -30,23 +26,29 @@ class Load3D():
|
|||||||
for file_path in input_path.rglob("*")
|
for file_path in input_path.rglob("*")
|
||||||
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
|
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
|
||||||
]
|
]
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="Load3D",
|
||||||
|
display_name="Load 3D & Animation",
|
||||||
|
category="3d",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model),
|
||||||
|
IO.Load3D.Input("image"),
|
||||||
|
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
||||||
|
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Image.Output(display_name="image"),
|
||||||
|
IO.Mask.Output(display_name="mask"),
|
||||||
|
IO.String.Output(display_name="mesh_path"),
|
||||||
|
IO.Image.Output(display_name="normal"),
|
||||||
|
IO.Load3DCamera.Output(display_name="camera_info"),
|
||||||
|
IO.Video.Output(display_name="recording_video"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
return {"required": {
|
@classmethod
|
||||||
"model_file": (sorted(files), {"file_upload": True}),
|
def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput:
|
||||||
"image": ("LOAD_3D", {}),
|
|
||||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
|
||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
|
||||||
}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
|
||||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
|
|
||||||
|
|
||||||
FUNCTION = "process"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
CATEGORY = "3d"
|
|
||||||
|
|
||||||
def process(self, model_file, image, **kwargs):
|
|
||||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||||
normal_path = folder_paths.get_annotated_filepath(image['normal'])
|
normal_path = folder_paths.get_annotated_filepath(image['normal'])
|
||||||
@ -61,58 +63,47 @@ class Load3D():
|
|||||||
if image['recording'] != "":
|
if image['recording'] != "":
|
||||||
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
||||||
|
|
||||||
video = VideoFromFile(recording_video_path)
|
video = InputImpl.VideoFromFile(recording_video_path)
|
||||||
|
|
||||||
return output_image, output_mask, model_file, normal_image, image['camera_info'], video
|
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video)
|
||||||
|
|
||||||
class Preview3D():
|
process = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class Preview3D(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return IO.Schema(
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
node_id="Preview3D",
|
||||||
},
|
display_name="Preview 3D & Animation",
|
||||||
"optional": {
|
category="3d",
|
||||||
"camera_info": ("LOAD3D_CAMERA", {}),
|
is_experimental=True,
|
||||||
"bg_image": ("IMAGE", {})
|
is_output_node=True,
|
||||||
}}
|
inputs=[
|
||||||
|
IO.String.Input("model_file", default="", multiline=False),
|
||||||
|
IO.Load3DCamera.Input("camera_info", optional=True),
|
||||||
|
IO.Image.Input("bg_image", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
@classmethod
|
||||||
RETURN_TYPES = ()
|
def execute(cls, model_file, **kwargs) -> IO.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "3d"
|
|
||||||
|
|
||||||
FUNCTION = "process"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
def process(self, model_file, **kwargs):
|
|
||||||
camera_info = kwargs.get("camera_info", None)
|
camera_info = kwargs.get("camera_info", None)
|
||||||
bg_image = kwargs.get("bg_image", None)
|
bg_image = kwargs.get("bg_image", None)
|
||||||
|
return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image))
|
||||||
|
|
||||||
bg_image_path = None
|
process = execute # TODO: remove
|
||||||
if bg_image is not None:
|
|
||||||
|
|
||||||
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
|
|
||||||
img = Image.fromarray(img_array)
|
|
||||||
|
|
||||||
temp_dir = folder_paths.get_temp_directory()
|
class Load3DExtension(ComfyExtension):
|
||||||
filename = f"bg_{uuid.uuid4().hex}.png"
|
@override
|
||||||
bg_image_path = os.path.join(temp_dir, filename)
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
img.save(bg_image_path, compress_level=1)
|
return [
|
||||||
|
Load3D,
|
||||||
|
Preview3D,
|
||||||
|
]
|
||||||
|
|
||||||
bg_image_path = f"temp/{filename}"
|
|
||||||
|
|
||||||
return {
|
async def comfy_entrypoint() -> Load3DExtension:
|
||||||
"ui": {
|
return Load3DExtension()
|
||||||
"result": [model_file, camera_info, bg_image_path]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"Load3D": Load3D,
|
|
||||||
"Preview3D": Preview3D,
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"Load3D": "Load 3D & Animation",
|
|
||||||
"Preview3D": "Preview 3D & Animation",
|
|
||||||
}
|
|
||||||
|
|||||||
155
comfy_extras/nodes_logic.py
Normal file
155
comfy_extras/nodes_logic.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
from typing import TypedDict
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy_api.latest import _io
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
template = io.MatchType.Template("switch")
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ComfySwitchNode",
|
||||||
|
display_name="Switch",
|
||||||
|
category="logic",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Boolean.Input("switch"),
|
||||||
|
io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
|
||||||
|
io.MatchType.Input("on_true", template=template, lazy=True, optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.MatchType.Output(template=template, display_name="output"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_lazy_status(cls, switch, on_false=..., on_true=...):
|
||||||
|
# We use ... instead of None, as None is passed for connected-but-unevaluated inputs.
|
||||||
|
# This trick allows us to ignore the value of the switch and still be able to run execute().
|
||||||
|
|
||||||
|
# One of the inputs may be missing, in which case we need to evaluate the other input
|
||||||
|
if on_false is ...:
|
||||||
|
return ["on_true"]
|
||||||
|
if on_true is ...:
|
||||||
|
return ["on_false"]
|
||||||
|
# Normal lazy switch operation
|
||||||
|
if switch and on_true is None:
|
||||||
|
return ["on_true"]
|
||||||
|
if not switch and on_false is None:
|
||||||
|
return ["on_false"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_inputs(cls, switch, on_false=..., on_true=...):
|
||||||
|
# This check happens before check_lazy_status(), so we can eliminate the case where
|
||||||
|
# both inputs are missing.
|
||||||
|
if on_false is ... and on_true is ...:
|
||||||
|
return "At least one of on_false or on_true must be connected to Switch node"
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput:
|
||||||
|
if on_true is ...:
|
||||||
|
return io.NodeOutput(on_false)
|
||||||
|
if on_false is ...:
|
||||||
|
return io.NodeOutput(on_true)
|
||||||
|
return io.NodeOutput(on_true if switch else on_false)
|
||||||
|
|
||||||
|
|
||||||
|
class DCTestNode(io.ComfyNode):
|
||||||
|
class DCValues(TypedDict):
|
||||||
|
combo: str
|
||||||
|
string: str
|
||||||
|
integer: int
|
||||||
|
image: io.Image.Type
|
||||||
|
subcombo: dict[str]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="DCTestNode",
|
||||||
|
display_name="DCTest",
|
||||||
|
category="logic",
|
||||||
|
is_output_node=True,
|
||||||
|
inputs=[_io.DynamicCombo.Input("combo", options=[
|
||||||
|
_io.DynamicCombo.Option("option1", [io.String.Input("string")]),
|
||||||
|
_io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
|
||||||
|
_io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
|
||||||
|
_io.DynamicCombo.Option("option4", [
|
||||||
|
_io.DynamicCombo.Input("subcombo", options=[
|
||||||
|
_io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
|
||||||
|
_io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
|
||||||
|
])
|
||||||
|
])]
|
||||||
|
)],
|
||||||
|
outputs=[io.AnyType.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, combo: DCValues) -> io.NodeOutput:
|
||||||
|
combo_val = combo["combo"]
|
||||||
|
if combo_val == "option1":
|
||||||
|
return io.NodeOutput(combo["string"])
|
||||||
|
elif combo_val == "option2":
|
||||||
|
return io.NodeOutput(combo["integer"])
|
||||||
|
elif combo_val == "option3":
|
||||||
|
return io.NodeOutput(combo["image"])
|
||||||
|
elif combo_val == "option4":
|
||||||
|
return io.NodeOutput(f"{combo['subcombo']}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid combo: {combo_val}")
|
||||||
|
|
||||||
|
|
||||||
|
class AutogrowNamesTestNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
template = _io.Autogrow.TemplateNames(input=io.Float.Input("float"), names=["a", "b", "c"])
|
||||||
|
return io.Schema(
|
||||||
|
node_id="AutogrowNamesTestNode",
|
||||||
|
display_name="AutogrowNamesTest",
|
||||||
|
category="logic",
|
||||||
|
inputs=[
|
||||||
|
_io.Autogrow.Input("autogrow", template=template)
|
||||||
|
],
|
||||||
|
outputs=[io.String.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
vals = list(autogrow.values())
|
||||||
|
combined = ",".join([str(x) for x in vals])
|
||||||
|
return io.NodeOutput(combined)
|
||||||
|
|
||||||
|
class AutogrowPrefixTestNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
template = _io.Autogrow.TemplatePrefix(input=io.Float.Input("float"), prefix="float", min=1, max=10)
|
||||||
|
return io.Schema(
|
||||||
|
node_id="AutogrowPrefixTestNode",
|
||||||
|
display_name="AutogrowPrefixTest",
|
||||||
|
category="logic",
|
||||||
|
inputs=[
|
||||||
|
_io.Autogrow.Input("autogrow", template=template)
|
||||||
|
],
|
||||||
|
outputs=[io.String.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
vals = list(autogrow.values())
|
||||||
|
combined = ",".join([str(x) for x in vals])
|
||||||
|
return io.NodeOutput(combined)
|
||||||
|
|
||||||
|
class LogicExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
# SwitchNode,
|
||||||
|
# DCTestNode,
|
||||||
|
# AutogrowNamesTestNode,
|
||||||
|
# AutogrowPrefixTestNode,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> LogicExtension:
|
||||||
|
return LogicExtension()
|
||||||
@ -6,6 +6,7 @@ import comfy.ops
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
|
import comfy.ldm.lumina.controlnet
|
||||||
|
|
||||||
|
|
||||||
class BlockWiseControlBlock(torch.nn.Module):
|
class BlockWiseControlBlock(torch.nn.Module):
|
||||||
@ -189,6 +190,35 @@ class SigLIPMultiFeatProjModel(torch.nn.Module):
|
|||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
def z_image_convert(sd):
|
||||||
|
replace_keys = {".attention.to_out.0.bias": ".attention.out.bias",
|
||||||
|
".attention.norm_k.weight": ".attention.k_norm.weight",
|
||||||
|
".attention.norm_q.weight": ".attention.q_norm.weight",
|
||||||
|
".attention.to_out.0.weight": ".attention.out.weight"
|
||||||
|
}
|
||||||
|
|
||||||
|
out_sd = {}
|
||||||
|
for k in sorted(sd.keys()):
|
||||||
|
w = sd[k]
|
||||||
|
|
||||||
|
k_out = k
|
||||||
|
if k_out.endswith(".attention.to_k.weight"):
|
||||||
|
cc = [w]
|
||||||
|
continue
|
||||||
|
if k_out.endswith(".attention.to_q.weight"):
|
||||||
|
cc = [w] + cc
|
||||||
|
continue
|
||||||
|
if k_out.endswith(".attention.to_v.weight"):
|
||||||
|
cc = cc + [w]
|
||||||
|
w = torch.cat(cc, dim=0)
|
||||||
|
k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight")
|
||||||
|
|
||||||
|
for r, rr in replace_keys.items():
|
||||||
|
k_out = k_out.replace(r, rr)
|
||||||
|
out_sd[k_out] = w
|
||||||
|
|
||||||
|
return out_sd
|
||||||
|
|
||||||
class ModelPatchLoader:
|
class ModelPatchLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -211,6 +241,9 @@ class ModelPatchLoader:
|
|||||||
elif 'feature_embedder.mid_layer_norm.bias' in sd:
|
elif 'feature_embedder.mid_layer_norm.bias' in sd:
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
|
||||||
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
|
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
|
||||||
|
sd = z_image_convert(sd)
|
||||||
|
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
|
|
||||||
model.load_state_dict(sd)
|
model.load_state_dict(sd)
|
||||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
@ -263,6 +296,69 @@ class DiffSynthCnetPatch:
|
|||||||
def models(self):
|
def models(self):
|
||||||
return [self.model_patch]
|
return [self.model_patch]
|
||||||
|
|
||||||
|
class ZImageControlPatch:
|
||||||
|
def __init__(self, model_patch, vae, image, strength):
|
||||||
|
self.model_patch = model_patch
|
||||||
|
self.vae = vae
|
||||||
|
self.image = image
|
||||||
|
self.strength = strength
|
||||||
|
self.encoded_image = self.encode_latent_cond(image)
|
||||||
|
self.encoded_image_size = (image.shape[1], image.shape[2])
|
||||||
|
self.temp_data = None
|
||||||
|
|
||||||
|
def encode_latent_cond(self, image):
|
||||||
|
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
|
||||||
|
return latent_image
|
||||||
|
|
||||||
|
def __call__(self, kwargs):
|
||||||
|
x = kwargs.get("x")
|
||||||
|
img = kwargs.get("img")
|
||||||
|
txt = kwargs.get("txt")
|
||||||
|
pe = kwargs.get("pe")
|
||||||
|
vec = kwargs.get("vec")
|
||||||
|
block_index = kwargs.get("block_index")
|
||||||
|
spacial_compression = self.vae.spacial_compression_encode()
|
||||||
|
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
||||||
|
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
||||||
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
|
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
|
||||||
|
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
||||||
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
|
|
||||||
|
cnet_index = (block_index // 5)
|
||||||
|
cnet_index_float = (block_index / 5)
|
||||||
|
|
||||||
|
kwargs.pop("img") # we do ops in place
|
||||||
|
kwargs.pop("txt")
|
||||||
|
|
||||||
|
cnet_blocks = self.model_patch.model.n_control_layers
|
||||||
|
if cnet_index_float > (cnet_blocks - 1):
|
||||||
|
self.temp_data = None
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
if self.temp_data is None or self.temp_data[0] > cnet_index:
|
||||||
|
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||||
|
|
||||||
|
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
|
||||||
|
next_layer = self.temp_data[0] + 1
|
||||||
|
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||||
|
|
||||||
|
if cnet_index_float == self.temp_data[0]:
|
||||||
|
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||||
|
if cnet_blocks == self.temp_data[0] + 1:
|
||||||
|
self.temp_data = None
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def to(self, device_or_dtype):
|
||||||
|
if isinstance(device_or_dtype, torch.device):
|
||||||
|
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
||||||
|
self.temp_data = None
|
||||||
|
return self
|
||||||
|
|
||||||
|
def models(self):
|
||||||
|
return [self.model_patch]
|
||||||
|
|
||||||
class QwenImageDiffsynthControlnet:
|
class QwenImageDiffsynthControlnet:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -289,6 +385,9 @@ class QwenImageDiffsynthControlnet:
|
|||||||
mask = mask.unsqueeze(2)
|
mask = mask.unsqueeze(2)
|
||||||
mask = 1.0 - mask
|
mask = 1.0 - mask
|
||||||
|
|
||||||
|
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
||||||
|
model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
|
||||||
|
else:
|
||||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||||
return (model_patched,)
|
return (model_patched,)
|
||||||
|
|
||||||
|
|||||||
@ -623,7 +623,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||||
if multi_res:
|
if multi_res:
|
||||||
# use first latent as dummy latent if multi_res
|
# use first latent as dummy latent if multi_res
|
||||||
latents = latents[0].repeat(num_images, 1, 1, 1)
|
latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1)))
|
||||||
guider.sample(
|
guider.sample(
|
||||||
noise.generate_noise({"samples": latents}),
|
noise.generate_noise({"samples": latents}),
|
||||||
latents,
|
latents,
|
||||||
|
|||||||
@ -88,7 +88,7 @@ class SaveVideo(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput:
|
def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput:
|
||||||
width, height = video.get_dimensions()
|
width, height = video.get_dimensions()
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||||
filename_prefix,
|
filename_prefix,
|
||||||
@ -108,7 +108,7 @@ class SaveVideo(io.ComfyNode):
|
|||||||
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
|
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
|
||||||
video.save_to(
|
video.save_to(
|
||||||
os.path.join(full_output_folder, file),
|
os.path.join(full_output_folder, file),
|
||||||
format=format,
|
format=VideoContainer(format),
|
||||||
codec=codec,
|
codec=codec,
|
||||||
metadata=saved_metadata
|
metadata=saved_metadata
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.75"
|
__version__ = "0.3.76"
|
||||||
|
|||||||
40
execution.py
40
execution.py
@ -34,7 +34,7 @@ from comfy_execution.validation import validate_node_input
|
|||||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||||
from comfy_execution.utils import CurrentNodeContext
|
from comfy_execution.utils import CurrentNodeContext
|
||||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||||
from comfy_api.latest import io
|
from comfy_api.latest import io, _io
|
||||||
|
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
@ -76,7 +76,7 @@ class IsChangedCache:
|
|||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||||
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
|
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
|
||||||
try:
|
try:
|
||||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
||||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||||
@ -146,8 +146,9 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
|||||||
|
|
||||||
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
|
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
|
||||||
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||||
|
v3_data: io.V3Data = {}
|
||||||
if is_v3:
|
if is_v3:
|
||||||
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
|
valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
|
||||||
else:
|
else:
|
||||||
valid_inputs = class_def.INPUT_TYPES()
|
valid_inputs = class_def.INPUT_TYPES()
|
||||||
input_data_all = {}
|
input_data_all = {}
|
||||||
@ -207,7 +208,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
|||||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||||
if h[x] == "API_KEY_COMFY_ORG":
|
if h[x] == "API_KEY_COMFY_ORG":
|
||||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||||
return input_data_all, missing_keys, hidden_inputs_v3
|
v3_data["hidden_inputs"] = hidden_inputs_v3
|
||||||
|
return input_data_all, missing_keys, v3_data
|
||||||
|
|
||||||
map_node_over_list = None #Don't hook this please
|
map_node_over_list = None #Don't hook this please
|
||||||
|
|
||||||
@ -223,7 +225,7 @@ async def resolve_map_node_over_list_results(results):
|
|||||||
raise exc
|
raise exc
|
||||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||||
|
|
||||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
|
||||||
# check if node wants the lists
|
# check if node wants the lists
|
||||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||||
|
|
||||||
@ -259,13 +261,16 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
|||||||
if is_class(obj):
|
if is_class(obj):
|
||||||
type_obj = obj
|
type_obj = obj
|
||||||
obj.VALIDATE_CLASS()
|
obj.VALIDATE_CLASS()
|
||||||
class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
|
||||||
# otherwise, use class instance to populate/reuse some fields
|
# otherwise, use class instance to populate/reuse some fields
|
||||||
else:
|
else:
|
||||||
type_obj = type(obj)
|
type_obj = type(obj)
|
||||||
type_obj.VALIDATE_CLASS()
|
type_obj.VALIDATE_CLASS()
|
||||||
class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
|
||||||
f = make_locked_method_func(type_obj, func, class_clone)
|
f = make_locked_method_func(type_obj, func, class_clone)
|
||||||
|
# in case of dynamic inputs, restructure inputs to expected nested dict
|
||||||
|
if v3_data is not None:
|
||||||
|
inputs = _io.build_nested_inputs(inputs, v3_data)
|
||||||
# V1
|
# V1
|
||||||
else:
|
else:
|
||||||
f = getattr(obj, func)
|
f = getattr(obj, func)
|
||||||
@ -320,8 +325,8 @@ def merge_result_data(results, obj):
|
|||||||
output.append([o[i] for o in results])
|
output.append([o[i] for o in results])
|
||||||
return output
|
return output
|
||||||
|
|
||||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
|
||||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||||
if has_pending_task:
|
if has_pending_task:
|
||||||
return return_values, {}, False, has_pending_task
|
return return_values, {}, False, has_pending_task
|
||||||
@ -460,7 +465,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
has_subgraph = False
|
has_subgraph = False
|
||||||
else:
|
else:
|
||||||
get_progress_state().start_progress(unique_id)
|
get_progress_state().start_progress(unique_id)
|
||||||
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = display_node_id
|
server.last_node_id = display_node_id
|
||||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||||
@ -475,7 +480,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
else:
|
else:
|
||||||
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||||
if lazy_status_present:
|
if lazy_status_present:
|
||||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
|
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data)
|
||||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||||
@ -507,7 +512,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
def pre_execute_cb(call_index):
|
def pre_execute_cb(call_index):
|
||||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||||
if has_pending_tasks:
|
if has_pending_tasks:
|
||||||
pending_async_nodes[unique_id] = output_data
|
pending_async_nodes[unique_id] = output_data
|
||||||
unblock = execution_list.add_external_block(unique_id)
|
unblock = execution_list.add_external_block(unique_id)
|
||||||
@ -745,18 +750,17 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
class_type = prompt[unique_id]['class_type']
|
class_type = prompt[unique_id]['class_type']
|
||||||
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
|
||||||
class_inputs = obj_class.INPUT_TYPES()
|
|
||||||
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
|
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
valid = True
|
valid = True
|
||||||
|
|
||||||
validate_function_inputs = []
|
validate_function_inputs = []
|
||||||
validate_has_kwargs = False
|
validate_has_kwargs = False
|
||||||
if issubclass(obj_class, _ComfyNodeInternal):
|
if issubclass(obj_class, _ComfyNodeInternal):
|
||||||
|
class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
|
||||||
validate_function_name = "validate_inputs"
|
validate_function_name = "validate_inputs"
|
||||||
validate_function = first_real_override(obj_class, validate_function_name)
|
validate_function = first_real_override(obj_class, validate_function_name)
|
||||||
else:
|
else:
|
||||||
|
class_inputs = obj_class.INPUT_TYPES()
|
||||||
validate_function_name = "VALIDATE_INPUTS"
|
validate_function_name = "VALIDATE_INPUTS"
|
||||||
validate_function = getattr(obj_class, validate_function_name, None)
|
validate_function = getattr(obj_class, validate_function_name, None)
|
||||||
if validate_function is not None:
|
if validate_function is not None:
|
||||||
@ -765,6 +769,8 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
validate_has_kwargs = argspec.varkw is not None
|
validate_has_kwargs = argspec.varkw is not None
|
||||||
received_types = {}
|
received_types = {}
|
||||||
|
|
||||||
|
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
|
||||||
|
|
||||||
for x in valid_inputs:
|
for x in valid_inputs:
|
||||||
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
||||||
assert extra_info is not None
|
assert extra_info is not None
|
||||||
@ -935,7 +941,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||||||
input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
|
input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
|
||||||
input_filtered = {}
|
input_filtered = {}
|
||||||
for x in input_data_all:
|
for x in input_data_all:
|
||||||
if x in validate_function_inputs or validate_has_kwargs:
|
if x in validate_function_inputs or validate_has_kwargs:
|
||||||
@ -943,7 +949,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
if 'input_types' in validate_function_inputs:
|
if 'input_types' in validate_function_inputs:
|
||||||
input_filtered['input_types'] = [received_types]
|
input_filtered['input_types'] = [received_types]
|
||||||
|
|
||||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
|
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data)
|
||||||
ret = await resolve_map_node_over_list_results(ret)
|
ret = await resolve_map_node_over_list_results(ret)
|
||||||
for x in input_filtered:
|
for x in input_filtered:
|
||||||
for i, r in enumerate(ret):
|
for i, r in enumerate(ret):
|
||||||
|
|||||||
30
main.py
30
main.py
@ -15,6 +15,7 @@ from comfy_execution.progress import get_progress_state
|
|||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
from comfy_api import feature_flags
|
from comfy_api import feature_flags
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||||
@ -22,6 +23,23 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_comfyui_manager_unavailable():
|
||||||
|
if not args.windows_standalone_build:
|
||||||
|
logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n")
|
||||||
|
args.enable_manager = False
|
||||||
|
|
||||||
|
|
||||||
|
if args.enable_manager:
|
||||||
|
if importlib.util.find_spec("comfyui_manager"):
|
||||||
|
import comfyui_manager
|
||||||
|
|
||||||
|
if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith('__init__.py'):
|
||||||
|
handle_comfyui_manager_unavailable()
|
||||||
|
else:
|
||||||
|
handle_comfyui_manager_unavailable()
|
||||||
|
|
||||||
|
|
||||||
def apply_custom_paths():
|
def apply_custom_paths():
|
||||||
# extra model paths
|
# extra model paths
|
||||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||||
@ -79,6 +97,11 @@ def execute_prestartup_script():
|
|||||||
|
|
||||||
for possible_module in possible_modules:
|
for possible_module in possible_modules:
|
||||||
module_path = os.path.join(custom_node_path, possible_module)
|
module_path = os.path.join(custom_node_path, possible_module)
|
||||||
|
|
||||||
|
if args.enable_manager:
|
||||||
|
if comfyui_manager.should_be_disabled(module_path):
|
||||||
|
continue
|
||||||
|
|
||||||
if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
|
if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -101,6 +124,10 @@ def execute_prestartup_script():
|
|||||||
logging.info("")
|
logging.info("")
|
||||||
|
|
||||||
apply_custom_paths()
|
apply_custom_paths()
|
||||||
|
|
||||||
|
if args.enable_manager:
|
||||||
|
comfyui_manager.prestartup()
|
||||||
|
|
||||||
execute_prestartup_script()
|
execute_prestartup_script()
|
||||||
|
|
||||||
|
|
||||||
@ -323,6 +350,9 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
asyncio.set_event_loop(asyncio_loop)
|
asyncio.set_event_loop(asyncio_loop)
|
||||||
prompt_server = server.PromptServer(asyncio_loop)
|
prompt_server = server.PromptServer(asyncio_loop)
|
||||||
|
|
||||||
|
if args.enable_manager and not args.disable_manager_ui:
|
||||||
|
comfyui_manager.start()
|
||||||
|
|
||||||
hook_breaker_ac10a0.save_functions()
|
hook_breaker_ac10a0.save_functions()
|
||||||
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
||||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||||
|
|||||||
1
manager_requirements.txt
Normal file
1
manager_requirements.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
comfyui_manager==4.0.3b3
|
||||||
12
nodes.py
12
nodes.py
@ -43,6 +43,9 @@ import folder_paths
|
|||||||
import latent_preview
|
import latent_preview
|
||||||
import node_helpers
|
import node_helpers
|
||||||
|
|
||||||
|
if args.enable_manager:
|
||||||
|
import comfyui_manager
|
||||||
|
|
||||||
def before_node_execution():
|
def before_node_execution():
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
@ -939,7 +942,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -2243,6 +2246,12 @@ async def init_external_custom_nodes():
|
|||||||
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
||||||
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if args.enable_manager:
|
||||||
|
if comfyui_manager.should_be_disabled(module_path):
|
||||||
|
logging.info(f"Blocked by policy: {module_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
time_before = time.perf_counter()
|
time_before = time.perf_counter()
|
||||||
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||||
@ -2346,6 +2355,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_easycache.py",
|
"nodes_easycache.py",
|
||||||
"nodes_audio_encoder.py",
|
"nodes_audio_encoder.py",
|
||||||
"nodes_rope.py",
|
"nodes_rope.py",
|
||||||
|
"nodes_logic.py",
|
||||||
"nodes_nop.py",
|
"nodes_nop.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.75"
|
version = "0.3.76"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.32.9
|
comfyui-frontend-package==1.33.10
|
||||||
comfyui-workflow-templates==0.7.25
|
comfyui-workflow-templates==0.7.25
|
||||||
comfyui-embedded-docs==0.3.1
|
comfyui-embedded-docs==0.3.1
|
||||||
torch
|
torch
|
||||||
|
|||||||
10
server.py
10
server.py
@ -44,6 +44,9 @@ from protocol import BinaryEventTypes
|
|||||||
# Import cache control middleware
|
# Import cache control middleware
|
||||||
from middleware.cache_middleware import cache_control
|
from middleware.cache_middleware import cache_control
|
||||||
|
|
||||||
|
if args.enable_manager:
|
||||||
|
import comfyui_manager
|
||||||
|
|
||||||
async def send_socket_catch_exception(function, message):
|
async def send_socket_catch_exception(function, message):
|
||||||
try:
|
try:
|
||||||
await function(message)
|
await function(message)
|
||||||
@ -95,7 +98,7 @@ def create_cors_middleware(allowed_origin: str):
|
|||||||
response = await handler(request)
|
response = await handler(request)
|
||||||
|
|
||||||
response.headers['Access-Control-Allow-Origin'] = allowed_origin
|
response.headers['Access-Control-Allow-Origin'] = allowed_origin
|
||||||
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
|
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS, PATCH'
|
||||||
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
|
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
|
||||||
response.headers['Access-Control-Allow-Credentials'] = 'true'
|
response.headers['Access-Control-Allow-Credentials'] = 'true'
|
||||||
return response
|
return response
|
||||||
@ -212,6 +215,9 @@ class PromptServer():
|
|||||||
if args.disable_api_nodes:
|
if args.disable_api_nodes:
|
||||||
middlewares.append(create_block_external_middleware())
|
middlewares.append(create_block_external_middleware())
|
||||||
|
|
||||||
|
if args.enable_manager:
|
||||||
|
middlewares.append(comfyui_manager.create_middleware())
|
||||||
|
|
||||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||||
self.sockets = dict()
|
self.sockets = dict()
|
||||||
@ -599,7 +605,7 @@ class PromptServer():
|
|||||||
|
|
||||||
system_stats = {
|
system_stats = {
|
||||||
"system": {
|
"system": {
|
||||||
"os": os.name,
|
"os": sys.platform,
|
||||||
"ram_total": ram_total,
|
"ram_total": ram_total,
|
||||||
"ram_free": ram_free,
|
"ram_free": ram_free,
|
||||||
"comfyui_version": __version__,
|
"comfyui_version": __version__,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user