mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
Merge branch 'master' into asset-management
This commit is contained in:
commit
ca39552954
@ -1355,7 +1355,7 @@ class WanT2VCrossAttentionGather(WanSelfAttention):
|
||||
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
x = x.transpose(1, 2).view(b, -1, n, d).flatten(2)
|
||||
x = x.transpose(1, 2).reshape(b, -1, n * d)
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
@ -1551,6 +1551,9 @@ class HumoWanModel(WanModel):
|
||||
context_img_len = None
|
||||
|
||||
if audio_embed is not None:
|
||||
if reference_latent is not None:
|
||||
zero_audio_pad = torch.zeros(audio_embed.shape[0], reference_latent.shape[-3], *audio_embed.shape[2:], device=audio_embed.device, dtype=audio_embed.dtype)
|
||||
audio_embed = torch.cat([audio_embed, zero_audio_pad], dim=1)
|
||||
audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2)
|
||||
else:
|
||||
audio = None
|
||||
|
||||
548
comfy/ldm/wan/model_animate.py
Normal file
548
comfy/ldm/wan/model_animate.py
Normal file
@ -0,0 +1,548 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from typing import Tuple, Optional
|
||||
from einops import rearrange
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from .model import WanModel, sinusoidal_embedding_1d
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
class CausalConv1d(nn.Module):
|
||||
|
||||
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", operations=None, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
padding = (kernel_size - 1, 0) # T
|
||||
self.time_causal_padding = padding
|
||||
|
||||
self.conv = operations.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class FaceEncoder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None, operations=None):
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1, operations=operations, **factory_kwargs)
|
||||
self.norm1 = operations.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.act = nn.SiLU()
|
||||
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2, operations=operations, **factory_kwargs)
|
||||
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2, operations=operations, **factory_kwargs)
|
||||
|
||||
self.out_proj = operations.Linear(1024, hidden_dim, **factory_kwargs)
|
||||
self.norm1 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.norm2 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.norm3 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
b, c, t = x.shape
|
||||
|
||||
x = self.conv1_local(x)
|
||||
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
|
||||
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
x = self.conv2(x)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
x = self.conv3(x)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
x = self.norm3(x)
|
||||
x = self.act(x)
|
||||
x = self.out_proj(x)
|
||||
x = rearrange(x, "(b n) t c -> b t n c", b=b)
|
||||
padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1)
|
||||
x = torch.cat([x, padding], dim=-2)
|
||||
x_local = x.clone()
|
||||
|
||||
return x_local
|
||||
|
||||
|
||||
def get_norm_layer(norm_layer, operations=None):
|
||||
"""
|
||||
Get the normalization layer.
|
||||
|
||||
Args:
|
||||
norm_layer (str): The type of normalization layer.
|
||||
|
||||
Returns:
|
||||
norm_layer (nn.Module): The normalization layer.
|
||||
"""
|
||||
if norm_layer == "layer":
|
||||
return operations.LayerNorm
|
||||
elif norm_layer == "rms":
|
||||
return operations.RMSNorm
|
||||
else:
|
||||
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
||||
|
||||
|
||||
class FaceAdapter(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
heads_num: int,
|
||||
qk_norm: bool = True,
|
||||
qk_norm_type: str = "rms",
|
||||
num_adapter_layers: int = 1,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_dim
|
||||
self.heads_num = heads_num
|
||||
self.fuser_blocks = nn.ModuleList(
|
||||
[
|
||||
FaceBlock(
|
||||
self.hidden_size,
|
||||
self.heads_num,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
operations=operations,
|
||||
**factory_kwargs,
|
||||
)
|
||||
for _ in range(num_adapter_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
motion_embed: torch.Tensor,
|
||||
idx: int,
|
||||
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
|
||||
|
||||
|
||||
|
||||
class FaceBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
heads_num: int,
|
||||
qk_norm: bool = True,
|
||||
qk_norm_type: str = "rms",
|
||||
qk_scale: float = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
operations=None
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
|
||||
self.deterministic = False
|
||||
self.hidden_size = hidden_size
|
||||
self.heads_num = heads_num
|
||||
head_dim = hidden_size // heads_num
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.linear1_kv = operations.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
|
||||
self.linear1_q = operations.Linear(hidden_size, hidden_size, **factory_kwargs)
|
||||
|
||||
self.linear2 = operations.Linear(hidden_size, hidden_size, **factory_kwargs)
|
||||
|
||||
qk_norm_layer = get_norm_layer(qk_norm_type, operations=operations)
|
||||
self.q_norm = (
|
||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
||||
)
|
||||
self.k_norm = (
|
||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
||||
)
|
||||
|
||||
self.pre_norm_feat = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.pre_norm_motion = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
motion_vec: torch.Tensor,
|
||||
motion_mask: Optional[torch.Tensor] = None,
|
||||
# use_context_parallel=False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
B, T, N, C = motion_vec.shape
|
||||
T_comp = T
|
||||
|
||||
x_motion = self.pre_norm_motion(motion_vec)
|
||||
x_feat = self.pre_norm_feat(x)
|
||||
|
||||
kv = self.linear1_kv(x_motion)
|
||||
q = self.linear1_q(x_feat)
|
||||
|
||||
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
|
||||
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
|
||||
|
||||
# Apply QK-Norm if needed.
|
||||
q = self.q_norm(q).to(v)
|
||||
k = self.k_norm(k).to(v)
|
||||
|
||||
k = rearrange(k, "B L N H D -> (B L) N H D")
|
||||
v = rearrange(v, "B L N H D -> (B L) N H D")
|
||||
|
||||
q = rearrange(q, "B (L S) H D -> (B L) S (H D)", L=T_comp)
|
||||
|
||||
attn = optimized_attention(q, k, v, heads=self.heads_num)
|
||||
|
||||
attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
|
||||
|
||||
output = self.linear2(attn)
|
||||
|
||||
if motion_mask is not None:
|
||||
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
|
||||
|
||||
return output
|
||||
|
||||
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/upfirdn2d/upfirdn2d.py#L162
|
||||
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
|
||||
_, minor, in_h, in_w = input.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = input.view(-1, minor, in_h, 1, in_w, 1)
|
||||
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
|
||||
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
|
||||
|
||||
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||
out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0)]
|
||||
|
||||
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1)
|
||||
return out[:, :, ::down_y, ::down_x]
|
||||
|
||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
||||
|
||||
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/fused_act/fused_act.py#L81
|
||||
class FusedLeakyReLU(torch.nn.Module):
|
||||
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.bias = torch.nn.Parameter(torch.empty(1, channel, 1, 1, dtype=dtype, device=device))
|
||||
self.negative_slope = negative_slope
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, input):
|
||||
return fused_leaky_relu(input, comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype), self.negative_slope, self.scale)
|
||||
|
||||
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
||||
return F.leaky_relu(input + bias, negative_slope) * scale
|
||||
|
||||
class Blur(torch.nn.Module):
|
||||
def __init__(self, kernel, pad, dtype=None, device=None):
|
||||
super().__init__()
|
||||
kernel = torch.tensor(kernel, dtype=dtype, device=device)
|
||||
kernel = kernel[None, :] * kernel[:, None]
|
||||
kernel = kernel / kernel.sum()
|
||||
self.register_buffer('kernel', kernel)
|
||||
self.pad = pad
|
||||
|
||||
def forward(self, input):
|
||||
return upfirdn2d(input, comfy.model_management.cast_to(self.kernel, dtype=input.dtype, device=input.device), pad=self.pad)
|
||||
|
||||
#https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L590
|
||||
class ScaledLeakyReLU(torch.nn.Module):
|
||||
def __init__(self, negative_slope=0.2):
|
||||
super().__init__()
|
||||
self.negative_slope = negative_slope
|
||||
|
||||
def forward(self, input):
|
||||
return F.leaky_relu(input, negative_slope=self.negative_slope)
|
||||
|
||||
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L605
|
||||
class EqualConv2d(torch.nn.Module):
|
||||
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.empty(out_channel, in_channel, kernel_size, kernel_size, device=device, dtype=dtype))
|
||||
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_channel, device=device, dtype=dtype)) if bias else None
|
||||
|
||||
def forward(self, input):
|
||||
if self.bias is None:
|
||||
bias = None
|
||||
else:
|
||||
bias = comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype)
|
||||
|
||||
return F.conv2d(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias, stride=self.stride, padding=self.padding)
|
||||
|
||||
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L134
|
||||
class EqualLinear(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.empty(out_dim, in_dim, device=device, dtype=dtype))
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_dim, device=device, dtype=dtype)) if bias else None
|
||||
self.activation = activation
|
||||
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
||||
self.lr_mul = lr_mul
|
||||
|
||||
def forward(self, input):
|
||||
if self.bias is None:
|
||||
bias = None
|
||||
else:
|
||||
bias = comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype) * self.lr_mul
|
||||
|
||||
if self.activation:
|
||||
out = F.linear(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale)
|
||||
return fused_leaky_relu(out, bias)
|
||||
return F.linear(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias)
|
||||
|
||||
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L654
|
||||
class ConvLayer(torch.nn.Sequential):
|
||||
def __init__(self, in_channel, out_channel, kernel_size, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, activate=True, dtype=None, device=None, operations=None):
|
||||
layers = []
|
||||
|
||||
if downsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||
layers.append(Blur(blur_kernel, pad=((p + 1) // 2, p // 2)))
|
||||
stride, padding = 2, 0
|
||||
else:
|
||||
stride, padding = 1, kernel_size // 2
|
||||
|
||||
layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias and not activate, dtype=dtype, device=device, operations=operations))
|
||||
|
||||
if activate:
|
||||
layers.append(FusedLeakyReLU(out_channel) if bias else ScaledLeakyReLU(0.2))
|
||||
|
||||
super().__init__(*layers)
|
||||
|
||||
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L704
|
||||
class ResBlock(torch.nn.Module):
|
||||
def __init__(self, in_channel, out_channel, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv1 = ConvLayer(in_channel, in_channel, 3, dtype=dtype, device=device, operations=operations)
|
||||
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True, dtype=dtype, device=device, operations=operations)
|
||||
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.conv2(self.conv1(input))
|
||||
skip = self.skip(input)
|
||||
return (out + skip) / math.sqrt(2)
|
||||
|
||||
|
||||
class EncoderApp(torch.nn.Module):
|
||||
def __init__(self, w_dim=512, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
kwargs = {"device": device, "dtype": dtype, "operations": operations}
|
||||
|
||||
self.convs = torch.nn.ModuleList([
|
||||
ConvLayer(3, 32, 1, **kwargs), ResBlock(32, 64, **kwargs),
|
||||
ResBlock(64, 128, **kwargs), ResBlock(128, 256, **kwargs),
|
||||
ResBlock(256, 512, **kwargs), ResBlock(512, 512, **kwargs),
|
||||
ResBlock(512, 512, **kwargs), ResBlock(512, 512, **kwargs),
|
||||
EqualConv2d(512, w_dim, 4, padding=0, bias=False, **kwargs)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
for conv in self.convs:
|
||||
h = conv(h)
|
||||
return h.squeeze(-1).squeeze(-1)
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
def __init__(self, dim=512, motion_dim=20, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.net_app = EncoderApp(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.fc = torch.nn.Sequential(*[EqualLinear(dim, dim, dtype=dtype, device=device, operations=operations) for _ in range(4)] + [EqualLinear(dim, motion_dim, dtype=dtype, device=device, operations=operations)])
|
||||
|
||||
def encode_motion(self, x):
|
||||
return self.fc(self.net_app(x))
|
||||
|
||||
class Direction(torch.nn.Module):
|
||||
def __init__(self, motion_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.empty(512, motion_dim, device=device, dtype=dtype))
|
||||
self.motion_dim = motion_dim
|
||||
|
||||
def forward(self, input):
|
||||
stabilized_weight = comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) + 1e-8 * torch.eye(512, self.motion_dim, device=input.device, dtype=input.dtype)
|
||||
Q, _ = torch.linalg.qr(stabilized_weight.float())
|
||||
if input is None:
|
||||
return Q
|
||||
return torch.sum(input.unsqueeze(-1) * Q.T.to(input.dtype), dim=1)
|
||||
|
||||
class Synthesis(torch.nn.Module):
|
||||
def __init__(self, motion_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.direction = Direction(motion_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(self, style_dim=512, motion_dim=20, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.enc = Encoder(style_dim, motion_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.dec = Synthesis(motion_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def get_motion(self, img):
|
||||
motion_feat = self.enc.encode_motion(img)
|
||||
return self.dec.direction(motion_feat)
|
||||
|
||||
class AnimateWanModel(WanModel):
|
||||
r"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='animate',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
flf_pos_embed_token_number=None,
|
||||
motion_encoder_dim=512,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.pose_patch_embedding = operations.Conv3d(
|
||||
16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
self.motion_encoder = Generator(style_dim=512, motion_dim=20, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.face_adapter = FaceAdapter(
|
||||
heads_num=self.num_heads,
|
||||
hidden_dim=self.dim,
|
||||
num_adapter_layers=self.num_layers // 5,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
|
||||
self.face_encoder = FaceEncoder(
|
||||
in_dim=motion_encoder_dim,
|
||||
hidden_dim=self.dim,
|
||||
num_heads=4,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
|
||||
def after_patch_embedding(self, x, pose_latents, face_pixel_values):
|
||||
if pose_latents is not None:
|
||||
pose_latents = self.pose_patch_embedding(pose_latents)
|
||||
x[:, :, 1:pose_latents.shape[2] + 1] += pose_latents[:, :, :x.shape[2] - 1]
|
||||
|
||||
if face_pixel_values is None:
|
||||
return x, None
|
||||
|
||||
b, c, T, h, w = face_pixel_values.shape
|
||||
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
|
||||
encode_bs = 8
|
||||
face_pixel_values_tmp = []
|
||||
for i in range(math.ceil(face_pixel_values.shape[0] / encode_bs)):
|
||||
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i * encode_bs: (i + 1) * encode_bs]))
|
||||
|
||||
motion_vec = torch.cat(face_pixel_values_tmp)
|
||||
|
||||
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
|
||||
motion_vec = self.face_encoder(motion_vec)
|
||||
|
||||
B, L, H, C = motion_vec.shape
|
||||
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
|
||||
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
|
||||
|
||||
if motion_vec.shape[1] < x.shape[2]:
|
||||
B, L, H, C = motion_vec.shape
|
||||
pad = torch.zeros(B, x.shape[2] - motion_vec.shape[1], H, C).type_as(motion_vec)
|
||||
motion_vec = torch.cat([motion_vec, pad], dim=1)
|
||||
else:
|
||||
motion_vec = motion_vec[:, :x.shape[2]]
|
||||
return x, motion_vec
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
clip_fea=None,
|
||||
pose_latents=None,
|
||||
face_pixel_values=None,
|
||||
freqs=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
full_ref = None
|
||||
if self.ref_conv is not None:
|
||||
full_ref = kwargs.get("reference_latent", None)
|
||||
if full_ref is not None:
|
||||
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
|
||||
x = torch.concat((full_ref, x), dim=1)
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None:
|
||||
if self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
|
||||
if i % 5 == 0 and motion_vec is not None:
|
||||
x = x + self.face_adapter.fuser_blocks[i // 5](x, motion_vec)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
if full_ref is not None:
|
||||
x = x[:, full_ref.shape[1]:]
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
@ -39,6 +39,7 @@ import comfy.ldm.cosmos.model
|
||||
import comfy.ldm.cosmos.predict2
|
||||
import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.wan.model_animate
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
@ -1253,6 +1254,23 @@ class WAN21_HuMo(WAN21):
|
||||
|
||||
return out
|
||||
|
||||
class WAN22_Animate(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
|
||||
face_video_pixels = kwargs.get("face_video_pixels", None)
|
||||
if face_video_pixels is not None:
|
||||
out['face_pixel_values'] = comfy.conds.CONDRegular(face_video_pixels)
|
||||
|
||||
pose_latents = kwargs.get("pose_video_latent", None)
|
||||
if pose_latents is not None:
|
||||
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
|
||||
return out
|
||||
|
||||
class WAN22_S2V(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||
|
||||
@ -404,6 +404,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "s2v"
|
||||
elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "humo"
|
||||
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "animate"
|
||||
else:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
|
||||
@ -348,7 +348,7 @@ try:
|
||||
# if any((a in arch) for a in ["gfx1201"]):
|
||||
# ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||
SUPPORT_FP8_OPS = True
|
||||
|
||||
except:
|
||||
@ -645,7 +645,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
if loaded_model.model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
for i in to_unload:
|
||||
current_loaded_models.pop(i).model.detach(unpatch_all=False)
|
||||
model_to_unload = current_loaded_models.pop(i)
|
||||
model_to_unload.model.detach(unpatch_all=False)
|
||||
model_to_unload.model_finalizer.detach()
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
|
||||
13
comfy/ops.py
13
comfy/ops.py
@ -365,12 +365,13 @@ class fp8_ops(manual_cast):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
try:
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
except Exception as e:
|
||||
logging.info("Exception during fp8 op: {}".format(e))
|
||||
if not self.training:
|
||||
try:
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
except Exception as e:
|
||||
logging.info("Exception during fp8 op: {}".format(e))
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
@ -995,7 +995,7 @@ class WAN21_T2V(supported_models_base.BASE):
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Wan21
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
memory_usage_factor = 0.9
|
||||
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
@ -1004,7 +1004,7 @@ class WAN21_T2V(supported_models_base.BASE):
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000
|
||||
self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2222
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21(self, device=device)
|
||||
@ -1096,6 +1096,19 @@ class WAN22_S2V(WAN21_T2V):
|
||||
out = model_base.WAN22_S2V(self, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_Animate(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "animate",
|
||||
}
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN22_Animate(self, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_T2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -1361,6 +1374,6 @@ class HunyuanImage21Refiner(HunyuanVideo):
|
||||
out = model_base.HunyuanImage21Refiner(self, device=device)
|
||||
return out
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -400,21 +400,25 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||
grid = None
|
||||
position_ids = None
|
||||
offset = 0
|
||||
for e in embeds_info:
|
||||
if e.get("type") == "image":
|
||||
grid = e.get("extra", None)
|
||||
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||
start = e.get("index")
|
||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||
if position_ids is None:
|
||||
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device)
|
||||
position_ids[0, start:end] = start
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||
position_ids[0, start:end] = start + offset
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
max_d = int(grid[0][2]) // 2
|
||||
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||
offset += len_max - (end - start)
|
||||
|
||||
if grid is None:
|
||||
position_ids = None
|
||||
|
||||
@ -130,12 +130,12 @@ class LoHaAdapter(WeightAdapterBase):
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
in_dim = weight.shape[1:].numel()
|
||||
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
|
||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
|
||||
torch.nn.init.normal_(mat1, 0.1)
|
||||
torch.nn.init.constant_(mat2, 0.0)
|
||||
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
|
||||
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
|
||||
torch.nn.init.normal_(mat3, 0.1)
|
||||
torch.nn.init.normal_(mat4, 0.01)
|
||||
return LohaDiff(
|
||||
|
||||
@ -89,8 +89,8 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
in_dim = weight.shape[1:].numel()
|
||||
out1, out2 = factorization(out_dim, rank)
|
||||
in1, in2 = factorization(in_dim, rank)
|
||||
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
|
||||
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
|
||||
mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32)
|
||||
mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32)
|
||||
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
|
||||
torch.nn.init.constant_(mat1, 0.0)
|
||||
return LokrDiff(
|
||||
|
||||
@ -66,8 +66,8 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
in_dim = weight.shape[1:].numel()
|
||||
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
|
||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
|
||||
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
|
||||
torch.nn.init.constant_(mat2, 0.0)
|
||||
return LoraDiff(
|
||||
|
||||
@ -68,7 +68,7 @@ class OFTAdapter(WeightAdapterBase):
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
block_size, block_num = factorization(out_dim, rank)
|
||||
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
|
||||
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32)
|
||||
return OFTDiff(
|
||||
(block, None, alpha, None)
|
||||
)
|
||||
|
||||
@ -683,7 +683,7 @@ class SynchronousOperation(Generic[T, R]):
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
auth_kwargs: Optional[Dict[str, str]] = None,
|
||||
timeout: float = 604800.0,
|
||||
timeout: float = 7200.0,
|
||||
verify_ssl: bool = True,
|
||||
content_type: str = "application/json",
|
||||
multipart_parser: Callable | None = None,
|
||||
|
||||
@ -9,8 +9,9 @@ class Rodin3DGenerateRequest(BaseModel):
|
||||
seed: int = Field(..., description="seed_")
|
||||
tier: str = Field(..., description="Tier of generation.")
|
||||
material: str = Field(..., description="The material type.")
|
||||
quality: str = Field(..., description="The generation quality of the mesh.")
|
||||
quality_override: int = Field(..., description="The poly count of the mesh.")
|
||||
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
|
||||
TAPose: Optional[bool] = Field(None, description="")
|
||||
|
||||
class GenerateJobsData(BaseModel):
|
||||
uuids: List[str] = Field(..., description="str LIST")
|
||||
|
||||
@ -567,6 +567,12 @@ class ByteDanceSeedreamNode(comfy_io.ComfyNode):
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the image.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"fail_on_partial",
|
||||
default=True,
|
||||
tooltip="If enabled, abort execution if any requested images are missing or return an error.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
@ -592,6 +598,7 @@ class ByteDanceSeedreamNode(comfy_io.ComfyNode):
|
||||
max_images: int = 1,
|
||||
seed: int = 0,
|
||||
watermark: bool = True,
|
||||
fail_on_partial: bool = True,
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
w = h = None
|
||||
@ -651,9 +658,10 @@ class ByteDanceSeedreamNode(comfy_io.ComfyNode):
|
||||
|
||||
if len(response.data) == 1:
|
||||
return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
|
||||
return comfy_io.NodeOutput(
|
||||
torch.cat([await download_url_to_image_tensor(str(i["url"])) for i in response.data])
|
||||
)
|
||||
urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d]
|
||||
if fail_on_partial and len(urls) < len(response.data):
|
||||
raise RuntimeError(f"Only {len(urls)} of {len(response.data)} images were generated before error.")
|
||||
return comfy_io.NodeOutput(torch.cat([await download_url_to_image_tensor(i) for i in urls]))
|
||||
|
||||
|
||||
class ByteDanceTextToVideoNode(comfy_io.ComfyNode):
|
||||
@ -1171,7 +1179,7 @@ async def process_video_task(
|
||||
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
|
||||
auth_kwargs: dict,
|
||||
node_id: str,
|
||||
estimated_duration: int | None,
|
||||
estimated_duration: Optional[int],
|
||||
) -> comfy_io.NodeOutput:
|
||||
initial_response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
|
||||
@ -121,10 +121,10 @@ class Rodin3DAPI:
|
||||
else:
|
||||
return "Generating"
|
||||
|
||||
async def create_generate_task(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
|
||||
async def create_generate_task(self, images=None, seed=1, material="PBR", quality_override=18000, tier="Regular", mesh_mode="Quad", TAPose = False, **kwargs):
|
||||
if images is None:
|
||||
raise Exception("Rodin 3D generate requires at least 1 image.")
|
||||
if len(images) >= 5:
|
||||
if len(images) > 5:
|
||||
raise Exception("Rodin 3D generate requires up to 5 image.")
|
||||
|
||||
path = "/proxy/rodin/api/v2/rodin"
|
||||
@ -139,8 +139,9 @@ class Rodin3DAPI:
|
||||
seed=seed,
|
||||
tier=tier,
|
||||
material=material,
|
||||
quality=quality,
|
||||
mesh_mode=mesh_mode
|
||||
quality_override=quality_override,
|
||||
mesh_mode=mesh_mode,
|
||||
TAPose=TAPose,
|
||||
),
|
||||
files=[
|
||||
(
|
||||
@ -211,23 +212,36 @@ class Rodin3DAPI:
|
||||
return await operation.execute()
|
||||
|
||||
def get_quality_mode(self, poly_count):
|
||||
if poly_count == "200K-Triangle":
|
||||
polycount = poly_count.split("-")
|
||||
poly = polycount[1]
|
||||
count = polycount[0]
|
||||
if poly == "Triangle":
|
||||
mesh_mode = "Raw"
|
||||
quality = "medium"
|
||||
elif poly == "Quad":
|
||||
mesh_mode = "Quad"
|
||||
else:
|
||||
mesh_mode = "Quad"
|
||||
if poly_count == "4K-Quad":
|
||||
quality = "extra-low"
|
||||
elif poly_count == "8K-Quad":
|
||||
quality = "low"
|
||||
elif poly_count == "18K-Quad":
|
||||
quality = "medium"
|
||||
elif poly_count == "50K-Quad":
|
||||
quality = "high"
|
||||
else:
|
||||
quality = "medium"
|
||||
|
||||
return mesh_mode, quality
|
||||
if count == "4K":
|
||||
quality_override = 4000
|
||||
elif count == "8K":
|
||||
quality_override = 8000
|
||||
elif count == "18K":
|
||||
quality_override = 18000
|
||||
elif count == "50K":
|
||||
quality_override = 50000
|
||||
elif count == "2K":
|
||||
quality_override = 2000
|
||||
elif count == "20K":
|
||||
quality_override = 20000
|
||||
elif count == "150K":
|
||||
quality_override = 150000
|
||||
elif count == "500K":
|
||||
quality_override = 500000
|
||||
else:
|
||||
quality_override = 18000
|
||||
|
||||
return mesh_mode, quality_override
|
||||
|
||||
async def download_files(self, url_list):
|
||||
save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
@ -300,9 +314,9 @@ class Rodin3D_Regular(Rodin3DAPI):
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.get_quality_mode(Polygon_count)
|
||||
mesh_mode, quality_override = self.get_quality_mode(Polygon_count)
|
||||
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
|
||||
quality=quality, tier=tier, mesh_mode=mesh_mode,
|
||||
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode,
|
||||
**kwargs)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
@ -346,9 +360,9 @@ class Rodin3D_Detail(Rodin3DAPI):
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.get_quality_mode(Polygon_count)
|
||||
mesh_mode, quality_override = self.get_quality_mode(Polygon_count)
|
||||
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
|
||||
quality=quality, tier=tier, mesh_mode=mesh_mode,
|
||||
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode,
|
||||
**kwargs)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
@ -392,9 +406,9 @@ class Rodin3D_Smooth(Rodin3DAPI):
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.get_quality_mode(Polygon_count)
|
||||
mesh_mode, quality_override = self.get_quality_mode(Polygon_count)
|
||||
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
|
||||
quality=quality, tier=tier, mesh_mode=mesh_mode,
|
||||
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode,
|
||||
**kwargs)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
@ -446,10 +460,10 @@ class Rodin3D_Sketch(Rodin3DAPI):
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
material_type = "PBR"
|
||||
quality = "medium"
|
||||
quality_override = 18000
|
||||
mesh_mode = "Quad"
|
||||
task_uuid, subscription_key = await self.create_generate_task(
|
||||
images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs
|
||||
images=m_images, seed=Seed, material=material_type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, **kwargs
|
||||
)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
@ -457,6 +471,80 @@ class Rodin3D_Sketch(Rodin3DAPI):
|
||||
|
||||
return (model,)
|
||||
|
||||
class Rodin3D_Gen2(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"Images":
|
||||
(
|
||||
IO.IMAGE,
|
||||
{
|
||||
"forceInput":True,
|
||||
}
|
||||
)
|
||||
},
|
||||
"optional": {
|
||||
"Seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default":0,
|
||||
"min":0,
|
||||
"max":65535,
|
||||
"display":"number"
|
||||
}
|
||||
),
|
||||
"Material_Type": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["PBR", "Shaded"],
|
||||
"default": "PBR"
|
||||
}
|
||||
),
|
||||
"Polygon_count": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
|
||||
"default": "500K-Triangle"
|
||||
}
|
||||
),
|
||||
"TAPose": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
}
|
||||
)
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
Material_Type,
|
||||
Polygon_count,
|
||||
TAPose,
|
||||
**kwargs
|
||||
):
|
||||
tier = "Gen-2"
|
||||
num_images = Images.shape[0]
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality_override = self.get_quality_mode(Polygon_count)
|
||||
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
|
||||
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, TAPose=TAPose,
|
||||
**kwargs)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
@ -464,6 +552,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"Rodin3D_Detail": Rodin3D_Detail,
|
||||
"Rodin3D_Smooth": Rodin3D_Smooth,
|
||||
"Rodin3D_Sketch": Rodin3D_Sketch,
|
||||
"Rodin3D_Gen2": Rodin3D_Gen2,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
@ -472,4 +561,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Rodin3D_Detail": "Rodin 3D Generate - Detail Generate",
|
||||
"Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate",
|
||||
"Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate",
|
||||
"Rodin3D_Gen2": "Rodin 3D Generate - Gen-2 Generate",
|
||||
}
|
||||
|
||||
602
comfy_api_nodes/nodes_wan.py
Normal file
602
comfy_api_nodes/nodes_wan.py
Normal file
@ -0,0 +1,602 @@
|
||||
import re
|
||||
from typing import Optional, Type, Union
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from comfy_api.latest import ComfyExtension, Input, io as comfy_io
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
R,
|
||||
T,
|
||||
)
|
||||
from comfy_api_nodes.util.validation_utils import get_number_of_images, validate_audio_duration
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
tensor_to_base64_string,
|
||||
audio_to_base64_string,
|
||||
)
|
||||
|
||||
class Text2ImageInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class Text2VideoInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
audio_url: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class Image2VideoInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
img_url: str = Field(...)
|
||||
audio_url: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class Txt2ImageParametersField(BaseModel):
|
||||
size: str = Field(...)
|
||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
prompt_extend: bool = Field(True)
|
||||
watermark: bool = Field(True)
|
||||
|
||||
|
||||
class Text2VideoParametersField(BaseModel):
|
||||
size: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
duration: int = Field(5, ge=5, le=10)
|
||||
prompt_extend: bool = Field(True)
|
||||
watermark: bool = Field(True)
|
||||
audio: bool = Field(False, description="Should be audio generated automatically")
|
||||
|
||||
|
||||
class Image2VideoParametersField(BaseModel):
|
||||
resolution: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
duration: int = Field(5, ge=5, le=10)
|
||||
prompt_extend: bool = Field(True)
|
||||
watermark: bool = Field(True)
|
||||
audio: bool = Field(False, description="Should be audio generated automatically")
|
||||
|
||||
|
||||
class Text2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Text2ImageInputField = Field(...)
|
||||
parameters: Txt2ImageParametersField = Field(...)
|
||||
|
||||
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Text2VideoInputField = Field(...)
|
||||
parameters: Text2VideoParametersField = Field(...)
|
||||
|
||||
|
||||
class Image2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Image2VideoInputField = Field(...)
|
||||
parameters: Image2VideoParametersField = Field(...)
|
||||
|
||||
|
||||
class TaskCreationOutputField(BaseModel):
|
||||
task_id: str = Field(...)
|
||||
task_status: str = Field(...)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
output: Optional[TaskCreationOutputField] = Field(None)
|
||||
request_id: str = Field(...)
|
||||
code: Optional[str] = Field(None, description="The error code of the failed request.")
|
||||
message: Optional[str] = Field(None, description="Details of the failed request.")
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
url: Optional[str] = Field(None)
|
||||
code: Optional[str] = Field(None)
|
||||
message: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class ImageTaskStatusOutputField(TaskCreationOutputField):
|
||||
task_id: str = Field(...)
|
||||
task_status: str = Field(...)
|
||||
results: Optional[list[TaskResult]] = Field(None)
|
||||
|
||||
|
||||
class VideoTaskStatusOutputField(TaskCreationOutputField):
|
||||
task_id: str = Field(...)
|
||||
task_status: str = Field(...)
|
||||
video_url: Optional[str] = Field(None)
|
||||
code: Optional[str] = Field(None)
|
||||
message: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class ImageTaskStatusResponse(BaseModel):
|
||||
output: Optional[ImageTaskStatusOutputField] = Field(None)
|
||||
request_id: str = Field(...)
|
||||
|
||||
|
||||
class VideoTaskStatusResponse(BaseModel):
|
||||
output: Optional[VideoTaskStatusOutputField] = Field(None)
|
||||
request_id: str = Field(...)
|
||||
|
||||
|
||||
RES_IN_PARENS = re.compile(r'\((\d+)\s*[x×]\s*(\d+)\)')
|
||||
|
||||
|
||||
async def process_task(
|
||||
auth_kwargs: dict[str, str],
|
||||
url: str,
|
||||
request_model: Type[T],
|
||||
response_model: Type[R],
|
||||
payload: Union[Text2ImageTaskCreationRequest, Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
|
||||
node_id: str,
|
||||
estimated_duration: int,
|
||||
poll_interval: int,
|
||||
) -> Type[R]:
|
||||
initial_response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=url,
|
||||
method=HttpMethod.POST,
|
||||
request_model=request_model,
|
||||
response_model=TaskCreationResponse,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
|
||||
return await PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=response_model,
|
||||
),
|
||||
completed_statuses=["SUCCEEDED"],
|
||||
failed_statuses=["FAILED", "CANCELED", "UNKNOWN"],
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
estimated_duration=estimated_duration,
|
||||
poll_interval=poll_interval,
|
||||
node_id=node_id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
|
||||
|
||||
class WanTextToImageApi(comfy_io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="WanTextToImageApi",
|
||||
display_name="Wan Text to Image",
|
||||
category="api node/image/Wan",
|
||||
description="Generates image based on text prompt.",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=["wan2.5-t2i-preview"],
|
||||
default="wan2.5-t2i-preview",
|
||||
tooltip="Model to use.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"width",
|
||||
default=1024,
|
||||
min=768,
|
||||
max=1440,
|
||||
step=32,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"height",
|
||||
default=1024,
|
||||
min=768,
|
||||
max=1440,
|
||||
step=32,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to use for generation.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"prompt_extend",
|
||||
default=True,
|
||||
tooltip="Whether to enhance the prompt with AI assistance.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
seed: int = 0,
|
||||
prompt_extend: bool = True,
|
||||
watermark: bool = True,
|
||||
):
|
||||
payload = Text2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt),
|
||||
parameters=Txt2ImageParametersField(
|
||||
size=f"{width}*{height}",
|
||||
seed=seed,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
),
|
||||
)
|
||||
response = await process_task(
|
||||
{
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
"/proxy/wan/api/v1/services/aigc/text2image/image-synthesis",
|
||||
request_model=Text2ImageTaskCreationRequest,
|
||||
response_model=ImageTaskStatusResponse,
|
||||
payload=payload,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=9,
|
||||
poll_interval=3,
|
||||
)
|
||||
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
|
||||
|
||||
|
||||
class WanTextToVideoApi(comfy_io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="WanTextToVideoApi",
|
||||
display_name="Wan Text to Video",
|
||||
category="api node/video/Wan",
|
||||
description="Generates video based on text prompt.",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=["wan2.5-t2v-preview"],
|
||||
default="wan2.5-t2v-preview",
|
||||
tooltip="Model to use.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"size",
|
||||
options=[
|
||||
"480p: 1:1 (624x624)",
|
||||
"480p: 16:9 (832x480)",
|
||||
"480p: 9:16 (480x832)",
|
||||
"720p: 1:1 (960x960)",
|
||||
"720p: 16:9 (1280x720)",
|
||||
"720p: 9:16 (720x1280)",
|
||||
"720p: 4:3 (1088x832)",
|
||||
"720p: 3:4 (832x1088)",
|
||||
"1080p: 1:1 (1440x1440)",
|
||||
"1080p: 16:9 (1920x1080)",
|
||||
"1080p: 9:16 (1080x1920)",
|
||||
"1080p: 4:3 (1632x1248)",
|
||||
"1080p: 3:4 (1248x1632)",
|
||||
],
|
||||
default="480p: 1:1 (624x624)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=10,
|
||||
step=5,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Available durations: 5 and 10 seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Audio.Input(
|
||||
"audio",
|
||||
optional=True,
|
||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to use for generation.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="If there is no audio input, generate audio automatically.",
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"prompt_extend",
|
||||
default=True,
|
||||
tooltip="Whether to enhance the prompt with AI assistance.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
size: str = "480p: 1:1 (624x624)",
|
||||
duration: int = 5,
|
||||
audio: Optional[Input.Audio] = None,
|
||||
seed: int = 0,
|
||||
generate_audio: bool = False,
|
||||
prompt_extend: bool = True,
|
||||
watermark: bool = True,
|
||||
):
|
||||
width, height = RES_IN_PARENS.search(size).groups()
|
||||
audio_url = None
|
||||
if audio is not None:
|
||||
validate_audio_duration(audio, 3.0, 29.0)
|
||||
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
|
||||
payload = Text2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url),
|
||||
parameters=Text2VideoParametersField(
|
||||
size=f"{width}*{height}",
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
audio=generate_audio,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
),
|
||||
)
|
||||
response = await process_task(
|
||||
{
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
|
||||
request_model=Text2VideoTaskCreationRequest,
|
||||
response_model=VideoTaskStatusResponse,
|
||||
payload=payload,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=120 * int(duration / 5),
|
||||
poll_interval=6,
|
||||
)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||
|
||||
|
||||
class WanImageToVideoApi(comfy_io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="WanImageToVideoApi",
|
||||
display_name="Wan Image to Video",
|
||||
category="api node/video/Wan",
|
||||
description="Generates video based on the first frame and text prompt.",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=["wan2.5-i2v-preview"],
|
||||
default="wan2.5-i2v-preview",
|
||||
tooltip="Model to use.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"480P",
|
||||
"720P",
|
||||
"1080P",
|
||||
],
|
||||
default="480P",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=10,
|
||||
step=5,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Available durations: 5 and 10 seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Audio.Input(
|
||||
"audio",
|
||||
optional=True,
|
||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to use for generation.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="If there is no audio input, generate audio automatically.",
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"prompt_extend",
|
||||
default=True,
|
||||
tooltip="Whether to enhance the prompt with AI assistance.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
resolution: str = "480P",
|
||||
duration: int = 5,
|
||||
audio: Optional[Input.Audio] = None,
|
||||
seed: int = 0,
|
||||
generate_audio: bool = False,
|
||||
prompt_extend: bool = True,
|
||||
watermark: bool = True,
|
||||
):
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000*2000)
|
||||
audio_url = None
|
||||
if audio is not None:
|
||||
validate_audio_duration(audio, 3.0, 29.0)
|
||||
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
|
||||
payload = Image2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
input=Image2VideoInputField(
|
||||
prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url
|
||||
),
|
||||
parameters=Image2VideoParametersField(
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
audio=generate_audio,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
),
|
||||
)
|
||||
response = await process_task(
|
||||
{
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
|
||||
request_model=Image2VideoTaskCreationRequest,
|
||||
response_model=VideoTaskStatusResponse,
|
||||
payload=payload,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=120 * int(duration / 5),
|
||||
poll_interval=6,
|
||||
)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||
|
||||
|
||||
class WanApiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
return [
|
||||
WanTextToImageApi,
|
||||
WanTextToVideoApi,
|
||||
WanImageToVideoApi,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> WanApiExtension:
|
||||
return WanApiExtension()
|
||||
@ -11,6 +11,7 @@ import json
|
||||
import random
|
||||
import hashlib
|
||||
import node_helpers
|
||||
import logging
|
||||
from comfy.cli_args import args
|
||||
from comfy.comfy_types import FileLocator
|
||||
|
||||
@ -364,6 +365,216 @@ class RecordAudio:
|
||||
return (audio, )
|
||||
|
||||
|
||||
class TrimAudioDuration:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"audio": ("AUDIO",),
|
||||
"start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}),
|
||||
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}),
|
||||
},
|
||||
}
|
||||
|
||||
FUNCTION = "trim"
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
CATEGORY = "audio"
|
||||
DESCRIPTION = "Trim audio tensor into chosen time range."
|
||||
|
||||
def trim(self, audio, start_index, duration):
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
audio_length = waveform.shape[-1]
|
||||
|
||||
if start_index < 0:
|
||||
start_frame = audio_length + int(round(start_index * sample_rate))
|
||||
else:
|
||||
start_frame = int(round(start_index * sample_rate))
|
||||
start_frame = max(0, min(start_frame, audio_length - 1))
|
||||
|
||||
end_frame = start_frame + int(round(duration * sample_rate))
|
||||
end_frame = max(0, min(end_frame, audio_length))
|
||||
|
||||
if start_frame >= end_frame:
|
||||
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
|
||||
|
||||
return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},)
|
||||
|
||||
|
||||
class SplitAudioChannels:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"audio": ("AUDIO",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("AUDIO", "AUDIO")
|
||||
RETURN_NAMES = ("left", "right")
|
||||
FUNCTION = "separate"
|
||||
CATEGORY = "audio"
|
||||
DESCRIPTION = "Separates the audio into left and right channels."
|
||||
|
||||
def separate(self, audio):
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
if waveform.shape[1] != 2:
|
||||
raise ValueError("AudioSplit: Input audio has only one channel.")
|
||||
|
||||
left_channel = waveform[..., 0:1, :]
|
||||
right_channel = waveform[..., 1:2, :]
|
||||
|
||||
return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
|
||||
|
||||
|
||||
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
|
||||
if sample_rate_1 != sample_rate_2:
|
||||
if sample_rate_1 > sample_rate_2:
|
||||
waveform_2 = torchaudio.functional.resample(waveform_2, sample_rate_2, sample_rate_1)
|
||||
output_sample_rate = sample_rate_1
|
||||
logging.info(f"Resampling audio2 from {sample_rate_2}Hz to {sample_rate_1}Hz for merging.")
|
||||
else:
|
||||
waveform_1 = torchaudio.functional.resample(waveform_1, sample_rate_1, sample_rate_2)
|
||||
output_sample_rate = sample_rate_2
|
||||
logging.info(f"Resampling audio1 from {sample_rate_1}Hz to {sample_rate_2}Hz for merging.")
|
||||
else:
|
||||
output_sample_rate = sample_rate_1
|
||||
return waveform_1, waveform_2, output_sample_rate
|
||||
|
||||
|
||||
class AudioConcat:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"audio1": ("AUDIO",),
|
||||
"audio2": ("AUDIO",),
|
||||
"direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
FUNCTION = "concat"
|
||||
CATEGORY = "audio"
|
||||
DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction."
|
||||
|
||||
def concat(self, audio1, audio2, direction):
|
||||
waveform_1 = audio1["waveform"]
|
||||
waveform_2 = audio2["waveform"]
|
||||
sample_rate_1 = audio1["sample_rate"]
|
||||
sample_rate_2 = audio2["sample_rate"]
|
||||
|
||||
if waveform_1.shape[1] == 1:
|
||||
waveform_1 = waveform_1.repeat(1, 2, 1)
|
||||
logging.info("AudioConcat: Converted mono audio1 to stereo by duplicating the channel.")
|
||||
if waveform_2.shape[1] == 1:
|
||||
waveform_2 = waveform_2.repeat(1, 2, 1)
|
||||
logging.info("AudioConcat: Converted mono audio2 to stereo by duplicating the channel.")
|
||||
|
||||
waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
|
||||
|
||||
if direction == 'after':
|
||||
concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2)
|
||||
elif direction == 'before':
|
||||
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
|
||||
|
||||
return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},)
|
||||
|
||||
|
||||
class AudioMerge:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"audio1": ("AUDIO",),
|
||||
"audio2": ("AUDIO",),
|
||||
"merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}),
|
||||
},
|
||||
}
|
||||
|
||||
FUNCTION = "merge"
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
CATEGORY = "audio"
|
||||
DESCRIPTION = "Combine two audio tracks by overlaying their waveforms."
|
||||
|
||||
def merge(self, audio1, audio2, merge_method):
|
||||
waveform_1 = audio1["waveform"]
|
||||
waveform_2 = audio2["waveform"]
|
||||
sample_rate_1 = audio1["sample_rate"]
|
||||
sample_rate_2 = audio2["sample_rate"]
|
||||
|
||||
waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
|
||||
|
||||
length_1 = waveform_1.shape[-1]
|
||||
length_2 = waveform_2.shape[-1]
|
||||
|
||||
if length_2 > length_1:
|
||||
logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.")
|
||||
waveform_2 = waveform_2[..., :length_1]
|
||||
elif length_2 < length_1:
|
||||
logging.info(f"AudioMerge: Padding audio2 from {length_2} to {length_1} samples to match audio1 length.")
|
||||
pad_shape = list(waveform_2.shape)
|
||||
pad_shape[-1] = length_1 - length_2
|
||||
pad_tensor = torch.zeros(pad_shape, dtype=waveform_2.dtype, device=waveform_2.device)
|
||||
waveform_2 = torch.cat((waveform_2, pad_tensor), dim=-1)
|
||||
|
||||
if merge_method == "add":
|
||||
waveform = waveform_1 + waveform_2
|
||||
elif merge_method == "subtract":
|
||||
waveform = waveform_1 - waveform_2
|
||||
elif merge_method == "multiply":
|
||||
waveform = waveform_1 * waveform_2
|
||||
elif merge_method == "mean":
|
||||
waveform = (waveform_1 + waveform_2) / 2
|
||||
|
||||
max_val = waveform.abs().max()
|
||||
if max_val > 1.0:
|
||||
waveform = waveform / max_val
|
||||
|
||||
return ({"waveform": waveform, "sample_rate": output_sample_rate},)
|
||||
|
||||
|
||||
class AudioAdjustVolume:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"audio": ("AUDIO",),
|
||||
"volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
FUNCTION = "adjust_volume"
|
||||
CATEGORY = "audio"
|
||||
|
||||
def adjust_volume(self, audio, volume):
|
||||
if volume == 0:
|
||||
return (audio,)
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
gain = 10 ** (volume / 20)
|
||||
waveform = waveform * gain
|
||||
|
||||
return ({"waveform": waveform, "sample_rate": sample_rate},)
|
||||
|
||||
|
||||
class EmptyAudio:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}),
|
||||
"sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}),
|
||||
"channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
FUNCTION = "create_empty_audio"
|
||||
CATEGORY = "audio"
|
||||
|
||||
def create_empty_audio(self, duration, sample_rate, channels):
|
||||
num_samples = int(round(duration * sample_rate))
|
||||
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
|
||||
return ({"waveform": waveform, "sample_rate": sample_rate},)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyLatentAudio": EmptyLatentAudio,
|
||||
"VAEEncodeAudio": VAEEncodeAudio,
|
||||
@ -375,6 +586,12 @@ NODE_CLASS_MAPPINGS = {
|
||||
"PreviewAudio": PreviewAudio,
|
||||
"ConditioningStableAudio": ConditioningStableAudio,
|
||||
"RecordAudio": RecordAudio,
|
||||
"TrimAudioDuration": TrimAudioDuration,
|
||||
"SplitAudioChannels": SplitAudioChannels,
|
||||
"AudioConcat": AudioConcat,
|
||||
"AudioMerge": AudioMerge,
|
||||
"AudioAdjustVolume": AudioAdjustVolume,
|
||||
"EmptyAudio": EmptyAudio,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@ -387,4 +604,10 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SaveAudioMP3": "Save Audio (MP3)",
|
||||
"SaveAudioOpus": "Save Audio (Opus)",
|
||||
"RecordAudio": "Record Audio",
|
||||
"TrimAudioDuration": "Trim Audio Duration",
|
||||
"SplitAudioChannels": "Split Audio Channels",
|
||||
"AudioConcat": "Audio Concat",
|
||||
"AudioMerge": "Audio Merge",
|
||||
"AudioAdjustVolume": "Audio Adjust Volume",
|
||||
"EmptyAudio": "Empty Audio",
|
||||
}
|
||||
|
||||
@ -5,19 +5,30 @@ import torch
|
||||
class DifferentialDiffusion():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL", ),
|
||||
}}
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", ),
|
||||
},
|
||||
"optional": {
|
||||
"strength": ("FLOAT", {
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "apply"
|
||||
CATEGORY = "_for_testing"
|
||||
INIT = False
|
||||
|
||||
def apply(self, model):
|
||||
def apply(self, model, strength=1.0):
|
||||
model = model.clone()
|
||||
model.set_model_denoise_mask_function(self.forward)
|
||||
return (model,)
|
||||
model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength))
|
||||
return (model, )
|
||||
|
||||
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict):
|
||||
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
|
||||
model = extra_options["model"]
|
||||
step_sigmas = extra_options["sigmas"]
|
||||
sigma_to = model.inner_model.model_sampling.sigma_min
|
||||
@ -31,7 +42,15 @@ class DifferentialDiffusion():
|
||||
|
||||
threshold = (current_ts - ts_to) / (ts_from - ts_to)
|
||||
|
||||
return (denoise_mask >= threshold).to(denoise_mask.dtype)
|
||||
# Generate the binary mask based on the threshold
|
||||
binary_mask = (denoise_mask >= threshold).to(denoise_mask.dtype)
|
||||
|
||||
# Blend binary mask with the original denoise_mask using strength
|
||||
if strength and strength < 1:
|
||||
blended_mask = strength * binary_mask + (1 - strength) * denoise_mask
|
||||
return blended_mask
|
||||
else:
|
||||
return binary_mask
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@ -233,6 +233,7 @@ class Sharpen:
|
||||
|
||||
kernel_size = sharpen_radius * 2 + 1
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
|
||||
kernel = kernel.to(dtype=image.dtype)
|
||||
center = kernel_size // 2
|
||||
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
|
||||
|
||||
@ -43,6 +43,61 @@ class TextEncodeQwenImageEdit:
|
||||
return (conditioning, )
|
||||
|
||||
|
||||
class TextEncodeQwenImageEditPlus:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
},
|
||||
"optional": {"vae": ("VAE", ),
|
||||
"image1": ("IMAGE", ),
|
||||
"image2": ("IMAGE", ),
|
||||
"image3": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def encode(self, clip, prompt, vae=None, image1=None, image2=None, image3=None):
|
||||
ref_latents = []
|
||||
images = [image1, image2, image3]
|
||||
images_vl = []
|
||||
llama_template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
image_prompt = ""
|
||||
|
||||
for i, image in enumerate(images):
|
||||
if image is not None:
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(384 * 384)
|
||||
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
|
||||
images_vl.append(s.movedim(1, -1))
|
||||
if vae is not None:
|
||||
total = int(1024 * 1024)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
width = round(samples.shape[3] * scale_by / 8.0) * 8
|
||||
height = round(samples.shape[2] * scale_by / 8.0) * 8
|
||||
|
||||
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
|
||||
ref_latents.append(vae.encode(s.movedim(1, -1)[:, :, :, :3]))
|
||||
|
||||
image_prompt += "Picture {}: <|vision_start|><|image_pad|><|vision_end|>".format(i + 1)
|
||||
|
||||
tokens = clip.tokenize(image_prompt + prompt, images=images_vl, llama_template=llama_template)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
if len(ref_latents) > 0:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
|
||||
return (conditioning, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TextEncodeQwenImageEdit": TextEncodeQwenImageEdit,
|
||||
"TextEncodeQwenImageEditPlus": TextEncodeQwenImageEditPlus,
|
||||
}
|
||||
|
||||
@ -38,6 +38,23 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||
return new_dict
|
||||
|
||||
|
||||
def process_cond_list(d, prefix=""):
|
||||
if hasattr(d, "__iter__") and not hasattr(d, "items"):
|
||||
for index, item in enumerate(d):
|
||||
process_cond_list(item, f"{prefix}.{index}")
|
||||
return d
|
||||
elif hasattr(d, "items"):
|
||||
for k, v in list(d.items()):
|
||||
if isinstance(v, dict):
|
||||
process_cond_list(v, f"{prefix}.{k}")
|
||||
elif isinstance(v, torch.Tensor):
|
||||
d[k] = v.clone()
|
||||
elif isinstance(v, (list, tuple)):
|
||||
for index, item in enumerate(v):
|
||||
process_cond_list(item, f"{prefix}.{k}.{index}")
|
||||
return d
|
||||
|
||||
|
||||
class TrainSampler(comfy.samplers.Sampler):
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||
self.loss_fn = loss_fn
|
||||
@ -50,6 +67,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self.training_dtype = training_dtype
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
model_wrap.conds = process_cond_list(model_wrap.conds)
|
||||
cond = model_wrap.conds["positive"]
|
||||
dataset_size = sigmas.size(0)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -287,7 +287,6 @@ class WanVaceToVideo(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="WanVaceToVideo",
|
||||
category="conditioning/video_models",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
@ -375,7 +374,6 @@ class TrimVideoLatent(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="TrimVideoLatent",
|
||||
category="latent/video",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
io.Int.Input("trim_amount", default=0, min=0, max=99999),
|
||||
@ -969,7 +967,6 @@ class WanSoundImageToVideo(io.ComfyNode):
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1000,7 +997,6 @@ class WanSoundImageToVideoExtend(io.ComfyNode):
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1095,10 +1091,6 @@ class WanHuMoImageToVideo(io.ComfyNode):
|
||||
audio_emb = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280]
|
||||
audio_emb, _ = get_audio_emb_window(audio_emb, length, frame0_idx=0)
|
||||
|
||||
# pad for ref latent
|
||||
zero_audio_pad = torch.zeros(ref_latent.shape[2], *audio_emb.shape[1:], device=audio_emb.device, dtype=audio_emb.dtype)
|
||||
audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0)
|
||||
|
||||
audio_emb = audio_emb.unsqueeze(0)
|
||||
audio_emb_neg = torch.zeros_like(audio_emb)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_emb})
|
||||
@ -1112,6 +1104,146 @@ class WanHuMoImageToVideo(io.ComfyNode):
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
class WanAnimateToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanAnimateToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("reference_image", optional=True),
|
||||
io.Image.Input("face_video", optional=True),
|
||||
io.Image.Input("pose_video", optional=True),
|
||||
io.Int.Input("continue_motion_max_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Image.Input("background_video", optional=True),
|
||||
io.Mask.Input("character_mask", optional=True),
|
||||
io.Image.Input("continue_motion", optional=True),
|
||||
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="The amount of frames to seek in all the input videos. Used for generating longer videos by chunk. Connect to the video_frame_offset output of the previous node for extending a video."),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
io.Int.Output(display_name="trim_latent"),
|
||||
io.Int.Output(display_name="trim_image"),
|
||||
io.Int.Output(display_name="video_frame_offset"),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, video_frame_offset, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None, background_video=None, character_mask=None) -> io.NodeOutput:
|
||||
trim_to_pose_video = False
|
||||
latent_length = ((length - 1) // 4) + 1
|
||||
latent_width = width // 8
|
||||
latent_height = height // 8
|
||||
trim_latent = 0
|
||||
|
||||
if reference_image is None:
|
||||
reference_image = torch.zeros((1, height, width, 3))
|
||||
|
||||
image = comfy.utils.common_upscale(reference_image[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||
mask = torch.zeros((1, 4, concat_latent_image.shape[-3], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype)
|
||||
trim_latent += concat_latent_image.shape[2]
|
||||
ref_motion_latent_length = 0
|
||||
|
||||
if continue_motion is None:
|
||||
image = torch.ones((length, height, width, 3)) * 0.5
|
||||
else:
|
||||
continue_motion = continue_motion[-continue_motion_max_frames:]
|
||||
video_frame_offset -= continue_motion.shape[0]
|
||||
video_frame_offset = max(0, video_frame_offset)
|
||||
continue_motion = comfy.utils.common_upscale(continue_motion[-length:].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
|
||||
image = torch.ones((length, height, width, continue_motion.shape[-1]), device=continue_motion.device, dtype=continue_motion.dtype) * 0.5
|
||||
image[:continue_motion.shape[0]] = continue_motion
|
||||
ref_motion_latent_length += ((continue_motion.shape[0] - 1) // 4) + 1
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
if pose_video is not None:
|
||||
if pose_video.shape[0] <= video_frame_offset:
|
||||
pose_video = None
|
||||
else:
|
||||
pose_video = pose_video[video_frame_offset:]
|
||||
|
||||
if pose_video is not None:
|
||||
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
|
||||
if not trim_to_pose_video:
|
||||
if pose_video.shape[0] < length:
|
||||
pose_video = torch.cat((pose_video,) + (pose_video[-1:],) * (length - pose_video.shape[0]), dim=0)
|
||||
|
||||
pose_video_latent = vae.encode(pose_video[:, :, :, :3])
|
||||
positive = node_helpers.conditioning_set_values(positive, {"pose_video_latent": pose_video_latent})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"pose_video_latent": pose_video_latent})
|
||||
|
||||
if trim_to_pose_video:
|
||||
latent_length = pose_video_latent.shape[2]
|
||||
length = latent_length * 4 - 3
|
||||
image = image[:length]
|
||||
|
||||
if face_video is not None:
|
||||
if face_video.shape[0] <= video_frame_offset:
|
||||
face_video = None
|
||||
else:
|
||||
face_video = face_video[video_frame_offset:]
|
||||
|
||||
if face_video is not None:
|
||||
face_video = comfy.utils.common_upscale(face_video[:length].movedim(-1, 1), 512, 512, "area", "center") * 2.0 - 1.0
|
||||
face_video = face_video.movedim(0, 1).unsqueeze(0)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0})
|
||||
|
||||
ref_images_num = max(0, ref_motion_latent_length * 4 - 3)
|
||||
if background_video is not None:
|
||||
if background_video.shape[0] > video_frame_offset:
|
||||
background_video = background_video[video_frame_offset:]
|
||||
background_video = comfy.utils.common_upscale(background_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
|
||||
if background_video.shape[0] > ref_images_num:
|
||||
image[ref_images_num:background_video.shape[0]] = background_video[ref_images_num:]
|
||||
|
||||
mask_refmotion = torch.ones((1, 1, latent_length * 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype)
|
||||
if continue_motion is not None:
|
||||
mask_refmotion[:, :, :ref_motion_latent_length * 4] = 0.0
|
||||
|
||||
if character_mask is not None:
|
||||
if character_mask.shape[0] > video_frame_offset or character_mask.shape[0] == 1:
|
||||
if character_mask.shape[0] == 1:
|
||||
character_mask = character_mask.repeat((length,) + (1,) * (character_mask.ndim - 1))
|
||||
else:
|
||||
character_mask = character_mask[video_frame_offset:]
|
||||
if character_mask.ndim == 3:
|
||||
character_mask = character_mask.unsqueeze(1)
|
||||
character_mask = character_mask.movedim(0, 1)
|
||||
if character_mask.ndim == 4:
|
||||
character_mask = character_mask.unsqueeze(1)
|
||||
character_mask = comfy.utils.common_upscale(character_mask[:, :, :length], concat_latent_image.shape[-1], concat_latent_image.shape[-2], "nearest-exact", "center")
|
||||
if character_mask.shape[2] > ref_images_num:
|
||||
mask_refmotion[:, :, ref_images_num:character_mask.shape[2]] = character_mask[:, :, ref_images_num:]
|
||||
|
||||
concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2)
|
||||
|
||||
|
||||
mask_refmotion = mask_refmotion.view(1, mask_refmotion.shape[2] // 4, 4, mask_refmotion.shape[3], mask_refmotion.shape[4]).transpose(1, 2)
|
||||
mask = torch.cat((mask, mask_refmotion), dim=2)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
|
||||
latent = torch.zeros([batch_size, 16, latent_length + trim_latent, latent_height, latent_width], device=comfy.model_management.intermediate_device())
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent, trim_latent, max(0, ref_motion_latent_length * 4 - 3), video_frame_offset + length)
|
||||
|
||||
class Wan22ImageToVideoLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -1173,6 +1305,7 @@ class WanExtension(ComfyExtension):
|
||||
WanSoundImageToVideo,
|
||||
WanSoundImageToVideoExtend,
|
||||
WanHuMoImageToVideo,
|
||||
WanAnimateToVideo,
|
||||
Wan22ImageToVideoLatent,
|
||||
]
|
||||
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.59"
|
||||
__version__ = "0.3.60"
|
||||
|
||||
1
nodes.py
1
nodes.py
@ -2361,6 +2361,7 @@ async def init_builtin_api_nodes():
|
||||
"nodes_rodin.py",
|
||||
"nodes_gemini.py",
|
||||
"nodes_vidu.py",
|
||||
"nodes_wan.py",
|
||||
]
|
||||
|
||||
if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.59"
|
||||
version = "0.3.60"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.26.11
|
||||
comfyui-workflow-templates==0.1.81
|
||||
comfyui-frontend-package==1.26.13
|
||||
comfyui-workflow-templates==0.1.86
|
||||
comfyui-embedded-docs==0.2.6
|
||||
torch
|
||||
torchsde
|
||||
|
||||
@ -648,7 +648,14 @@ class PromptServer():
|
||||
max_items = request.rel_url.query.get("max_items", None)
|
||||
if max_items is not None:
|
||||
max_items = int(max_items)
|
||||
return web.json_response(self.prompt_queue.get_history(max_items=max_items))
|
||||
|
||||
offset = request.rel_url.query.get("offset", None)
|
||||
if offset is not None:
|
||||
offset = int(offset)
|
||||
else:
|
||||
offset = -1
|
||||
|
||||
return web.json_response(self.prompt_queue.get_history(max_items=max_items, offset=offset))
|
||||
|
||||
@routes.get("/history/{prompt_id}")
|
||||
async def get_history_prompt_id(request):
|
||||
|
||||
@ -84,6 +84,21 @@ class ComfyClient:
|
||||
with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response:
|
||||
return json.loads(response.read())
|
||||
|
||||
def get_all_history(self, max_items=None, offset=None):
|
||||
url = "http://{}/history".format(self.server_address)
|
||||
params = {}
|
||||
if max_items is not None:
|
||||
params["max_items"] = max_items
|
||||
if offset is not None:
|
||||
params["offset"] = offset
|
||||
|
||||
if params:
|
||||
url_values = urllib.parse.urlencode(params)
|
||||
url = "{}?{}".format(url, url_values)
|
||||
|
||||
with urllib.request.urlopen(url) as response:
|
||||
return json.loads(response.read())
|
||||
|
||||
def set_test_name(self, name):
|
||||
self.test_name = name
|
||||
|
||||
@ -498,7 +513,6 @@ class TestExecution:
|
||||
assert len(images1) == 1, "Should have 1 image"
|
||||
assert len(images2) == 1, "Should have 1 image"
|
||||
|
||||
|
||||
# This tests that only constant outputs are used in the call to `IS_CHANGED`
|
||||
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
@ -762,3 +776,92 @@ class TestExecution:
|
||||
except urllib.error.HTTPError:
|
||||
pass # Expected behavior
|
||||
|
||||
def _create_history_item(self, client, builder):
|
||||
g = GraphBuilder(prefix="offset_test")
|
||||
input_node = g.node(
|
||||
"StubImage", content="BLACK", height=32, width=32, batch_size=1
|
||||
)
|
||||
g.node("SaveImage", images=input_node.out(0))
|
||||
return client.run(g)
|
||||
|
||||
def test_offset_returns_different_items_than_beginning_of_history(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test that offset skips items at the beginning"""
|
||||
for _ in range(5):
|
||||
self._create_history_item(client, builder)
|
||||
|
||||
first_two = client.get_all_history(max_items=2, offset=0)
|
||||
next_two = client.get_all_history(max_items=2, offset=2)
|
||||
|
||||
assert set(first_two.keys()).isdisjoint(
|
||||
set(next_two.keys())
|
||||
), "Offset should skip initial items"
|
||||
|
||||
def test_offset_beyond_history_length_returns_empty(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test offset larger than total history returns empty result"""
|
||||
self._create_history_item(client, builder)
|
||||
|
||||
result = client.get_all_history(offset=100)
|
||||
assert len(result) == 0, "Large offset should return no items"
|
||||
|
||||
def test_offset_at_exact_history_length_returns_empty(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test offset equal to history length returns empty"""
|
||||
for _ in range(3):
|
||||
self._create_history_item(client, builder)
|
||||
|
||||
all_history = client.get_all_history()
|
||||
result = client.get_all_history(offset=len(all_history))
|
||||
assert len(result) == 0, "Offset at history length should return empty"
|
||||
|
||||
def test_offset_zero_equals_no_offset_parameter(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test offset=0 behaves same as omitting offset"""
|
||||
self._create_history_item(client, builder)
|
||||
|
||||
with_zero = client.get_all_history(offset=0)
|
||||
without_offset = client.get_all_history()
|
||||
|
||||
assert with_zero == without_offset, "offset=0 should equal no offset"
|
||||
|
||||
def test_offset_without_max_items_skips_from_beginning(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test offset alone (no max_items) returns remaining items"""
|
||||
for _ in range(4):
|
||||
self._create_history_item(client, builder)
|
||||
|
||||
all_items = client.get_all_history()
|
||||
offset_items = client.get_all_history(offset=2)
|
||||
|
||||
assert (
|
||||
len(offset_items) == len(all_items) - 2
|
||||
), "Offset should skip specified number of items"
|
||||
|
||||
def test_offset_with_max_items_returns_correct_window(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test offset + max_items returns correct slice of history"""
|
||||
for _ in range(6):
|
||||
self._create_history_item(client, builder)
|
||||
|
||||
window = client.get_all_history(max_items=2, offset=1)
|
||||
assert len(window) <= 2, "Should respect max_items limit"
|
||||
|
||||
def test_offset_near_end_returns_remaining_items_only(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test offset near end of history returns only remaining items"""
|
||||
for _ in range(3):
|
||||
self._create_history_item(client, builder)
|
||||
|
||||
all_history = client.get_all_history()
|
||||
# Offset to near the end
|
||||
result = client.get_all_history(max_items=5, offset=len(all_history) - 1)
|
||||
|
||||
assert len(result) <= 1, "Should return at most 1 item when offset is near end"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user