diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 9cf3c171d..54616e6eb 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1355,7 +1355,7 @@ class WanT2VCrossAttentionGather(WanSelfAttention): x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options) - x = x.transpose(1, 2).view(b, -1, n, d).flatten(2) + x = x.transpose(1, 2).reshape(b, -1, n * d) x = self.o(x) return x @@ -1551,6 +1551,9 @@ class HumoWanModel(WanModel): context_img_len = None if audio_embed is not None: + if reference_latent is not None: + zero_audio_pad = torch.zeros(audio_embed.shape[0], reference_latent.shape[-3], *audio_embed.shape[2:], device=audio_embed.device, dtype=audio_embed.dtype) + audio_embed = torch.cat([audio_embed, zero_audio_pad], dim=1) audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2) else: audio = None diff --git a/comfy/ldm/wan/model_animate.py b/comfy/ldm/wan/model_animate.py new file mode 100644 index 000000000..7c87835d4 --- /dev/null +++ b/comfy/ldm/wan/model_animate.py @@ -0,0 +1,548 @@ +from torch import nn +import torch +from typing import Tuple, Optional +from einops import rearrange +import torch.nn.functional as F +import math +from .model import WanModel, sinusoidal_embedding_1d +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", operations=None, **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = operations.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class FaceEncoder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None, operations=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1, operations=operations, **factory_kwargs) + self.norm1 = operations.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(1024, 1024, 3, stride=2, operations=operations, **factory_kwargs) + self.conv3 = CausalConv1d(1024, 1024, 3, stride=2, operations=operations, **factory_kwargs) + + self.out_proj = operations.Linear(1024, hidden_dim, **factory_kwargs) + self.norm1 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm2 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm3 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs)) + + def forward(self, x): + + x = rearrange(x, "b t c -> b c t") + b, c, t = x.shape + + x = self.conv1_local(x) + x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) + + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = self.out_proj(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + return x_local + + +def get_norm_layer(norm_layer, operations=None): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return operations.LayerNorm + elif norm_layer == "rms": + return operations.RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +class FaceAdapter(nn.Module): + def __init__( + self, + hidden_dim: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + num_adapter_layers: int = 1, + dtype=None, device=None, operations=None + ): + + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.hidden_size = hidden_dim + self.heads_num = heads_num + self.fuser_blocks = nn.ModuleList( + [ + FaceBlock( + self.hidden_size, + self.heads_num, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + operations=operations, + **factory_kwargs, + ) + for _ in range(num_adapter_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + motion_embed: torch.Tensor, + idx: int, + freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None, + freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + + return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k) + + + +class FaceBlock(nn.Module): + def __init__( + self, + hidden_size: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + operations=None + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + self.scale = qk_scale or head_dim**-0.5 + + self.linear1_kv = operations.Linear(hidden_size, hidden_size * 2, **factory_kwargs) + self.linear1_q = operations.Linear(hidden_size, hidden_size, **factory_kwargs) + + self.linear2 = operations.Linear(hidden_size, hidden_size, **factory_kwargs) + + qk_norm_layer = get_norm_layer(qk_norm_type, operations=operations) + self.q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.pre_norm_feat = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.pre_norm_motion = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + def forward( + self, + x: torch.Tensor, + motion_vec: torch.Tensor, + motion_mask: Optional[torch.Tensor] = None, + # use_context_parallel=False, + ) -> torch.Tensor: + + B, T, N, C = motion_vec.shape + T_comp = T + + x_motion = self.pre_norm_motion(motion_vec) + x_feat = self.pre_norm_feat(x) + + kv = self.linear1_kv(x_motion) + q = self.linear1_q(x_feat) + + k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num) + q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + k = rearrange(k, "B L N H D -> (B L) N H D") + v = rearrange(v, "B L N H D -> (B L) N H D") + + q = rearrange(q, "B (L S) H D -> (B L) S (H D)", L=T_comp) + + attn = optimized_attention(q, k, v, heads=self.heads_num) + + attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp) + + output = self.linear2(attn) + + if motion_mask is not None: + output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1) + + return output + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/upfirdn2d/upfirdn2d.py#L162 +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0)] + + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1) + return out[:, :, ::down_y, ::down_x] + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/fused_act/fused_act.py#L81 +class FusedLeakyReLU(torch.nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5, dtype=None, device=None): + super().__init__() + self.bias = torch.nn.Parameter(torch.empty(1, channel, 1, 1, dtype=dtype, device=device)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype), self.negative_slope, self.scale) + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + +class Blur(torch.nn.Module): + def __init__(self, kernel, pad, dtype=None, device=None): + super().__init__() + kernel = torch.tensor(kernel, dtype=dtype, device=device) + kernel = kernel[None, :] * kernel[:, None] + kernel = kernel / kernel.sum() + self.register_buffer('kernel', kernel) + self.pad = pad + + def forward(self, input): + return upfirdn2d(input, comfy.model_management.cast_to(self.kernel, dtype=input.dtype, device=input.device), pad=self.pad) + +#https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L590 +class ScaledLeakyReLU(torch.nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + self.negative_slope = negative_slope + + def forward(self, input): + return F.leaky_relu(input, negative_slope=self.negative_slope) + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L605 +class EqualConv2d(torch.nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, dtype=None, device=None, operations=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(out_channel, in_channel, kernel_size, kernel_size, device=device, dtype=dtype)) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + self.stride = stride + self.padding = padding + self.bias = torch.nn.Parameter(torch.empty(out_channel, device=device, dtype=dtype)) if bias else None + + def forward(self, input): + if self.bias is None: + bias = None + else: + bias = comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype) + + return F.conv2d(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias, stride=self.stride, padding=self.padding) + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L134 +class EqualLinear(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None, dtype=None, device=None, operations=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(out_dim, in_dim, device=device, dtype=dtype)) + self.bias = torch.nn.Parameter(torch.empty(out_dim, device=device, dtype=dtype)) if bias else None + self.activation = activation + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.bias is None: + bias = None + else: + bias = comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype) * self.lr_mul + + if self.activation: + out = F.linear(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale) + return fused_leaky_relu(out, bias) + return F.linear(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias) + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L654 +class ConvLayer(torch.nn.Sequential): + def __init__(self, in_channel, out_channel, kernel_size, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, activate=True, dtype=None, device=None, operations=None): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + layers.append(Blur(blur_kernel, pad=((p + 1) // 2, p // 2))) + stride, padding = 2, 0 + else: + stride, padding = 1, kernel_size // 2 + + layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias and not activate, dtype=dtype, device=device, operations=operations)) + + if activate: + layers.append(FusedLeakyReLU(out_channel) if bias else ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L704 +class ResBlock(torch.nn.Module): + def __init__(self, in_channel, out_channel, dtype=None, device=None, operations=None): + super().__init__() + self.conv1 = ConvLayer(in_channel, in_channel, 3, dtype=dtype, device=device, operations=operations) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True, dtype=dtype, device=device, operations=operations) + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False, dtype=dtype, device=device, operations=operations) + + def forward(self, input): + out = self.conv2(self.conv1(input)) + skip = self.skip(input) + return (out + skip) / math.sqrt(2) + + +class EncoderApp(torch.nn.Module): + def __init__(self, w_dim=512, dtype=None, device=None, operations=None): + super().__init__() + kwargs = {"device": device, "dtype": dtype, "operations": operations} + + self.convs = torch.nn.ModuleList([ + ConvLayer(3, 32, 1, **kwargs), ResBlock(32, 64, **kwargs), + ResBlock(64, 128, **kwargs), ResBlock(128, 256, **kwargs), + ResBlock(256, 512, **kwargs), ResBlock(512, 512, **kwargs), + ResBlock(512, 512, **kwargs), ResBlock(512, 512, **kwargs), + EqualConv2d(512, w_dim, 4, padding=0, bias=False, **kwargs) + ]) + + def forward(self, x): + h = x + for conv in self.convs: + h = conv(h) + return h.squeeze(-1).squeeze(-1) + +class Encoder(torch.nn.Module): + def __init__(self, dim=512, motion_dim=20, dtype=None, device=None, operations=None): + super().__init__() + self.net_app = EncoderApp(dim, dtype=dtype, device=device, operations=operations) + self.fc = torch.nn.Sequential(*[EqualLinear(dim, dim, dtype=dtype, device=device, operations=operations) for _ in range(4)] + [EqualLinear(dim, motion_dim, dtype=dtype, device=device, operations=operations)]) + + def encode_motion(self, x): + return self.fc(self.net_app(x)) + +class Direction(torch.nn.Module): + def __init__(self, motion_dim, dtype=None, device=None, operations=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(512, motion_dim, device=device, dtype=dtype)) + self.motion_dim = motion_dim + + def forward(self, input): + stabilized_weight = comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) + 1e-8 * torch.eye(512, self.motion_dim, device=input.device, dtype=input.dtype) + Q, _ = torch.linalg.qr(stabilized_weight.float()) + if input is None: + return Q + return torch.sum(input.unsqueeze(-1) * Q.T.to(input.dtype), dim=1) + +class Synthesis(torch.nn.Module): + def __init__(self, motion_dim, dtype=None, device=None, operations=None): + super().__init__() + self.direction = Direction(motion_dim, dtype=dtype, device=device, operations=operations) + +class Generator(torch.nn.Module): + def __init__(self, style_dim=512, motion_dim=20, dtype=None, device=None, operations=None): + super().__init__() + self.enc = Encoder(style_dim, motion_dim, dtype=dtype, device=device, operations=operations) + self.dec = Synthesis(motion_dim, dtype=dtype, device=device, operations=operations) + + def get_motion(self, img): + motion_feat = self.enc.encode_motion(img) + return self.dec.direction(motion_feat) + +class AnimateWanModel(WanModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__(self, + model_type='animate', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flf_pos_embed_token_number=None, + motion_encoder_dim=512, + image_model=None, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + + self.pose_patch_embedding = operations.Conv3d( + 16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype + ) + + self.motion_encoder = Generator(style_dim=512, motion_dim=20, device=device, dtype=dtype, operations=operations) + + self.face_adapter = FaceAdapter( + heads_num=self.num_heads, + hidden_dim=self.dim, + num_adapter_layers=self.num_layers // 5, + device=device, dtype=dtype, operations=operations + ) + + self.face_encoder = FaceEncoder( + in_dim=motion_encoder_dim, + hidden_dim=self.dim, + num_heads=4, + device=device, dtype=dtype, operations=operations + ) + + def after_patch_embedding(self, x, pose_latents, face_pixel_values): + if pose_latents is not None: + pose_latents = self.pose_patch_embedding(pose_latents) + x[:, :, 1:pose_latents.shape[2] + 1] += pose_latents[:, :, :x.shape[2] - 1] + + if face_pixel_values is None: + return x, None + + b, c, T, h, w = face_pixel_values.shape + face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") + encode_bs = 8 + face_pixel_values_tmp = [] + for i in range(math.ceil(face_pixel_values.shape[0] / encode_bs)): + face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i * encode_bs: (i + 1) * encode_bs])) + + motion_vec = torch.cat(face_pixel_values_tmp) + + motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) + motion_vec = self.face_encoder(motion_vec) + + B, L, H, C = motion_vec.shape + pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + + if motion_vec.shape[1] < x.shape[2]: + B, L, H, C = motion_vec.shape + pad = torch.zeros(B, x.shape[2] - motion_vec.shape[1], H, C).type_as(motion_vec) + motion_vec = torch.cat([motion_vec, pad], dim=1) + else: + motion_vec = motion_vec[:, :x.shape[2]] + return x, motion_vec + + def forward_orig( + self, + x, + t, + context, + clip_fea=None, + pose_latents=None, + face_pixel_values=None, + freqs=None, + transformer_options={}, + **kwargs, + ): + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + full_ref = None + if self.ref_conv is not None: + full_ref = kwargs.get("reference_latent", None) + if full_ref is not None: + full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) + x = torch.concat((full_ref, x), dim=1) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + + if i % 5 == 0 and motion_vec is not None: + x = x + self.face_adapter.fuser_blocks[i // 5](x, motion_vec) + + # head + x = self.head(x, e) + + if full_ref is not None: + x = x[:, full_ref.shape[1]:] + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 70b67b7c1..b0b9cde7d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -39,6 +39,7 @@ import comfy.ldm.cosmos.model import comfy.ldm.cosmos.predict2 import comfy.ldm.lumina.model import comfy.ldm.wan.model +import comfy.ldm.wan.model_animate import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model @@ -1253,6 +1254,23 @@ class WAN21_HuMo(WAN21): return out +class WAN22_Animate(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel) + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + face_video_pixels = kwargs.get("face_video_pixels", None) + if face_video_pixels is not None: + out['face_pixel_values'] = comfy.conds.CONDRegular(face_video_pixels) + + pose_latents = kwargs.get("pose_video_latent", None) + if pose_latents is not None: + out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents)) + return out + class WAN22_S2V(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 72621bed6..46415c17a 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -404,6 +404,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "s2v" elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "humo" + elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "animate" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" diff --git a/comfy/model_management.py b/comfy/model_management.py index bbfc3c7a1..c5b817b62 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -348,7 +348,7 @@ try: # if any((a in arch) for a in ["gfx1201"]): # ENABLE_PYTORCH_ATTENTION = True if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): - if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches + if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches SUPPORT_FP8_OPS = True except: @@ -645,7 +645,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if loaded_model.model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload for i in to_unload: - current_loaded_models.pop(i).model.detach(unpatch_all=False) + model_to_unload = current_loaded_models.pop(i) + model_to_unload.model.detach(unpatch_all=False) + model_to_unload.model_finalizer.detach() total_memory_required = {} for loaded_model in models_to_load: diff --git a/comfy/ops.py b/comfy/ops.py index 55e958adb..9d7dedd37 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -365,12 +365,13 @@ class fp8_ops(manual_cast): return None def forward_comfy_cast_weights(self, input): - try: - out = fp8_linear(self, input) - if out is not None: - return out - except Exception as e: - logging.info("Exception during fp8 op: {}".format(e)) + if not self.training: + try: + out = fp8_linear(self, input) + if out is not None: + return out + except Exception as e: + logging.info("Exception during fp8 op: {}".format(e)) weight, bias = cast_bias_weight(self, input) return torch.nn.functional.linear(input, weight, bias) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 213b5b92c..4064bdae1 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -995,7 +995,7 @@ class WAN21_T2V(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.Wan21 - memory_usage_factor = 1.0 + memory_usage_factor = 0.9 supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] @@ -1004,7 +1004,7 @@ class WAN21_T2V(supported_models_base.BASE): def __init__(self, unet_config): super().__init__(unet_config) - self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000 + self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2222 def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21(self, device=device) @@ -1096,6 +1096,19 @@ class WAN22_S2V(WAN21_T2V): out = model_base.WAN22_S2V(self, device=device) return out +class WAN22_Animate(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "animate", + } + + def __init__(self, unet_config): + super().__init__(unet_config) + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN22_Animate(self, device=device) + return out + class WAN22_T2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1361,6 +1374,6 @@ class HunyuanImage21Refiner(HunyuanVideo): out = model_base.HunyuanImage21Refiner(self, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 5e11956b5..c5a48ba9f 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -400,21 +400,25 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module): def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): grid = None + position_ids = None + offset = 0 for e in embeds_info: if e.get("type") == "image": grid = e.get("extra", None) - position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) start = e.get("index") - position_ids[:, :start] = torch.arange(0, start, device=embeds.device) + if position_ids is None: + position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) + position_ids[:, :start] = torch.arange(0, start, device=embeds.device) end = e.get("size") + start len_max = int(grid.max()) // 2 start_next = len_max + start - position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device) - position_ids[0, start:end] = start + position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device) + position_ids[0, start:end] = start + offset max_d = int(grid[0][1]) // 2 - position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] + position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] max_d = int(grid[0][2]) // 2 - position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start] + position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start] + offset += len_max - (end - start) if grid is None: position_ids = None diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index 55c97a3af..0abb2d403 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -130,12 +130,12 @@ class LoHaAdapter(WeightAdapterBase): def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] in_dim = weight.shape[1:].numel() - mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) - mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32) + mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.normal_(mat1, 0.1) torch.nn.init.constant_(mat2, 0.0) - mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) - mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32) + mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.normal_(mat3, 0.1) torch.nn.init.normal_(mat4, 0.01) return LohaDiff( diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index 563c835f5..9b2aff2d7 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -89,8 +89,8 @@ class LoKrAdapter(WeightAdapterBase): in_dim = weight.shape[1:].numel() out1, out2 = factorization(out_dim, rank) in1, in2 = factorization(in_dim, rank) - mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype) - mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype) + mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32) + mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32) torch.nn.init.kaiming_uniform_(mat2, a=5**0.5) torch.nn.init.constant_(mat1, 0.0) return LokrDiff( diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 47aa17d13..4db004e50 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -66,8 +66,8 @@ class LoRAAdapter(WeightAdapterBase): def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] in_dim = weight.shape[1:].numel() - mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) - mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32) + mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) torch.nn.init.constant_(mat2, 0.0) return LoraDiff( diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py index 9d4982083..c0aab9635 100644 --- a/comfy/weight_adapter/oft.py +++ b/comfy/weight_adapter/oft.py @@ -68,7 +68,7 @@ class OFTAdapter(WeightAdapterBase): def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] block_size, block_num = factorization(out_dim, rank) - block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype) + block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32) return OFTDiff( (block, None, alpha, None) ) diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py index 4ad0b783b..0aed906fb 100644 --- a/comfy_api_nodes/apis/client.py +++ b/comfy_api_nodes/apis/client.py @@ -683,7 +683,7 @@ class SynchronousOperation(Generic[T, R]): auth_token: Optional[str] = None, comfy_api_key: Optional[str] = None, auth_kwargs: Optional[Dict[str, str]] = None, - timeout: float = 604800.0, + timeout: float = 7200.0, verify_ssl: bool = True, content_type: str = "application/json", multipart_parser: Callable | None = None, diff --git a/comfy_api_nodes/apis/rodin_api.py b/comfy_api_nodes/apis/rodin_api.py index b0cf171fa..02cf42c29 100644 --- a/comfy_api_nodes/apis/rodin_api.py +++ b/comfy_api_nodes/apis/rodin_api.py @@ -9,8 +9,9 @@ class Rodin3DGenerateRequest(BaseModel): seed: int = Field(..., description="seed_") tier: str = Field(..., description="Tier of generation.") material: str = Field(..., description="The material type.") - quality: str = Field(..., description="The generation quality of the mesh.") + quality_override: int = Field(..., description="The poly count of the mesh.") mesh_mode: str = Field(..., description="It controls the type of faces of generated models.") + TAPose: Optional[bool] = Field(None, description="") class GenerateJobsData(BaseModel): uuids: List[str] = Field(..., description="str LIST") diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 369a3a4fe..a7eeaf15a 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -567,6 +567,12 @@ class ByteDanceSeedreamNode(comfy_io.ComfyNode): tooltip="Whether to add an \"AI generated\" watermark to the image.", optional=True, ), + comfy_io.Boolean.Input( + "fail_on_partial", + default=True, + tooltip="If enabled, abort execution if any requested images are missing or return an error.", + optional=True, + ), ], outputs=[ comfy_io.Image.Output(), @@ -592,6 +598,7 @@ class ByteDanceSeedreamNode(comfy_io.ComfyNode): max_images: int = 1, seed: int = 0, watermark: bool = True, + fail_on_partial: bool = True, ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) w = h = None @@ -651,9 +658,10 @@ class ByteDanceSeedreamNode(comfy_io.ComfyNode): if len(response.data) == 1: return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) - return comfy_io.NodeOutput( - torch.cat([await download_url_to_image_tensor(str(i["url"])) for i in response.data]) - ) + urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d] + if fail_on_partial and len(urls) < len(response.data): + raise RuntimeError(f"Only {len(urls)} of {len(response.data)} images were generated before error.") + return comfy_io.NodeOutput(torch.cat([await download_url_to_image_tensor(i) for i in urls])) class ByteDanceTextToVideoNode(comfy_io.ComfyNode): @@ -1171,7 +1179,7 @@ async def process_video_task( payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], auth_kwargs: dict, node_id: str, - estimated_duration: int | None, + estimated_duration: Optional[int], ) -> comfy_io.NodeOutput: initial_response = await SynchronousOperation( endpoint=ApiEndpoint( diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index c89d087e5..1af393eba 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -121,10 +121,10 @@ class Rodin3DAPI: else: return "Generating" - async def create_generate_task(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs): + async def create_generate_task(self, images=None, seed=1, material="PBR", quality_override=18000, tier="Regular", mesh_mode="Quad", TAPose = False, **kwargs): if images is None: raise Exception("Rodin 3D generate requires at least 1 image.") - if len(images) >= 5: + if len(images) > 5: raise Exception("Rodin 3D generate requires up to 5 image.") path = "/proxy/rodin/api/v2/rodin" @@ -139,8 +139,9 @@ class Rodin3DAPI: seed=seed, tier=tier, material=material, - quality=quality, - mesh_mode=mesh_mode + quality_override=quality_override, + mesh_mode=mesh_mode, + TAPose=TAPose, ), files=[ ( @@ -211,23 +212,36 @@ class Rodin3DAPI: return await operation.execute() def get_quality_mode(self, poly_count): - if poly_count == "200K-Triangle": + polycount = poly_count.split("-") + poly = polycount[1] + count = polycount[0] + if poly == "Triangle": mesh_mode = "Raw" - quality = "medium" + elif poly == "Quad": + mesh_mode = "Quad" else: mesh_mode = "Quad" - if poly_count == "4K-Quad": - quality = "extra-low" - elif poly_count == "8K-Quad": - quality = "low" - elif poly_count == "18K-Quad": - quality = "medium" - elif poly_count == "50K-Quad": - quality = "high" - else: - quality = "medium" - return mesh_mode, quality + if count == "4K": + quality_override = 4000 + elif count == "8K": + quality_override = 8000 + elif count == "18K": + quality_override = 18000 + elif count == "50K": + quality_override = 50000 + elif count == "2K": + quality_override = 2000 + elif count == "20K": + quality_override = 20000 + elif count == "150K": + quality_override = 150000 + elif count == "500K": + quality_override = 500000 + else: + quality_override = 18000 + + return mesh_mode, quality_override async def download_files(self, url_list): save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) @@ -300,9 +314,9 @@ class Rodin3D_Regular(Rodin3DAPI): m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.get_quality_mode(Polygon_count) + mesh_mode, quality_override = self.get_quality_mode(Polygon_count) task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality=quality, tier=tier, mesh_mode=mesh_mode, + quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, **kwargs) await self.poll_for_task_status(subscription_key, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs) @@ -346,9 +360,9 @@ class Rodin3D_Detail(Rodin3DAPI): m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.get_quality_mode(Polygon_count) + mesh_mode, quality_override = self.get_quality_mode(Polygon_count) task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality=quality, tier=tier, mesh_mode=mesh_mode, + quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, **kwargs) await self.poll_for_task_status(subscription_key, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs) @@ -392,9 +406,9 @@ class Rodin3D_Smooth(Rodin3DAPI): m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.get_quality_mode(Polygon_count) + mesh_mode, quality_override = self.get_quality_mode(Polygon_count) task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality=quality, tier=tier, mesh_mode=mesh_mode, + quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, **kwargs) await self.poll_for_task_status(subscription_key, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs) @@ -446,10 +460,10 @@ class Rodin3D_Sketch(Rodin3DAPI): for i in range(num_images): m_images.append(Images[i]) material_type = "PBR" - quality = "medium" + quality_override = 18000 mesh_mode = "Quad" task_uuid, subscription_key = await self.create_generate_task( - images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs + images=m_images, seed=Seed, material=material_type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, **kwargs ) await self.poll_for_task_status(subscription_key, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs) @@ -457,6 +471,80 @@ class Rodin3D_Sketch(Rodin3DAPI): return (model,) +class Rodin3D_Gen2(Rodin3DAPI): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "Images": + ( + IO.IMAGE, + { + "forceInput":True, + } + ) + }, + "optional": { + "Seed": ( + IO.INT, + { + "default":0, + "min":0, + "max":65535, + "display":"number" + } + ), + "Material_Type": ( + IO.COMBO, + { + "options": ["PBR", "Shaded"], + "default": "PBR" + } + ), + "Polygon_count": ( + IO.COMBO, + { + "options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"], + "default": "500K-Triangle" + } + ), + "TAPose": ( + IO.BOOLEAN, + { + "default": False, + } + ) + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + }, + } + + async def api_call( + self, + Images, + Seed, + Material_Type, + Polygon_count, + TAPose, + **kwargs + ): + tier = "Gen-2" + num_images = Images.shape[0] + m_images = [] + for i in range(num_images): + m_images.append(Images[i]) + mesh_mode, quality_override = self.get_quality_mode(Polygon_count) + task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, + quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, TAPose=TAPose, + **kwargs) + await self.poll_for_task_status(subscription_key, **kwargs) + download_list = await self.get_rodin_download_list(task_uuid, **kwargs) + model = await self.download_files(download_list) + + return (model,) + # A dictionary that contains all nodes you want to export with their names # NOTE: names should be globally unique NODE_CLASS_MAPPINGS = { @@ -464,6 +552,7 @@ NODE_CLASS_MAPPINGS = { "Rodin3D_Detail": Rodin3D_Detail, "Rodin3D_Smooth": Rodin3D_Smooth, "Rodin3D_Sketch": Rodin3D_Sketch, + "Rodin3D_Gen2": Rodin3D_Gen2, } # A dictionary that contains the friendly/humanly readable titles for the nodes @@ -472,4 +561,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "Rodin3D_Detail": "Rodin 3D Generate - Detail Generate", "Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate", "Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate", + "Rodin3D_Gen2": "Rodin 3D Generate - Gen-2 Generate", } diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py new file mode 100644 index 000000000..db5bd41c1 --- /dev/null +++ b/comfy_api_nodes/nodes_wan.py @@ -0,0 +1,602 @@ +import re +from typing import Optional, Type, Union +from typing_extensions import override + +import torch +from pydantic import BaseModel, Field +from comfy_api.latest import ComfyExtension, Input, io as comfy_io +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, + PollingOperation, + EmptyRequest, + R, + T, +) +from comfy_api_nodes.util.validation_utils import get_number_of_images, validate_audio_duration + +from comfy_api_nodes.apinode_utils import ( + download_url_to_image_tensor, + download_url_to_video_output, + tensor_to_base64_string, + audio_to_base64_string, +) + +class Text2ImageInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: Optional[str] = Field(None) + + +class Text2VideoInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: Optional[str] = Field(None) + audio_url: Optional[str] = Field(None) + + +class Image2VideoInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: Optional[str] = Field(None) + img_url: str = Field(...) + audio_url: Optional[str] = Field(None) + + +class Txt2ImageParametersField(BaseModel): + size: str = Field(...) + n: int = Field(1, description="Number of images to generate.") # we support only value=1 + seed: int = Field(..., ge=0, le=2147483647) + prompt_extend: bool = Field(True) + watermark: bool = Field(True) + + +class Text2VideoParametersField(BaseModel): + size: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + duration: int = Field(5, ge=5, le=10) + prompt_extend: bool = Field(True) + watermark: bool = Field(True) + audio: bool = Field(False, description="Should be audio generated automatically") + + +class Image2VideoParametersField(BaseModel): + resolution: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + duration: int = Field(5, ge=5, le=10) + prompt_extend: bool = Field(True) + watermark: bool = Field(True) + audio: bool = Field(False, description="Should be audio generated automatically") + + +class Text2ImageTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Text2ImageInputField = Field(...) + parameters: Txt2ImageParametersField = Field(...) + + +class Text2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Text2VideoInputField = Field(...) + parameters: Text2VideoParametersField = Field(...) + + +class Image2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Image2VideoInputField = Field(...) + parameters: Image2VideoParametersField = Field(...) + + +class TaskCreationOutputField(BaseModel): + task_id: str = Field(...) + task_status: str = Field(...) + + +class TaskCreationResponse(BaseModel): + output: Optional[TaskCreationOutputField] = Field(None) + request_id: str = Field(...) + code: Optional[str] = Field(None, description="The error code of the failed request.") + message: Optional[str] = Field(None, description="Details of the failed request.") + + +class TaskResult(BaseModel): + url: Optional[str] = Field(None) + code: Optional[str] = Field(None) + message: Optional[str] = Field(None) + + +class ImageTaskStatusOutputField(TaskCreationOutputField): + task_id: str = Field(...) + task_status: str = Field(...) + results: Optional[list[TaskResult]] = Field(None) + + +class VideoTaskStatusOutputField(TaskCreationOutputField): + task_id: str = Field(...) + task_status: str = Field(...) + video_url: Optional[str] = Field(None) + code: Optional[str] = Field(None) + message: Optional[str] = Field(None) + + +class ImageTaskStatusResponse(BaseModel): + output: Optional[ImageTaskStatusOutputField] = Field(None) + request_id: str = Field(...) + + +class VideoTaskStatusResponse(BaseModel): + output: Optional[VideoTaskStatusOutputField] = Field(None) + request_id: str = Field(...) + + +RES_IN_PARENS = re.compile(r'\((\d+)\s*[x×]\s*(\d+)\)') + + +async def process_task( + auth_kwargs: dict[str, str], + url: str, + request_model: Type[T], + response_model: Type[R], + payload: Union[Text2ImageTaskCreationRequest, Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], + node_id: str, + estimated_duration: int, + poll_interval: int, +) -> Type[R]: + initial_response = await SynchronousOperation( + endpoint=ApiEndpoint( + path=url, + method=HttpMethod.POST, + request_model=request_model, + response_model=TaskCreationResponse, + ), + request=payload, + auth_kwargs=auth_kwargs, + ).execute() + + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + + return await PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=response_model, + ), + completed_statuses=["SUCCEEDED"], + failed_statuses=["FAILED", "CANCELED", "UNKNOWN"], + status_extractor=lambda x: x.output.task_status, + estimated_duration=estimated_duration, + poll_interval=poll_interval, + node_id=node_id, + auth_kwargs=auth_kwargs, + ).execute() + + +class WanTextToImageApi(comfy_io.ComfyNode): + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="WanTextToImageApi", + display_name="Wan Text to Image", + category="api node/image/Wan", + description="Generates image based on text prompt.", + inputs=[ + comfy_io.Combo.Input( + "model", + options=["wan2.5-t2i-preview"], + default="wan2.5-t2i-preview", + tooltip="Model to use.", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + ), + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid.", + optional=True, + ), + comfy_io.Int.Input( + "width", + default=1024, + min=768, + max=1440, + step=32, + optional=True, + ), + comfy_io.Int.Input( + "height", + default=1024, + min=768, + max=1440, + step=32, + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "prompt_extend", + default=True, + tooltip="Whether to enhance the prompt with AI assistance.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the result.", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + negative_prompt: str = "", + width: int = 1024, + height: int = 1024, + seed: int = 0, + prompt_extend: bool = True, + watermark: bool = True, + ): + payload = Text2ImageTaskCreationRequest( + model=model, + input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt), + parameters=Txt2ImageParametersField( + size=f"{width}*{height}", + seed=seed, + prompt_extend=prompt_extend, + watermark=watermark, + ), + ) + response = await process_task( + { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + "/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", + request_model=Text2ImageTaskCreationRequest, + response_model=ImageTaskStatusResponse, + payload=payload, + node_id=cls.hidden.unique_id, + estimated_duration=9, + poll_interval=3, + ) + return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url))) + + +class WanTextToVideoApi(comfy_io.ComfyNode): + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="WanTextToVideoApi", + display_name="Wan Text to Video", + category="api node/video/Wan", + description="Generates video based on text prompt.", + inputs=[ + comfy_io.Combo.Input( + "model", + options=["wan2.5-t2v-preview"], + default="wan2.5-t2v-preview", + tooltip="Model to use.", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + ), + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid.", + optional=True, + ), + comfy_io.Combo.Input( + "size", + options=[ + "480p: 1:1 (624x624)", + "480p: 16:9 (832x480)", + "480p: 9:16 (480x832)", + "720p: 1:1 (960x960)", + "720p: 16:9 (1280x720)", + "720p: 9:16 (720x1280)", + "720p: 4:3 (1088x832)", + "720p: 3:4 (832x1088)", + "1080p: 1:1 (1440x1440)", + "1080p: 16:9 (1920x1080)", + "1080p: 9:16 (1080x1920)", + "1080p: 4:3 (1632x1248)", + "1080p: 3:4 (1248x1632)", + ], + default="480p: 1:1 (624x624)", + optional=True, + ), + comfy_io.Int.Input( + "duration", + default=5, + min=5, + max=10, + step=5, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Available durations: 5 and 10 seconds", + optional=True, + ), + comfy_io.Audio.Input( + "audio", + optional=True, + tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="If there is no audio input, generate audio automatically.", + ), + comfy_io.Boolean.Input( + "prompt_extend", + default=True, + tooltip="Whether to enhance the prompt with AI assistance.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the result.", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + negative_prompt: str = "", + size: str = "480p: 1:1 (624x624)", + duration: int = 5, + audio: Optional[Input.Audio] = None, + seed: int = 0, + generate_audio: bool = False, + prompt_extend: bool = True, + watermark: bool = True, + ): + width, height = RES_IN_PARENS.search(size).groups() + audio_url = None + if audio is not None: + validate_audio_duration(audio, 3.0, 29.0) + audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") + payload = Text2VideoTaskCreationRequest( + model=model, + input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url), + parameters=Text2VideoParametersField( + size=f"{width}*{height}", + duration=duration, + seed=seed, + audio=generate_audio, + prompt_extend=prompt_extend, + watermark=watermark, + ), + ) + response = await process_task( + { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + "/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", + request_model=Text2VideoTaskCreationRequest, + response_model=VideoTaskStatusResponse, + payload=payload, + node_id=cls.hidden.unique_id, + estimated_duration=120 * int(duration / 5), + poll_interval=6, + ) + return comfy_io.NodeOutput(await download_url_to_video_output(response.output.video_url)) + + +class WanImageToVideoApi(comfy_io.ComfyNode): + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="WanImageToVideoApi", + display_name="Wan Image to Video", + category="api node/video/Wan", + description="Generates video based on the first frame and text prompt.", + inputs=[ + comfy_io.Combo.Input( + "model", + options=["wan2.5-i2v-preview"], + default="wan2.5-i2v-preview", + tooltip="Model to use.", + ), + comfy_io.Image.Input( + "image", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + ), + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid.", + optional=True, + ), + comfy_io.Combo.Input( + "resolution", + options=[ + "480P", + "720P", + "1080P", + ], + default="480P", + optional=True, + ), + comfy_io.Int.Input( + "duration", + default=5, + min=5, + max=10, + step=5, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Available durations: 5 and 10 seconds", + optional=True, + ), + comfy_io.Audio.Input( + "audio", + optional=True, + tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="If there is no audio input, generate audio automatically.", + ), + comfy_io.Boolean.Input( + "prompt_extend", + default=True, + tooltip="Whether to enhance the prompt with AI assistance.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the result.", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: torch.Tensor, + prompt: str, + negative_prompt: str = "", + resolution: str = "480P", + duration: int = 5, + audio: Optional[Input.Audio] = None, + seed: int = 0, + generate_audio: bool = False, + prompt_extend: bool = True, + watermark: bool = True, + ): + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000*2000) + audio_url = None + if audio is not None: + validate_audio_duration(audio, 3.0, 29.0) + audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") + payload = Image2VideoTaskCreationRequest( + model=model, + input=Image2VideoInputField( + prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url + ), + parameters=Image2VideoParametersField( + resolution=resolution, + duration=duration, + seed=seed, + audio=generate_audio, + prompt_extend=prompt_extend, + watermark=watermark, + ), + ) + response = await process_task( + { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + "/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", + request_model=Image2VideoTaskCreationRequest, + response_model=VideoTaskStatusResponse, + payload=payload, + node_id=cls.hidden.unique_id, + estimated_duration=120 * int(duration / 5), + poll_interval=6, + ) + return comfy_io.NodeOutput(await download_url_to_video_output(response.output.video_url)) + + +class WanApiExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + WanTextToImageApi, + WanTextToVideoApi, + WanImageToVideoApi, + ] + + +async def comfy_entrypoint() -> WanApiExtension: + return WanApiExtension() diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 3b23f65d8..51c8b9dd9 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -11,6 +11,7 @@ import json import random import hashlib import node_helpers +import logging from comfy.cli_args import args from comfy.comfy_types import FileLocator @@ -364,6 +365,216 @@ class RecordAudio: return (audio, ) +class TrimAudioDuration: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "audio": ("AUDIO",), + "start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}), + "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}), + }, + } + + FUNCTION = "trim" + RETURN_TYPES = ("AUDIO",) + CATEGORY = "audio" + DESCRIPTION = "Trim audio tensor into chosen time range." + + def trim(self, audio, start_index, duration): + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + audio_length = waveform.shape[-1] + + if start_index < 0: + start_frame = audio_length + int(round(start_index * sample_rate)) + else: + start_frame = int(round(start_index * sample_rate)) + start_frame = max(0, min(start_frame, audio_length - 1)) + + end_frame = start_frame + int(round(duration * sample_rate)) + end_frame = max(0, min(end_frame, audio_length)) + + if start_frame >= end_frame: + raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.") + + return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},) + + +class SplitAudioChannels: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "audio": ("AUDIO",), + }} + + RETURN_TYPES = ("AUDIO", "AUDIO") + RETURN_NAMES = ("left", "right") + FUNCTION = "separate" + CATEGORY = "audio" + DESCRIPTION = "Separates the audio into left and right channels." + + def separate(self, audio): + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + + if waveform.shape[1] != 2: + raise ValueError("AudioSplit: Input audio has only one channel.") + + left_channel = waveform[..., 0:1, :] + right_channel = waveform[..., 1:2, :] + + return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate}) + + +def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2): + if sample_rate_1 != sample_rate_2: + if sample_rate_1 > sample_rate_2: + waveform_2 = torchaudio.functional.resample(waveform_2, sample_rate_2, sample_rate_1) + output_sample_rate = sample_rate_1 + logging.info(f"Resampling audio2 from {sample_rate_2}Hz to {sample_rate_1}Hz for merging.") + else: + waveform_1 = torchaudio.functional.resample(waveform_1, sample_rate_1, sample_rate_2) + output_sample_rate = sample_rate_2 + logging.info(f"Resampling audio1 from {sample_rate_1}Hz to {sample_rate_2}Hz for merging.") + else: + output_sample_rate = sample_rate_1 + return waveform_1, waveform_2, output_sample_rate + + +class AudioConcat: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "audio1": ("AUDIO",), + "audio2": ("AUDIO",), + "direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}), + }} + + RETURN_TYPES = ("AUDIO",) + FUNCTION = "concat" + CATEGORY = "audio" + DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction." + + def concat(self, audio1, audio2, direction): + waveform_1 = audio1["waveform"] + waveform_2 = audio2["waveform"] + sample_rate_1 = audio1["sample_rate"] + sample_rate_2 = audio2["sample_rate"] + + if waveform_1.shape[1] == 1: + waveform_1 = waveform_1.repeat(1, 2, 1) + logging.info("AudioConcat: Converted mono audio1 to stereo by duplicating the channel.") + if waveform_2.shape[1] == 1: + waveform_2 = waveform_2.repeat(1, 2, 1) + logging.info("AudioConcat: Converted mono audio2 to stereo by duplicating the channel.") + + waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2) + + if direction == 'after': + concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2) + elif direction == 'before': + concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2) + + return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},) + + +class AudioMerge: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "audio1": ("AUDIO",), + "audio2": ("AUDIO",), + "merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}), + }, + } + + FUNCTION = "merge" + RETURN_TYPES = ("AUDIO",) + CATEGORY = "audio" + DESCRIPTION = "Combine two audio tracks by overlaying their waveforms." + + def merge(self, audio1, audio2, merge_method): + waveform_1 = audio1["waveform"] + waveform_2 = audio2["waveform"] + sample_rate_1 = audio1["sample_rate"] + sample_rate_2 = audio2["sample_rate"] + + waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2) + + length_1 = waveform_1.shape[-1] + length_2 = waveform_2.shape[-1] + + if length_2 > length_1: + logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.") + waveform_2 = waveform_2[..., :length_1] + elif length_2 < length_1: + logging.info(f"AudioMerge: Padding audio2 from {length_2} to {length_1} samples to match audio1 length.") + pad_shape = list(waveform_2.shape) + pad_shape[-1] = length_1 - length_2 + pad_tensor = torch.zeros(pad_shape, dtype=waveform_2.dtype, device=waveform_2.device) + waveform_2 = torch.cat((waveform_2, pad_tensor), dim=-1) + + if merge_method == "add": + waveform = waveform_1 + waveform_2 + elif merge_method == "subtract": + waveform = waveform_1 - waveform_2 + elif merge_method == "multiply": + waveform = waveform_1 * waveform_2 + elif merge_method == "mean": + waveform = (waveform_1 + waveform_2) / 2 + + max_val = waveform.abs().max() + if max_val > 1.0: + waveform = waveform / max_val + + return ({"waveform": waveform, "sample_rate": output_sample_rate},) + + +class AudioAdjustVolume: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "audio": ("AUDIO",), + "volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}), + }} + + RETURN_TYPES = ("AUDIO",) + FUNCTION = "adjust_volume" + CATEGORY = "audio" + + def adjust_volume(self, audio, volume): + if volume == 0: + return (audio,) + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + + gain = 10 ** (volume / 20) + waveform = waveform * gain + + return ({"waveform": waveform, "sample_rate": sample_rate},) + + +class EmptyAudio: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}), + "sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}), + "channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}), + }} + + RETURN_TYPES = ("AUDIO",) + FUNCTION = "create_empty_audio" + CATEGORY = "audio" + + def create_empty_audio(self, duration, sample_rate, channels): + num_samples = int(round(duration * sample_rate)) + waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32) + return ({"waveform": waveform, "sample_rate": sample_rate},) + + NODE_CLASS_MAPPINGS = { "EmptyLatentAudio": EmptyLatentAudio, "VAEEncodeAudio": VAEEncodeAudio, @@ -375,6 +586,12 @@ NODE_CLASS_MAPPINGS = { "PreviewAudio": PreviewAudio, "ConditioningStableAudio": ConditioningStableAudio, "RecordAudio": RecordAudio, + "TrimAudioDuration": TrimAudioDuration, + "SplitAudioChannels": SplitAudioChannels, + "AudioConcat": AudioConcat, + "AudioMerge": AudioMerge, + "AudioAdjustVolume": AudioAdjustVolume, + "EmptyAudio": EmptyAudio, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -387,4 +604,10 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SaveAudioMP3": "Save Audio (MP3)", "SaveAudioOpus": "Save Audio (Opus)", "RecordAudio": "Record Audio", + "TrimAudioDuration": "Trim Audio Duration", + "SplitAudioChannels": "Split Audio Channels", + "AudioConcat": "Audio Concat", + "AudioMerge": "Audio Merge", + "AudioAdjustVolume": "Audio Adjust Volume", + "EmptyAudio": "Empty Audio", } diff --git a/comfy_extras/nodes_differential_diffusion.py b/comfy_extras/nodes_differential_diffusion.py index 98dbbf102..255ac420d 100644 --- a/comfy_extras/nodes_differential_diffusion.py +++ b/comfy_extras/nodes_differential_diffusion.py @@ -5,19 +5,30 @@ import torch class DifferentialDiffusion(): @classmethod def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - }} + return { + "required": { + "model": ("MODEL", ), + }, + "optional": { + "strength": ("FLOAT", { + "default": 1.0, + "min": 0.0, + "max": 1.0, + "step": 0.01, + }), + } + } RETURN_TYPES = ("MODEL",) FUNCTION = "apply" CATEGORY = "_for_testing" INIT = False - def apply(self, model): + def apply(self, model, strength=1.0): model = model.clone() - model.set_model_denoise_mask_function(self.forward) - return (model,) + model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength)) + return (model, ) - def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): + def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float): model = extra_options["model"] step_sigmas = extra_options["sigmas"] sigma_to = model.inner_model.model_sampling.sigma_min @@ -31,7 +42,15 @@ class DifferentialDiffusion(): threshold = (current_ts - ts_to) / (ts_from - ts_to) - return (denoise_mask >= threshold).to(denoise_mask.dtype) + # Generate the binary mask based on the threshold + binary_mask = (denoise_mask >= threshold).to(denoise_mask.dtype) + + # Blend binary mask with the original denoise_mask using strength + if strength and strength < 1: + blended_mask = strength * binary_mask + (1 - strength) * denoise_mask + return blended_mask + else: + return binary_mask NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index cb1a0d883..ed7a07152 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -233,6 +233,7 @@ class Sharpen: kernel_size = sharpen_radius * 2 + 1 kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10) + kernel = kernel.to(dtype=image.dtype) center = kernel_size // 2 kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index fff89556f..49747dc7a 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -43,6 +43,61 @@ class TextEncodeQwenImageEdit: return (conditioning, ) +class TextEncodeQwenImageEditPlus: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip": ("CLIP", ), + "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), + }, + "optional": {"vae": ("VAE", ), + "image1": ("IMAGE", ), + "image2": ("IMAGE", ), + "image3": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, prompt, vae=None, image1=None, image2=None, image3=None): + ref_latents = [] + images = [image1, image2, image3] + images_vl = [] + llama_template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + image_prompt = "" + + for i, image in enumerate(images): + if image is not None: + samples = image.movedim(-1, 1) + total = int(384 * 384) + + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = comfy.utils.common_upscale(samples, width, height, "area", "disabled") + images_vl.append(s.movedim(1, -1)) + if vae is not None: + total = int(1024 * 1024) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by / 8.0) * 8 + height = round(samples.shape[2] * scale_by / 8.0) * 8 + + s = comfy.utils.common_upscale(samples, width, height, "area", "disabled") + ref_latents.append(vae.encode(s.movedim(1, -1)[:, :, :, :3])) + + image_prompt += "Picture {}: <|vision_start|><|image_pad|><|vision_end|>".format(i + 1) + + tokens = clip.tokenize(image_prompt + prompt, images=images_vl, llama_template=llama_template) + conditioning = clip.encode_from_tokens_scheduled(tokens) + if len(ref_latents) > 0: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True) + return (conditioning, ) + + NODE_CLASS_MAPPINGS = { "TextEncodeQwenImageEdit": TextEncodeQwenImageEdit, + "TextEncodeQwenImageEditPlus": TextEncodeQwenImageEditPlus, } diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index c3aaaee9b..9e6ec6780 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -38,6 +38,23 @@ def make_batch_extra_option_dict(d, indicies, full_size=None): return new_dict +def process_cond_list(d, prefix=""): + if hasattr(d, "__iter__") and not hasattr(d, "items"): + for index, item in enumerate(d): + process_cond_list(item, f"{prefix}.{index}") + return d + elif hasattr(d, "items"): + for k, v in list(d.items()): + if isinstance(v, dict): + process_cond_list(v, f"{prefix}.{k}") + elif isinstance(v, torch.Tensor): + d[k] = v.clone() + elif isinstance(v, (list, tuple)): + for index, item in enumerate(v): + process_cond_list(item, f"{prefix}.{k}.{index}") + return d + + class TrainSampler(comfy.samplers.Sampler): def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): self.loss_fn = loss_fn @@ -50,6 +67,7 @@ class TrainSampler(comfy.samplers.Sampler): self.training_dtype = training_dtype def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + model_wrap.conds = process_cond_list(model_wrap.conds) cond = model_wrap.conds["positive"] dataset_size = sigmas.size(0) torch.cuda.empty_cache() diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 0b8b55813..b0bd471bf 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -287,7 +287,6 @@ class WanVaceToVideo(io.ComfyNode): return io.Schema( node_id="WanVaceToVideo", category="conditioning/video_models", - is_experimental=True, inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -375,7 +374,6 @@ class TrimVideoLatent(io.ComfyNode): return io.Schema( node_id="TrimVideoLatent", category="latent/video", - is_experimental=True, inputs=[ io.Latent.Input("samples"), io.Int.Input("trim_amount", default=0, min=0, max=99999), @@ -969,7 +967,6 @@ class WanSoundImageToVideo(io.ComfyNode): io.Conditioning.Output(display_name="negative"), io.Latent.Output(display_name="latent"), ], - is_experimental=True, ) @classmethod @@ -1000,7 +997,6 @@ class WanSoundImageToVideoExtend(io.ComfyNode): io.Conditioning.Output(display_name="negative"), io.Latent.Output(display_name="latent"), ], - is_experimental=True, ) @classmethod @@ -1095,10 +1091,6 @@ class WanHuMoImageToVideo(io.ComfyNode): audio_emb = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280] audio_emb, _ = get_audio_emb_window(audio_emb, length, frame0_idx=0) - # pad for ref latent - zero_audio_pad = torch.zeros(ref_latent.shape[2], *audio_emb.shape[1:], device=audio_emb.device, dtype=audio_emb.dtype) - audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0) - audio_emb = audio_emb.unsqueeze(0) audio_emb_neg = torch.zeros_like(audio_emb) positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_emb}) @@ -1112,6 +1104,146 @@ class WanHuMoImageToVideo(io.ComfyNode): out_latent["samples"] = latent return io.NodeOutput(positive, negative, out_latent) +class WanAnimateToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanAnimateToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("reference_image", optional=True), + io.Image.Input("face_video", optional=True), + io.Image.Input("pose_video", optional=True), + io.Int.Input("continue_motion_max_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Image.Input("background_video", optional=True), + io.Mask.Input("character_mask", optional=True), + io.Image.Input("continue_motion", optional=True), + io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="The amount of frames to seek in all the input videos. Used for generating longer videos by chunk. Connect to the video_frame_offset output of the previous node for extending a video."), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + io.Int.Output(display_name="trim_latent"), + io.Int.Output(display_name="trim_image"), + io.Int.Output(display_name="video_frame_offset"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, video_frame_offset, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None, background_video=None, character_mask=None) -> io.NodeOutput: + trim_to_pose_video = False + latent_length = ((length - 1) // 4) + 1 + latent_width = width // 8 + latent_height = height // 8 + trim_latent = 0 + + if reference_image is None: + reference_image = torch.zeros((1, height, width, 3)) + + image = comfy.utils.common_upscale(reference_image[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + concat_latent_image = vae.encode(image[:, :, :, :3]) + mask = torch.zeros((1, 4, concat_latent_image.shape[-3], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype) + trim_latent += concat_latent_image.shape[2] + ref_motion_latent_length = 0 + + if continue_motion is None: + image = torch.ones((length, height, width, 3)) * 0.5 + else: + continue_motion = continue_motion[-continue_motion_max_frames:] + video_frame_offset -= continue_motion.shape[0] + video_frame_offset = max(0, video_frame_offset) + continue_motion = comfy.utils.common_upscale(continue_motion[-length:].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + image = torch.ones((length, height, width, continue_motion.shape[-1]), device=continue_motion.device, dtype=continue_motion.dtype) * 0.5 + image[:continue_motion.shape[0]] = continue_motion + ref_motion_latent_length += ((continue_motion.shape[0] - 1) // 4) + 1 + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + if pose_video is not None: + if pose_video.shape[0] <= video_frame_offset: + pose_video = None + else: + pose_video = pose_video[video_frame_offset:] + + if pose_video is not None: + pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + if not trim_to_pose_video: + if pose_video.shape[0] < length: + pose_video = torch.cat((pose_video,) + (pose_video[-1:],) * (length - pose_video.shape[0]), dim=0) + + pose_video_latent = vae.encode(pose_video[:, :, :, :3]) + positive = node_helpers.conditioning_set_values(positive, {"pose_video_latent": pose_video_latent}) + negative = node_helpers.conditioning_set_values(negative, {"pose_video_latent": pose_video_latent}) + + if trim_to_pose_video: + latent_length = pose_video_latent.shape[2] + length = latent_length * 4 - 3 + image = image[:length] + + if face_video is not None: + if face_video.shape[0] <= video_frame_offset: + face_video = None + else: + face_video = face_video[video_frame_offset:] + + if face_video is not None: + face_video = comfy.utils.common_upscale(face_video[:length].movedim(-1, 1), 512, 512, "area", "center") * 2.0 - 1.0 + face_video = face_video.movedim(0, 1).unsqueeze(0) + positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video}) + negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0}) + + ref_images_num = max(0, ref_motion_latent_length * 4 - 3) + if background_video is not None: + if background_video.shape[0] > video_frame_offset: + background_video = background_video[video_frame_offset:] + background_video = comfy.utils.common_upscale(background_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + if background_video.shape[0] > ref_images_num: + image[ref_images_num:background_video.shape[0]] = background_video[ref_images_num:] + + mask_refmotion = torch.ones((1, 1, latent_length * 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype) + if continue_motion is not None: + mask_refmotion[:, :, :ref_motion_latent_length * 4] = 0.0 + + if character_mask is not None: + if character_mask.shape[0] > video_frame_offset or character_mask.shape[0] == 1: + if character_mask.shape[0] == 1: + character_mask = character_mask.repeat((length,) + (1,) * (character_mask.ndim - 1)) + else: + character_mask = character_mask[video_frame_offset:] + if character_mask.ndim == 3: + character_mask = character_mask.unsqueeze(1) + character_mask = character_mask.movedim(0, 1) + if character_mask.ndim == 4: + character_mask = character_mask.unsqueeze(1) + character_mask = comfy.utils.common_upscale(character_mask[:, :, :length], concat_latent_image.shape[-1], concat_latent_image.shape[-2], "nearest-exact", "center") + if character_mask.shape[2] > ref_images_num: + mask_refmotion[:, :, ref_images_num:character_mask.shape[2]] = character_mask[:, :, ref_images_num:] + + concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2) + + + mask_refmotion = mask_refmotion.view(1, mask_refmotion.shape[2] // 4, 4, mask_refmotion.shape[3], mask_refmotion.shape[4]).transpose(1, 2) + mask = torch.cat((mask, mask_refmotion), dim=2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + latent = torch.zeros([batch_size, 16, latent_length + trim_latent, latent_height, latent_width], device=comfy.model_management.intermediate_device()) + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent, trim_latent, max(0, ref_motion_latent_length * 4 - 3), video_frame_offset + length) + class Wan22ImageToVideoLatent(io.ComfyNode): @classmethod def define_schema(cls): @@ -1173,6 +1305,7 @@ class WanExtension(ComfyExtension): WanSoundImageToVideo, WanSoundImageToVideoExtend, WanHuMoImageToVideo, + WanAnimateToVideo, Wan22ImageToVideoLatent, ] diff --git a/comfyui_version.py b/comfyui_version.py index ee58205f5..d469a8194 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.59" +__version__ = "0.3.60" diff --git a/nodes.py b/nodes.py index 5a5fdcb8e..1a6784b68 100644 --- a/nodes.py +++ b/nodes.py @@ -2361,6 +2361,7 @@ async def init_builtin_api_nodes(): "nodes_rodin.py", "nodes_gemini.py", "nodes_vidu.py", + "nodes_wan.py", ] if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"): diff --git a/pyproject.toml b/pyproject.toml index a7fc1a5a6..7340c320b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.59" +version = "0.3.60" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index 0d204858b..f9497c93a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -comfyui-frontend-package==1.26.11 -comfyui-workflow-templates==0.1.81 +comfyui-frontend-package==1.26.13 +comfyui-workflow-templates==0.1.86 comfyui-embedded-docs==0.2.6 torch torchsde diff --git a/server.py b/server.py index 3bb6f0bed..d8c1c02c3 100644 --- a/server.py +++ b/server.py @@ -648,7 +648,14 @@ class PromptServer(): max_items = request.rel_url.query.get("max_items", None) if max_items is not None: max_items = int(max_items) - return web.json_response(self.prompt_queue.get_history(max_items=max_items)) + + offset = request.rel_url.query.get("offset", None) + if offset is not None: + offset = int(offset) + else: + offset = -1 + + return web.json_response(self.prompt_queue.get_history(max_items=max_items, offset=offset)) @routes.get("/history/{prompt_id}") async def get_history_prompt_id(request): diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index 8ea05fdd8..ef73ad9fd 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -84,6 +84,21 @@ class ComfyClient: with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response: return json.loads(response.read()) + def get_all_history(self, max_items=None, offset=None): + url = "http://{}/history".format(self.server_address) + params = {} + if max_items is not None: + params["max_items"] = max_items + if offset is not None: + params["offset"] = offset + + if params: + url_values = urllib.parse.urlencode(params) + url = "{}?{}".format(url, url_values) + + with urllib.request.urlopen(url) as response: + return json.loads(response.read()) + def set_test_name(self, name): self.test_name = name @@ -498,7 +513,6 @@ class TestExecution: assert len(images1) == 1, "Should have 1 image" assert len(images2) == 1, "Should have 1 image" - # This tests that only constant outputs are used in the call to `IS_CHANGED` def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder): g = builder @@ -762,3 +776,92 @@ class TestExecution: except urllib.error.HTTPError: pass # Expected behavior + def _create_history_item(self, client, builder): + g = GraphBuilder(prefix="offset_test") + input_node = g.node( + "StubImage", content="BLACK", height=32, width=32, batch_size=1 + ) + g.node("SaveImage", images=input_node.out(0)) + return client.run(g) + + def test_offset_returns_different_items_than_beginning_of_history( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test that offset skips items at the beginning""" + for _ in range(5): + self._create_history_item(client, builder) + + first_two = client.get_all_history(max_items=2, offset=0) + next_two = client.get_all_history(max_items=2, offset=2) + + assert set(first_two.keys()).isdisjoint( + set(next_two.keys()) + ), "Offset should skip initial items" + + def test_offset_beyond_history_length_returns_empty( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset larger than total history returns empty result""" + self._create_history_item(client, builder) + + result = client.get_all_history(offset=100) + assert len(result) == 0, "Large offset should return no items" + + def test_offset_at_exact_history_length_returns_empty( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset equal to history length returns empty""" + for _ in range(3): + self._create_history_item(client, builder) + + all_history = client.get_all_history() + result = client.get_all_history(offset=len(all_history)) + assert len(result) == 0, "Offset at history length should return empty" + + def test_offset_zero_equals_no_offset_parameter( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset=0 behaves same as omitting offset""" + self._create_history_item(client, builder) + + with_zero = client.get_all_history(offset=0) + without_offset = client.get_all_history() + + assert with_zero == without_offset, "offset=0 should equal no offset" + + def test_offset_without_max_items_skips_from_beginning( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset alone (no max_items) returns remaining items""" + for _ in range(4): + self._create_history_item(client, builder) + + all_items = client.get_all_history() + offset_items = client.get_all_history(offset=2) + + assert ( + len(offset_items) == len(all_items) - 2 + ), "Offset should skip specified number of items" + + def test_offset_with_max_items_returns_correct_window( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset + max_items returns correct slice of history""" + for _ in range(6): + self._create_history_item(client, builder) + + window = client.get_all_history(max_items=2, offset=1) + assert len(window) <= 2, "Should respect max_items limit" + + def test_offset_near_end_returns_remaining_items_only( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset near end of history returns only remaining items""" + for _ in range(3): + self._create_history_item(client, builder) + + all_history = client.get_all_history() + # Offset to near the end + result = client.get_all_history(max_items=5, offset=len(all_history) - 1) + + assert len(result) <= 1, "Should return at most 1 item when offset is near end"