From 78b5dec6b6beefb9fb40f917d33d2f10a40d9e53 Mon Sep 17 00:00:00 2001 From: Cezarijus Kivylius Date: Wed, 20 May 2026 12:58:49 +0100 Subject: [PATCH 1/7] fix: Hunyuan3D 2.1 batch size crashes in attention and forward pass (#13699) --- comfy/ldm/hunyuan3dv2_1/hunyuandit.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index f67ba84e9..bc36b8998 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -328,7 +328,7 @@ class CrossAttention(nn.Module): kv = torch.cat((k, v), dim=-1) split_size = kv.shape[-1] // self.num_heads // 2 - kv = kv.view(1, -1, self.num_heads, split_size * 2) + kv = kv.view(b, -1, self.num_heads, split_size * 2) k, v = torch.split(kv, split_size, dim=-1) q = q.view(b, s1, self.num_heads, self.head_dim) @@ -398,7 +398,7 @@ class Attention(nn.Module): qkv_combined = torch.cat((query, key, value), dim=-1) split_size = qkv_combined.shape[-1] // self.num_heads // 3 - qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3) + qkv = qkv_combined.view(B, -1, self.num_heads, split_size * 3) query, key, value = torch.split(qkv, split_size, dim=-1) query = query.reshape(B, N, self.num_heads, self.head_dim) @@ -607,9 +607,9 @@ class HunYuanDiTPlain(nn.Module): def forward(self, x, t, context, transformer_options = {}, **kwargs): x = x.movedim(-1, -2) - uncond_emb, cond_emb = context.chunk(2, dim = 0) - - context = torch.cat([cond_emb, uncond_emb], dim = 0) + if context.shape[0] >= 2: + uncond_emb, cond_emb = context.chunk(2, dim = 0) + context = torch.cat([cond_emb, uncond_emb], dim = 0) main_condition = context t = 1.0 - t @@ -657,5 +657,8 @@ class HunYuanDiTPlain(nn.Module): output = self.final_layer(combined) output = output.movedim(-2, -1) * (-1.0) - cond_emb, uncond_emb = output.chunk(2, dim = 0) - return torch.cat([uncond_emb, cond_emb]) + if output.shape[0] >= 2: + cond_emb, uncond_emb = output.chunk(2, dim = 0) + return torch.cat([uncond_emb, cond_emb]) + else: + return output From f9c84c94b43855ec46e26afc944c06fb121cb9a8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 20 May 2026 08:34:22 -0700 Subject: [PATCH 2/7] Support Stable Audio 3 model. (#14010) --- comfy/latent_formats.py | 5 + comfy/ldm/audio/dit.py | 250 +++++++++++++--- comfy/ldm/audio/embedders.py | 31 +- comfy/ldm/audio/vae_sa3.py | 533 +++++++++++++++++++++++++++++++++++ comfy/model_base.py | 79 ++++++ comfy/model_detection.py | 39 +++ comfy/sd.py | 37 +++ comfy/supported_models.py | 25 ++ comfy/text_encoders/sa3.py | 207 ++++++++++++++ 9 files changed, 1161 insertions(+), 45 deletions(-) create mode 100644 comfy/ldm/audio/vae_sa3.py create mode 100644 comfy/text_encoders/sa3.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 6e37080bb..75d459b59 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -152,6 +152,11 @@ class StableAudio1(LatentFormat): latent_dimensions = 1 temporal_downscale_ratio = 2048 +class StableAudio3(LatentFormat): + latent_channels = 256 + latent_dimensions = 1 + temporal_downscale_ratio = 4096 + class Flux(SD3): latent_channels = 16 def __init__(self): diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index ca865189e..a6258b755 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -10,6 +10,17 @@ from torch import nn from torch.nn import functional as F import math import comfy.ops +from .embedders import ExpoFourierFeatures + + +def _left_pad_to_match(emb, target_len): + emb_len = emb.shape[-2] + if emb_len < target_len: + return F.pad(emb, (0, 0, target_len - emb_len, 0), value=0.) + elif emb_len > target_len: + return emb[:, -target_len:, :] + return emb + class FourierFeatures(nn.Module): def __init__(self, in_features, out_features, std=1., dtype=None, device=None): @@ -22,6 +33,7 @@ class FourierFeatures(nn.Module): f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input) return torch.cat([f.cos(), f.sin()], dim=-1) + # norms class LayerNorm(nn.Module): def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None): @@ -43,6 +55,16 @@ class LayerNorm(nn.Module): beta = comfy.ops.cast_to_input(beta, x) return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta) + +class RMSNorm(nn.Module): + def __init__(self, dim, dtype=None, device=None): + super().__init__() + self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) + + def forward(self, x): + return F.rms_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x)) + + class GLU(nn.Module): def __init__( self, @@ -236,13 +258,6 @@ class FeedForward(nn.Module): linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device) - # # init last linear layer to 0 - # if zero_init_output: - # nn.init.zeros_(linear_out.weight) - # if not no_bias: - # nn.init.zeros_(linear_out.bias) - - self.ff = nn.Sequential( linear_in, rearrange('b d n -> b n d') if use_conv else nn.Identity(), @@ -261,8 +276,10 @@ class Attention(nn.Module): dim_context = None, causal = False, zero_init_output=True, - qk_norm = False, + qk_norm = "none", + differential = False, natten_kernel_size = None, + feat_scale = False, dtype=None, device=None, operations=None, @@ -271,6 +288,7 @@ class Attention(nn.Module): self.dim = dim self.dim_heads = dim_heads self.causal = causal + self.differential = differential dim_kv = dim_context if dim_context is not None else dim @@ -278,18 +296,37 @@ class Attention(nn.Module): self.kv_heads = dim_kv // dim_heads if dim_context is not None: - self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) - self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device) + if differential: + self.to_q = operations.Linear(dim, dim * 2, bias=False, dtype=dtype, device=device) + self.to_kv = operations.Linear(dim_kv, dim_kv * 3, bias=False, dtype=dtype, device=device) + else: + self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device) else: - self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device) + if differential: + self.to_qkv = operations.Linear(dim, dim * 5, bias=False, dtype=dtype, device=device) + else: + self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device) self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) - # if zero_init_output: - # nn.init.zeros_(self.to_out.weight) - + # Accept bool for backward compat + if isinstance(qk_norm, bool): + qk_norm = "l2" if qk_norm else "none" self.qk_norm = qk_norm + if self.qk_norm == "ln": + self.q_norm = operations.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) + self.k_norm = operations.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) + elif self.qk_norm == "rms": + self.q_norm = RMSNorm(dim_heads, dtype=dtype, device=device) + self.k_norm = RMSNorm(dim_heads, dtype=dtype, device=device) + + self.feat_scale = feat_scale + + if self.feat_scale: + self.lambda_dc = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) + self.lambda_hf = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) def forward( self, @@ -306,22 +343,51 @@ class Attention(nn.Module): kv_input = context if has_context else x if hasattr(self, 'to_q'): - # Use separate linear projections for q and k/v - q = self.to_q(x) - q = rearrange(q, 'b n (h d) -> b h n d', h = h) + if self.differential: + # cross-attention differential: to_q → (q, q_diff), to_kv → (k, k_diff, v) + q, q_diff = self.to_q(x).chunk(2, dim=-1) + q = rearrange(q, 'b n (h d) -> b h n d', h=h) + q_diff = rearrange(q_diff, 'b n (h d) -> b h n d', h=h) + q = torch.stack([q, q_diff], dim=1) # (B, 2, H, N, D) + k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1) + k = rearrange(k, 'b n (h d) -> b h n d', h=kv_h) + k_diff = rearrange(k_diff, 'b n (h d) -> b h n d', h=kv_h) + v = rearrange(v, 'b n (h d) -> b h n d', h=kv_h) + k = torch.stack([k, k_diff], dim=1) # (B, 2, H, M, D) + else: + # Use separate linear projections for q and k/v + q = self.to_q(x) + q = rearrange(q, 'b n (h d) -> b h n d', h = h) - k, v = self.to_kv(kv_input).chunk(2, dim=-1) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) - k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) + k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) else: - # Use fused linear projection - q, k, v = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + if self.differential: + # self-attention differential: to_qkv → (q, k, v, q_diff, k_diff) + q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1) + q, k, v, q_diff, k_diff = map( + lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), + (q, k, v, q_diff, k_diff) + ) + q = torch.stack([q, q_diff], dim=1) # (B, 2, H, N, D) + k = torch.stack([k, k_diff], dim=1) + else: + # Use fused linear projection + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) # Normalize q and k for cosine sim attention - if self.qk_norm: + if self.qk_norm == "l2": q = F.normalize(q, dim=-1) k = F.normalize(k, dim=-1) + elif self.qk_norm == "rms": + q_type, k_type = q.dtype, k.dtype + q = self.q_norm(q).to(q_type) + k = self.k_norm(k).to(k_type) + elif self.qk_norm != 'none': + q = self.q_norm(q) + k = self.k_norm(k) if rotary_pos_emb is not None and not has_context: freqs, _ = rotary_pos_emb @@ -364,9 +430,24 @@ class Attention(nn.Module): heads_per_kv_head = h // kv_h k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) - out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) + if self.differential: + q, q_diff = q.unbind(dim=1) + k, k_diff = k.unbind(dim=1) + out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) + out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, transformer_options=transformer_options) + out = out - out_diff + else: + out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) + out = self.to_out(out) + if self.feat_scale: + out_dc = out.mean(dim=-2, keepdim=True) + out_hf = out - out_dc + + # Selectively modulate DC and high frequency components + out = out + comfy.ops.cast_to_input(self.lambda_dc, out) * out_dc + comfy.ops.cast_to_input(self.lambda_hf, out) * out_hf + if mask is not None: mask = rearrange(mask, 'b n -> b n 1') out = out.masked_fill(~mask, 0.) @@ -417,11 +498,14 @@ class TransformerBlock(nn.Module): cross_attend = False, dim_context = None, global_cond_dim = None, + global_cond_shared_embed = False, + local_add_cond_dim = None, causal = False, zero_init_branch_outputs = True, conformer = False, layer_ix = -1, remove_norms = False, + norm_type = "layer_norm", attn_kwargs = {}, ff_kwargs = {}, norm_kwargs = {}, @@ -436,8 +520,20 @@ class TransformerBlock(nn.Module): self.cross_attend = cross_attend self.dim_context = dim_context self.causal = causal + self.global_cond_shared_embed = global_cond_shared_embed - self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity() + norm_layer_map = { + "layer_norm": LayerNorm, + "rms_norm": RMSNorm, + } + norm_cls = norm_layer_map.get(norm_type, LayerNorm) + + def make_norm(): + if remove_norms: + return nn.Identity() + return norm_cls(dim, dtype=dtype, device=device, **norm_kwargs) + + self.pre_norm = make_norm() self.self_attn = Attention( dim, @@ -451,7 +547,7 @@ class TransformerBlock(nn.Module): ) if cross_attend: - self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity() + self.cross_attend_norm = make_norm() self.cross_attn = Attention( dim, dim_heads = dim_heads, @@ -464,37 +560,56 @@ class TransformerBlock(nn.Module): **attn_kwargs ) - self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity() - self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs) + self.ff_norm = make_norm() + self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations, **ff_kwargs) self.layer_ix = layer_ix self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None - self.global_cond_dim = global_cond_dim + # Global conditioning + self.has_global_cond = (global_cond_dim is not None) or global_cond_shared_embed - if global_cond_dim is not None: + if global_cond_shared_embed: + # SA3 style: learnable per-block additive bias; global_cond is pre-projected to (B, dim*6) + self.to_scale_shift_gate = nn.Parameter(torch.empty(dim * 6, device=device, dtype=dtype)) + elif global_cond_dim is not None: + # SA1 style: per-block MLP projects global_cond → (B, dim*6) self.to_scale_shift_gate = nn.Sequential( nn.SiLU(), - nn.Linear(global_cond_dim, dim * 6, bias=False) + operations.Linear(global_cond_dim, dim * 6, bias=False, device=device, dtype=dtype) ) - nn.init.zeros_(self.to_scale_shift_gate[1].weight) - #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias) + # Local additive conditioning (e.g. inpaint mask + masked latent) + self.local_add_cond_dim = local_add_cond_dim + if local_add_cond_dim is not None: + self.to_local_embed = nn.Sequential( + operations.Linear(local_add_cond_dim, dim, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(dim, dim, bias=True, dtype=dtype, device=device), + ) + else: + self.to_local_embed = None def forward( self, x, context = None, global_cond=None, + local_add_cond=None, mask = None, context_mask = None, rotary_pos_emb = None, transformer_options={} ): - if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: + if self.has_global_cond and global_cond is not None: + if self.global_cond_shared_embed: + # global_cond already has shape (B, dim*6) + ssg = (comfy.ops.cast_to_input(self.to_scale_shift_gate, global_cond) + global_cond).unsqueeze(1) + else: + ssg = self.to_scale_shift_gate(global_cond).unsqueeze(1) - scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1) + scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = ssg.chunk(6, dim = -1) # self-attention with adaLN residual = x @@ -510,6 +625,9 @@ class TransformerBlock(nn.Module): if self.conformer is not None: x = x + self.conformer(x) + if local_add_cond is not None and self.to_local_embed is not None: + x = x + _left_pad_to_match(self.to_local_embed(local_add_cond), x.shape[-2]) + # feedforward with adaLN residual = x x = self.ff_norm(x) @@ -527,6 +645,9 @@ class TransformerBlock(nn.Module): if self.conformer is not None: x = x + self.conformer(x) + if local_add_cond is not None and self.to_local_embed is not None: + x = x + _left_pad_to_match(self.to_local_embed(local_add_cond), x.shape[-2]) + x = x + self.ff(self.ff_norm(x)) return x @@ -543,6 +664,8 @@ class ContinuousTransformer(nn.Module): cross_attend=False, cond_token_dim=None, global_cond_dim=None, + global_cond_shared_embed=False, + local_add_cond_dim=None, causal=False, rotary_pos_emb=True, zero_init_branch_outputs=True, @@ -550,6 +673,7 @@ class ContinuousTransformer(nn.Module): use_sinusoidal_emb=False, use_abs_pos_emb=False, abs_pos_emb_max_length=10000, + num_memory_tokens=0, dtype=None, device=None, operations=None, @@ -562,6 +686,8 @@ class ContinuousTransformer(nn.Module): self.depth = depth self.causal = causal self.layers = nn.ModuleList([]) + self.num_memory_tokens = num_memory_tokens + self.global_cond_shared_embed = global_cond_shared_embed self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity() self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity() @@ -577,7 +703,22 @@ class ContinuousTransformer(nn.Module): self.use_abs_pos_emb = use_abs_pos_emb if use_abs_pos_emb: - self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length) + self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + num_memory_tokens) + + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.empty(num_memory_tokens, dim, device=device, dtype=dtype)) + + # Shared global-cond embedder (SA3 style): projects (B, global_cond_dim) → (B, dim*6) + self.global_cond_embedder = None + if global_cond_shared_embed and global_cond_dim is not None: + self.global_cond_embedder = nn.Sequential( + operations.Linear(global_cond_dim, dim, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(dim, dim * 6, bias=True, dtype=dtype, device=device), + ) + + # When using shared embed, TransformerBlocks use per-block Parameter (not per-block MLP) + block_global_cond_dim = None if global_cond_shared_embed else global_cond_dim for i in range(depth): self.layers.append( @@ -586,7 +727,9 @@ class ContinuousTransformer(nn.Module): dim_heads = dim_heads, cross_attend = cross_attend, dim_context = cond_token_dim, - global_cond_dim = global_cond_dim, + global_cond_dim = block_global_cond_dim, + global_cond_shared_embed = global_cond_shared_embed, + local_add_cond_dim = local_add_cond_dim, causal = causal, zero_init_branch_outputs = zero_init_branch_outputs, conformer=conformer, @@ -605,6 +748,7 @@ class ContinuousTransformer(nn.Module): prepend_embeds = None, prepend_mask = None, global_cond = None, + local_add_cond = None, return_info = False, **kwargs ): @@ -632,7 +776,9 @@ class ContinuousTransformer(nn.Module): mask = torch.cat((prepend_mask, mask), dim = -1) - # Attention layers + if self.num_memory_tokens > 0: + memory_tokens = comfy.ops.cast_to_input(self.memory_tokens, x).expand(batch, -1, -1) + x = torch.cat((memory_tokens, x), dim=1) if self.rotary_pos_emb is not None: rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device) @@ -642,6 +788,10 @@ class ContinuousTransformer(nn.Module): if self.use_sinusoidal_emb or self.use_abs_pos_emb: x = x + self.pos_emb(x) + # Project global_cond once (SA3 shared-embed path) + if global_cond is not None and self.global_cond_embedder is not None: + global_cond = self.global_cond_embedder(global_cond) + blocks_replace = patches_replace.get("dit", {}) # Iterate over the transformer layers for i, layer in enumerate(self.layers): @@ -654,12 +804,17 @@ class ContinuousTransformer(nn.Module): out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options) - # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) + x = layer(x, rotary_pos_emb=rotary_pos_emb, global_cond=global_cond, + local_add_cond=local_add_cond, context=context, + transformer_options=transformer_options) if return_info: info["hidden_states"].append(x) + # Strip memory tokens before projecting out + if self.num_memory_tokens > 0: + x = x[:, self.num_memory_tokens:, :] + x = self.project_out(x) if return_info: @@ -682,6 +837,7 @@ class AudioDiffusionTransformer(nn.Module): num_heads=24, transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer", global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", + timestep_features_type: str = "learned", audio_model="", dtype=None, device=None, @@ -696,7 +852,10 @@ class AudioDiffusionTransformer(nn.Module): # Timestep embeddings timestep_features_dim = 256 - self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device) + if timestep_features_type == "expo": + self.timestep_features = ExpoFourierFeatures(timestep_features_dim, 0.5, 10000.0) + else: + self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device) self.to_timestep_embed = nn.Sequential( operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device), @@ -781,6 +940,7 @@ class AudioDiffusionTransformer(nn.Module): cross_attn_cond=None, cross_attn_cond_mask=None, input_concat_cond=None, + local_add_cond=None, global_embed=None, prepend_cond=None, prepend_cond_mask=None, @@ -802,9 +962,13 @@ class AudioDiffusionTransformer(nn.Module): prepend_cond = self.to_prepend_embed(prepend_cond) prepend_inputs = prepend_cond + prepend_length = prepend_cond.shape[1] if prepend_cond_mask is not None: prepend_mask = prepend_cond_mask + if local_add_cond is not None and local_add_cond.dim() == 3: + local_add_cond = local_add_cond.permute(0, 2, 1) + if input_concat_cond is not None: # Interpolate input_concat_cond to the same length as x @@ -850,7 +1014,7 @@ class AudioDiffusionTransformer(nn.Module): if self.transformer_type == "x-transformers": output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs) elif self.transformer_type == "continuous_transformer": - output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs) + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, local_add_cond=local_add_cond, **extra_args, **kwargs) if return_info: output, info = output @@ -876,6 +1040,7 @@ class AudioDiffusionTransformer(nn.Module): context=None, context_mask=None, input_concat_cond=None, + local_add_cond=None, global_embed=None, negative_global_embed=None, prepend_cond=None, @@ -890,6 +1055,7 @@ class AudioDiffusionTransformer(nn.Module): cross_attn_cond=context, cross_attn_cond_mask=context_mask, input_concat_cond=input_concat_cond, + local_add_cond=local_add_cond, global_embed=global_embed, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, diff --git a/comfy/ldm/audio/embedders.py b/comfy/ldm/audio/embedders.py index 20edb365a..ba9a62837 100644 --- a/comfy/ldm/audio/embedders.py +++ b/comfy/ldm/audio/embedders.py @@ -31,15 +31,39 @@ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: ) +class ExpoFourierFeatures(nn.Module): + """Exponentially-spaced Fourier features (no learnable parameters).""" + def __init__(self, dim, min_freq=0.5, max_freq=10000.0): + super().__init__() + self.dim = dim + self.min_freq = min_freq + self.max_freq = max_freq + + def forward(self, t): + in_dtype = t.dtype + t = t.float() + if t.dim() == 1: + t = t.unsqueeze(-1) + half_dim = self.dim // 2 + ramp = torch.linspace(0, 1, half_dim, device=t.device, dtype=torch.float32) + freqs = torch.exp(ramp * (math.log(self.max_freq) - math.log(self.min_freq)) + math.log(self.min_freq)) + args = t * freqs * 2 * math.pi + return torch.cat([args.cos(), args.sin()], dim=-1).to(in_dtype) + + class NumberEmbedder(nn.Module): def __init__( self, features: int, dim: int = 256, + fourier_features_type="learned", ): super().__init__() self.features = features - self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) + if fourier_features_type == "expo": + self.embedding = nn.Sequential(ExpoFourierFeatures(dim=dim), comfy.ops.manual_cast.Linear(in_features=dim, out_features=features)) + else: + self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) def forward(self, x: Union[List[float], Tensor]) -> Tensor: if not torch.is_tensor(x): @@ -77,14 +101,15 @@ class NumberConditioner(Conditioner): def __init__(self, output_dim: int, min_val: float=0, - max_val: float=1 + max_val: float=1, + fourier_features_type: str = "learned", ): super().__init__(output_dim, output_dim) self.min_val = min_val self.max_val = max_val - self.embedder = NumberEmbedder(features=output_dim) + self.embedder = NumberEmbedder(features=output_dim, fourier_features_type=fourier_features_type) def forward(self, floats, device=None): # Cast the inputs to floats diff --git a/comfy/ldm/audio/vae_sa3.py b/comfy/ldm/audio/vae_sa3.py new file mode 100644 index 000000000..276846444 --- /dev/null +++ b/comfy/ldm/audio/vae_sa3.py @@ -0,0 +1,533 @@ +import torch +import torch.nn as nn + +import comfy.ops +import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.audio.autoencoder import WNConv1d + +ops = comfy.ops.disable_weight_init + +class Transpose(nn.Module): + def forward(self, x, **kwargs): + return x.transpose(-2, -1) + + +def _zero_pad_modulo_sequence(x, size, dim=-2): + input_len = x.shape[dim] + pad_len = (size - input_len % size) % size + if pad_len > 0: + pad_shape = list(x.shape) + pad_shape[dim] = pad_len + x = torch.cat([x, torch.zeros(pad_shape, device=x.device, dtype=x.dtype)], dim=dim) + return x + + +def _sliding_window_mask(seq_len, window, device, dtype): + """Additive attention mask enforcing a ±window local window (matches flash_attn window_size).""" + i = torch.arange(seq_len, device=device).unsqueeze(1) + j = torch.arange(seq_len, device=device).unsqueeze(0) + out_of_window = (j - i).abs() > window + return torch.where( + out_of_window, + torch.full((1,), torch.finfo(dtype).min / 4, device=device, dtype=dtype), + torch.zeros(1, device=device, dtype=dtype), + ) + + +class DynamicTanh(nn.Module): + def __init__(self, dim, init_alpha=4.0, dtype=None, device=None, **kwargs): + super().__init__() + self.alpha = nn.Parameter(torch.empty(1, dtype=dtype, device=device)) + self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) + self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) + + def forward(self, x): + alpha = comfy.ops.cast_to_input(self.alpha, x) + gamma = comfy.ops.cast_to_input(self.gamma, x) + beta = comfy.ops.cast_to_input(self.beta, x) + return gamma * torch.tanh(alpha * x) + beta + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, base=10000, base_rescale_factor=1., dtype=None, device=None): + super().__init__() + base = base * base_rescale_factor ** (dim / (dim - 2)) + self.register_buffer("inv_freq", torch.empty(dim // 2, dtype=dtype, device=device)) + + def forward_from_seq_len(self, seq_len, device, dtype=None): + t = torch.arange(seq_len, device=device, dtype=torch.float32) + return self.forward(t) + + def forward(self, t): + freqs = torch.outer(t.float(), comfy.model_management.cast_to(self.inv_freq, dtype=torch.float32, device=t.device)) + freqs = torch.cat((freqs, freqs), dim=-1) + return freqs, 1. + + +def _rotate_half(x): + d = x.shape[-1] // 2 + return torch.cat((-x[..., d:], x[..., :d]), dim=-1) + + +def _apply_rotary_pos_emb(t, freqs): + out_dtype = t.dtype + rot_dim = freqs.shape[-1] + seq_len = t.shape[-2] + freqs = freqs[-seq_len:] + t_rot, t_pass = t[..., :rot_dim], t[..., rot_dim:] + t_rot = t_rot * freqs.cos() + _rotate_half(t_rot) * freqs.sin() + return torch.cat((t_rot.to(out_dtype), t_pass.to(out_dtype)), dim=-1) + + +class Attention(nn.Module): + def __init__(self, dim, dim_heads=64, qk_norm="none", qk_norm_eps=1e-6, + differential=False, zero_init_output=True, + dtype=None, device=None, operations=None, **kwargs): + super().__init__() + self.num_heads = dim // dim_heads + self.differential = differential + self.qk_norm = qk_norm + + self.to_qkv = operations.Linear( + dim, dim * (5 if differential else 3), bias=False, dtype=dtype, device=device) + self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + + if qk_norm == "dyt": + self.q_norm = DynamicTanh(dim_heads, dtype=dtype, device=device) + self.k_norm = DynamicTanh(dim_heads, dtype=dtype, device=device) + elif qk_norm == "rms": + self.q_norm = operations.RMSNorm(dim_heads, eps=qk_norm_eps, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(dim_heads, eps=qk_norm_eps, dtype=dtype, device=device) + + def forward(self, x, rotary_pos_emb=None, mask=None, **kwargs): + B, N, _ = x.shape + h = self.num_heads + + qkv = self.to_qkv(x) + if self.differential: + q, k, v, q_diff, k_diff = qkv.chunk(5, dim=-1) + del qkv + q = q.view(B, N, h, -1).transpose(1, 2) + k = k.view(B, N, h, -1).transpose(1, 2) + v = v.view(B, N, h, -1).transpose(1, 2) + q_diff = q_diff.view(B, N, h, -1).transpose(1, 2) + k_diff = k_diff.view(B, N, h, -1).transpose(1, 2) + else: + q, k, v = qkv.chunk(3, dim=-1) + del qkv + q = q.view(B, N, h, -1).transpose(1, 2) + k = k.view(B, N, h, -1).transpose(1, 2) + v = v.view(B, N, h, -1).transpose(1, 2) + + if self.qk_norm != "none": + q_dtype, k_dtype = q.dtype, k.dtype + q = self.q_norm(q).to(q_dtype) + k = self.k_norm(k).to(k_dtype) + if self.differential: + q_diff = self.q_norm(q_diff).to(q_dtype) + k_diff = self.k_norm(k_diff).to(k_dtype) + + if rotary_pos_emb is not None: + freqs, _ = rotary_pos_emb + q_dtype, k_dtype = q.dtype, k.dtype + q = _apply_rotary_pos_emb(q.float(), freqs).to(q_dtype) + k = _apply_rotary_pos_emb(k.float(), freqs).to(k_dtype) + if self.differential: + q_diff = _apply_rotary_pos_emb(q_diff.float(), freqs).to(q_dtype) + k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype) + + if self.differential: + out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True) + - optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True)) + del q, k, v, q_diff, k_diff + else: + out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True) + del q, k, v + + return self.to_out(out) + + +class _Sin(nn.Module): + def forward(self, x): + return torch.sin(3.14159265359 * x) + + +class _GLU(nn.Module): + def __init__(self, dim_in, dim_out, activation, dtype=None, device=None, operations=None): + super().__init__() + self.act = activation + self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) + + def forward(self, x): + x = self.proj(x) + x, gate = x.chunk(2, dim=-1) + return x * self.act(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, mult=4, no_bias=False, zero_init_output=True, + sinusoidal=False, dtype=None, device=None, operations=None, **kwargs): + super().__init__() + inner_dim = int(dim * mult) + act = _Sin() if sinusoidal else nn.SiLU() + self.ff = nn.Sequential( + _GLU(dim, inner_dim, act, dtype=dtype, device=device, operations=operations), + nn.Identity(), + operations.Linear(inner_dim, dim, bias=not no_bias, dtype=dtype, device=device), + nn.Identity(), + ) + + def forward(self, x, **kwargs): + return self.ff(x) + + +class TransformerBlock(nn.Module): + def __init__(self, dim, dim_heads=64, causal=False, zero_init_branch_outputs=True, + norm_type="dyt", add_rope=False, attn_kwargs=None, ff_kwargs=None, + norm_kwargs=None, dtype=None, device=None, operations=None, **kwargs): + super().__init__() + if attn_kwargs is None: + attn_kwargs = {} + if ff_kwargs is None: + ff_kwargs = {} + if norm_kwargs is None: + norm_kwargs = {} + dim_heads = min(dim_heads, dim) + + Norm = DynamicTanh if norm_type == "dyt" else operations.RMSNorm + norm_kw = {**norm_kwargs, "dtype": dtype, "device": device} + + self.pre_norm = Norm(dim, **norm_kw) + self.self_attn = Attention(dim, dim_heads=dim_heads, + zero_init_output=zero_init_branch_outputs, + dtype=dtype, device=device, operations=operations, + **attn_kwargs) + self.ff_norm = Norm(dim, **norm_kw) + self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, + dtype=dtype, device=device, operations=operations, **ff_kwargs) + self.rope = RotaryEmbedding(dim_heads // 2, dtype=dtype, device=device) if add_rope else None + + def forward(self, x, mask=None, **kwargs): + rope = self.rope.forward_from_seq_len(x.shape[-2], device=x.device) \ + if self.rope is not None else None + x = x + self.self_attn(self.pre_norm(x), rotary_pos_emb=rope, mask=mask) + x = x + self.ff(self.ff_norm(x)) + return x + + +class TransformerResamplingBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, type="encoder", + transformer_depth=3, dim_heads=128, differential=True, + sliding_window=None, chunk_size=128, chunk_midpoint_shift=False, + dyt=True, ff_mult=3, mapping_bias=True, variable_stride=False, + sinusoidal_blocks=0, conv_mapping=False, dtype=None, device=None, operations=None, **kwargs): + super().__init__() + if type not in ("encoder", "decoder"): + raise ValueError(f"type must be 'encoder' or 'decoder', got {type!r}") + + self.type = type + self.stride = stride + self.chunk_size = chunk_size + self.chunk_midpoint_shift = chunk_midpoint_shift + self.variable_stride = variable_stride + self.transformer_depth = transformer_depth + + transformer_dim = out_channels if type == "encoder" else in_channels + + self.mapping = (WNConv1d(in_channels, out_channels, 3 if conv_mapping else 1, padding="same", bias=mapping_bias) + if in_channels != out_channels else nn.Identity()) + + self.sliding_window_latents = sliding_window + self.sliding_window_seq = self._get_sliding_window_size(sliding_window, stride) + self.input_seg_size, self.output_seg_size, self.sub_chunk_size = self._get_seg_sizes(stride) + + token_seq = 1 if variable_stride else self.output_seg_size + self.new_tokens = nn.Parameter(torch.empty(1, token_seq, transformer_dim, dtype=dtype, device=device)) + + norm_type = "dyt" if dyt else "rms_norm" + attn_kwargs = {"qk_norm": "dyt" if dyt else "rms", "qk_norm_eps": 1e-3, + "differential": differential} + norm_kwargs = {"eps": 1e-3} + transformers = [] + for i in range(transformer_depth): + sinusoidal = (transformer_depth - i) < sinusoidal_blocks + transformers.append(TransformerBlock( + transformer_dim, + dim_heads=dim_heads, + causal=False, + zero_init_branch_outputs=True, + norm_type=norm_type, + add_rope=True, + attn_kwargs=attn_kwargs, + ff_kwargs={"mult": ff_mult, "no_bias": False, "sinusoidal": sinusoidal}, + norm_kwargs=norm_kwargs, + dtype=dtype, device=device, operations=operations, + )) + self.transformers = nn.ModuleList(transformers) + + def _get_sliding_window_size(self, window, stride, prepend_cond_length=0): + if window is None: + return None + return [w * (stride + 1 + prepend_cond_length) for w in window] + + def _get_seg_sizes(self, stride, prepend_cond_length=0): + sub_chunk_size = stride + 1 + prepend_cond_length + input_seg_size = stride if self.type == "encoder" else 1 + output_seg_size = 1 if self.type == "encoder" else stride + return input_seg_size, output_seg_size, sub_chunk_size + + def forward(self, x, stride=None, **kwargs): + B = x.shape[0] + + if stride is None: + input_seg = self.input_seg_size + output_seg = self.output_seg_size + sub_chunk = self.sub_chunk_size + sliding_window = self.sliding_window_seq + else: + input_seg, output_seg, sub_chunk = self._get_seg_sizes(stride) + sliding_window = self._get_sliding_window_size(self.sliding_window_latents, stride) + + if self.type == "encoder": + if self.transformer_depth > 0: + pad_mod = self.chunk_size if sliding_window is None else input_seg + x = _zero_pad_modulo_sequence(x, pad_mod, dim=-1) + x = self.mapping(x) + + if self.transformer_depth > 0: + x = x.permute(0, 2, 1) + + if self.type != "encoder": + pad_mod = 1 if sliding_window is not None else ( + self.chunk_size // (stride if stride is not None else self.stride)) + x = _zero_pad_modulo_sequence(x, pad_mod) + + C = x.shape[2] + x = x.reshape(-1, input_seg, C) + + new_tokens = self.new_tokens.expand(x.shape[0], output_seg, -1) + x = torch.cat([x, comfy.ops.cast_to_input(new_tokens, x)], dim=-2) + del new_tokens + + x = x.reshape(B, -1, C) + + if sliding_window is None: + eff_chunk = self.chunk_size + self.chunk_size // (stride if stride is not None else self.stride) + + if sliding_window is None and self.chunk_midpoint_shift: + split = self.transformer_depth // 2 + shift = eff_chunk // 2 + + x = x.reshape(-1, eff_chunk, C) + for layer in self.transformers[:split]: + x = layer(x) + x = x.reshape(B, -1, C) + + shifted = torch.cat([x[:, :shift, :], x, x[:, -shift:, :]], dim=1) + del x + x = shifted.reshape(-1, eff_chunk, C) + del shifted + for layer in self.transformers[split:]: + x = layer(x) + x = x.reshape(B, -1, C) + x = x[:, shift:-shift, :] + elif sliding_window is None: + x = x.reshape(-1, eff_chunk, C) + for layer in self.transformers: + x = layer(x) + x = x.reshape(B, -1, C) + else: + attn_mask = _sliding_window_mask(x.shape[1], sliding_window[0], x.device, x.dtype) + for layer in self.transformers: + x = layer(x, mask=attn_mask) + + x = x.reshape(-1, sub_chunk, C) + x = x[:, -output_seg:, :] + x = x.reshape(B, -1, C).transpose(1, 2) + + if self.type == "decoder": + x = self.mapping(x) + + return x + + +class SAMEEncoder(nn.Module): + def __init__(self, in_channels=2, channels=128, latent_dim=32, + c_mults=(1, 2, 4, 8), strides=(2, 4, 8, 8), + transformer_depths=(3, 3, 3, 3), + dtype=None, device=None, operations=None, **kwargs): + super().__init__() + channel_dims = [in_channels] + [channels * c for c in c_mults] + layers = [] + for i in range(len(c_mults)): + layers.append(TransformerResamplingBlock( + in_channels=channel_dims[i], out_channels=channel_dims[i + 1], + stride=strides[i], type="encoder", + transformer_depth=transformer_depths[i], + dtype=dtype, device=device, operations=operations, **kwargs)) + layers += [ + Transpose(), + operations.Linear(channel_dims[-1], latent_dim, dtype=dtype, device=device), + Transpose(), + ] + self.layers = nn.ModuleList(layers) + + def forward(self, x, **kwargs): + for layer in self.layers: + x = layer(x) + return x + + +class SAMEDecoder(nn.Module): + def __init__(self, out_channels=2, channels=128, latent_dim=32, + c_mults=(1, 2, 4, 8), strides=(2, 4, 8, 8), + transformer_depths=(3, 3, 3, 3), sinusoidal_blocks=None, + dtype=None, device=None, operations=None, **kwargs): + super().__init__() + if sinusoidal_blocks is None: + sinusoidal_blocks = [0] * len(c_mults) + channel_dims = [out_channels] + [channels * c for c in c_mults] + layers = [ + Transpose(), + operations.Linear(latent_dim, channel_dims[-1], dtype=dtype, device=device), + Transpose(), + ] + for i in range(len(c_mults) - 1, -1, -1): + layers.append(TransformerResamplingBlock( + in_channels=channel_dims[i + 1], out_channels=channel_dims[i], + stride=strides[i], type="decoder", + transformer_depth=transformer_depths[i], + sinusoidal_blocks=sinusoidal_blocks[i], + dtype=dtype, device=device, operations=operations, **kwargs)) + self.layers = nn.ModuleList(layers) + + def forward(self, x, **kwargs): + for layer in self.layers: + x = layer(x) + return x + + +class SoftNormBottleneck(nn.Module): + def __init__(self, dim=32, noise_augment_dim=0, noise_regularize=False, + auto_scale=False, freeze=False, dtype=None, device=None, **kwargs): + super().__init__() + self.noise_augment_dim = noise_augment_dim + self.noise_regularize = noise_regularize + self.scaling_factor = nn.Parameter(torch.empty(1, dim, 1, dtype=dtype, device=device)) + self.bias = nn.Parameter(torch.empty(1, dim, 1, dtype=dtype, device=device)) + self.noise_scaling_factor = nn.Parameter(torch.empty(1, noise_augment_dim, 1, dtype=dtype, device=device)) + if auto_scale: + self.register_parameter("running_std", nn.Parameter( + torch.empty(1, dtype=dtype, device=device), requires_grad=False)) + if freeze: + for p in self.parameters(): + p.requires_grad = False + + def encode(self, x, return_info=False, **kwargs): + x = x * comfy.ops.cast_to_input(self.scaling_factor, x) \ + + comfy.ops.cast_to_input(self.bias, x) + if hasattr(self, "running_std"): + x = x / comfy.ops.cast_to_input(self.running_std, x) + if return_info: + return x, {} + return x + + def decode(self, x, **kwargs): + if hasattr(self, "running_std"): + x = x * comfy.ops.cast_to_input(self.running_std, x) + if self.noise_regularize: + scaling = self.running_std if hasattr(self, "running_std") \ + else x.std(dim=-1, keepdim=True) + noise = torch.randn_like(x) * comfy.ops.cast_to_input(scaling, x) * 1e-3 + x = x + noise + if self.noise_augment_dim > 0: + noise = comfy.ops.cast_to_input(self.noise_scaling_factor, x) * torch.randn( + x.shape[0], self.noise_augment_dim, x.shape[-1], device=x.device, dtype=x.dtype) + x = torch.cat([x, noise], dim=1) + return x + + +class PatchedPretransform(nn.Module): + def __init__(self, channels, patch_size, **kwargs): + super().__init__() + self.channels = channels + self.patch_size = patch_size + self.enable_grad = False + + def _pad(self, x): + pad_len = (self.patch_size - x.shape[-1] % self.patch_size) % self.patch_size + if pad_len > 0: + x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1) + return x + + def encode(self, x): + x = self._pad(x) + B, C, T = x.shape + h = self.patch_size + L = T // h + # b c (l h) -> b (c h) l + return x.reshape(B, C, L, h).permute(0, 1, 3, 2).reshape(B, C * h, L) + + def decode(self, x): + B, Ch, L = x.shape + h = self.patch_size + C = Ch // h + # b (c h) l -> b c (l h) + return x.reshape(B, C, h, L).permute(0, 1, 3, 2).reshape(B, C, L * h) + + +class SA3AudioVAE(nn.Module): + """SA3 VAE. State dict keys match checkpoint after stripping 'pretransform.model.'""" + + def __init__(self, channels=256, transformer_depths=12, sinusoidal_blocks=8, + sliding_window=None, decoder_conv_mapping=False, + chunk_size=128, chunk_midpoint_shift=False, + dtype=None, device=None, operations=None): + super().__init__() + if operations is None: + operations = ops + + self.pretransform = PatchedPretransform(channels=2, patch_size=256) + + common_kwargs = dict( + differential=True, dyt=True, dim_heads=64, + sliding_window=sliding_window, variable_stride=True, + chunk_size=chunk_size, chunk_midpoint_shift=chunk_midpoint_shift, + dtype=dtype, device=device, operations=operations, + ) + self.encoder = SAMEEncoder( + in_channels=512, channels=channels, c_mults=[6], strides=[16], + latent_dim=256, transformer_depths=[transformer_depths], + conv_mapping=False, **common_kwargs, + ) + self.decoder = SAMEDecoder( + out_channels=512, channels=channels, c_mults=[6], strides=[16], + latent_dim=256, transformer_depths=[transformer_depths], sinusoidal_blocks=[sinusoidal_blocks], + conv_mapping=decoder_conv_mapping, **common_kwargs, + ) + self.bottleneck = SoftNormBottleneck( + dim=256, noise_augment_dim=0, noise_regularize=True, + auto_scale=True, freeze=True, + dtype=dtype, device=device, + ) + + @torch.no_grad() + def _pretransform_encode(self, x): + return self.pretransform.encode(x) + + @torch.no_grad() + def _pretransform_decode(self, x): + return self.pretransform.decode(x) + + def encode(self, x): + x = self._pretransform_encode(x) + x = self.encoder(x) + x = self.bottleneck.encode(x) + return x + + def decode(self, x): + x = self.bottleneck.decode(x) + x = self.decoder(x) + x = self._pretransform_decode(x) + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index c22705655..d81f13c69 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -813,6 +813,85 @@ class StableAudio1(BaseModel): sd["{}{}".format(k, l)] = s[l] return sd +class StableAudio3(BaseModel): + def __init__(self, model_config, seconds_total_embedder_weights, padding_embedding=None, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer) + self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=384, fourier_features_type=model_config.unet_config["timestep_features_type"]) + self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights) + if padding_embedding is not None: + self.padding_embedding = torch.nn.Parameter(padding_embedding, requires_grad=False) + else: + self.padding_embedding = None + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + image = kwargs.get("concat_latent_image", None) + + if image is None: + shape_image = list(noise.shape) + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + image = self.process_latent_in(image) + # TODO: scale if not match + image = utils.resize_to_batch_size(image, noise.shape[0]) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.zeros_like(noise)[:, :1] + else: + if mask.shape[1] != 1: + mask = torch.mean(mask, dim=1, keepdim=True) + mask = 1.0 - mask + # TODO: scale if not match + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + + return torch.cat((mask, image), dim=1) + + def extra_conds(self, **kwargs): + out = {} + + concat_cond = self.concat_cond(**kwargs) + if concat_cond is not None: + out['local_add_cond'] = comfy.conds.CONDNoiseShape(concat_cond) + + noise = kwargs.get("noise", None) + device = kwargs["device"] + + seconds_total = kwargs.get("seconds_total", int(noise.shape[-1] / 10.7666)) + seconds_total_embed = self.seconds_total_embedder([seconds_total])[0].to(device) + + global_embed = seconds_total_embed.reshape((1, -1)) + out['global_embed'] = comfy.conds.CONDRegular(global_embed) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + cross_attn = cross_attn.to(device) + if self.padding_embedding is not None: + pe = self.padding_embedding.to(device=device, dtype=cross_attn.dtype) + max_text_tokens = self.model_config.unet_config.get("max_text_tokens", 256) + n_text = cross_attn.shape[1] + if n_text < max_text_tokens: + pad = pe.view(1, 1, -1).expand(cross_attn.shape[0], max_text_tokens - n_text, -1) + cross_attn = torch.cat([cross_attn, pad], dim=1) + cross_attn = torch.cat([cross_attn, seconds_total_embed.repeat((cross_attn.shape[0], 1, 1))], dim=1) + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + return out + + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) + + d = {"conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()} + + for k in d: + s = d[k] + for l in s: + sd["{}{}".format(k, l)] = s[l] + + if self.padding_embedding is not None: + sd["conditioner.conditioners.prompt.padding_embedding"] = self.padding_embedding.data + return sd + class HunyuanDiT(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index bc0b933bc..70b4df8b3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -116,6 +116,45 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit unet_config = {} unet_config["audio_model"] = "dit1.0" + unet_config["global_cond_dim"] = state_dict['{}to_global_embed.0.weight'.format(key_prefix)].shape[1] + cond_embed = state_dict['{}to_cond_embed.0.weight'.format(key_prefix)] + unet_config["project_cond_tokens"] = cond_embed.shape[0] != cond_embed.shape[1] + unet_config["embed_dim"] = state_dict['{}to_timestep_embed.0.weight'.format(key_prefix)].shape[0] + mem_tokens = state_dict.get('{}transformer.memory_tokens'.format(key_prefix), None) + to_qkv = state_dict.get('{}transformer.layers.0.self_attn.to_qkv.weight'.format(key_prefix), None) + differential = False + if to_qkv is not None: + if to_qkv.shape[0] == to_qkv.shape[1] * 5: + differential = True + if mem_tokens is not None: + unet_config["num_memory_tokens"] = mem_tokens.shape[0] + if '{}transformer.layers.0.self_attn.q_norm.weight'.format(key_prefix) in state_dict: + unet_config["attn_kwargs"] = {"qk_norm": "ln", "feat_scale": True} + rms_norm = state_dict.get('{}transformer.layers.0.self_attn.q_norm.gamma'.format(key_prefix), None) + if rms_norm is not None: + unet_config["attn_kwargs"] = {"qk_norm": "rms", "differential": differential} + unet_config["norm_type"] = "rms_norm" + unet_config["num_heads"] = unet_config["embed_dim"] // rms_norm.shape[0] + + if '{}timestep_features.weight'.format(key_prefix) in state_dict: + unet_config["timestep_features_type"] = "learned" + else: + unet_config["timestep_features_type"] = "expo" + + io_channels = state_dict['{}postprocess_conv.weight'.format(key_prefix)].shape[0] + unet_config["io_channels"] = io_channels + unet_config["input_concat_dim"] = state_dict['{}transformer.project_in.weight'.format(key_prefix)].shape[1] - io_channels + + local_add_cond = state_dict.get('{}transformer.layers.0.to_local_embed.0.weight'.format(key_prefix), None) + if local_add_cond is not None: + unet_config["local_add_cond_dim"] = local_add_cond.shape[1] + + global_cond_embed = state_dict.get('{}transformer.global_cond_embedder.0.weight'.format(key_prefix), None) + if global_cond_embed is not None: + unet_config["global_cond_shared_embed"] = True + unet_config["global_cond_type"] = "adaLN" + + unet_config["depth"] = count_blocks(state_dict_keys, '{}transformer.layers.'.format(key_prefix) + '{}.') return unet_config if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit diff --git a/comfy/sd.py b/comfy/sd.py index 2443353a4..7bd07ed3a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -21,6 +21,7 @@ import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.cogvideo.vae import comfy.ldm.hunyuan_video.vae import comfy.ldm.mmaudio.vae.autoencoder +import comfy.ldm.audio.vae_sa3 import comfy.pixel_space_convert import comfy.weight_adapter import yaml @@ -67,6 +68,7 @@ import comfy.text_encoders.qwen35 import comfy.text_encoders.ernie import comfy.text_encoders.gemma4 import comfy.text_encoders.cogvideo +import comfy.text_encoders.sa3 import comfy.model_patcher import comfy.lora @@ -854,6 +856,34 @@ class VAE: self.working_dtypes = [torch.float32] self.disable_offload = True self.extra_1d_channel = 16 + elif "decoder.layers.3.transformers.0.pre_norm.alpha" in sd: # Stable Audio 3 VAE + if "decoder.layers.3.transformers.11.self_attn.to_out.weight" in sd: + config = {"channels": 256, "transformer_depths": 12, "sinusoidal_blocks": 8, + "sliding_window": [1, 1], "decoder_conv_mapping": False, + "chunk_size": 128, "chunk_midpoint_shift": False} + self.memory_used_encode = lambda shape, dtype: (1500 * shape[2]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * 4096) * model_management.dtype_size(dtype) + else: + config = {"channels": 128, "transformer_depths": 6, "sinusoidal_blocks": 0, + "sliding_window": None, "decoder_conv_mapping": True, + "chunk_size": 32, "chunk_midpoint_shift": True} + self.memory_used_encode = lambda shape, dtype: (72 * shape[2]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (72 * shape[2] * 4096) * model_management.dtype_size(dtype) + + self.first_stage_model = comfy.ldm.audio.vae_sa3.SA3AudioVAE(**config) + self.latent_channels = 256 + self.output_channels = 2 + self.upscale_ratio = 4096 + self.downscale_ratio = 4096 + self.latent_dim = 1 + self.audio_sample_rate = 44100 + self.process_output = lambda audio: audio + self.process_input = lambda audio: audio + self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] + #This VAE has Parameters and Buffers the non-dynamic caster cannot handle + #Force cast it for --disable-dynamic-vram users until there is a true core fix. + if not comfy.memory_management.aimdo_enabled: + self.disable_offload = True else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -1290,6 +1320,7 @@ class TEModel(Enum): GEMMA_4_E4B = 29 GEMMA_4_E2B = 30 GEMMA_4_31B = 31 + T5_GEMMA = 32 def detect_te_model(sd): @@ -1314,6 +1345,8 @@ def detect_te_model(sd): if weight.shape[0] == 384: return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE + if "model.encoder.layers.0.pre_self_attn_layernorm.weight" in sd: + return TEModel.T5_GEMMA if 'model.layers.0.post_feedforward_layernorm.weight' in sd: if 'model.layers.59.self_attn.q_norm.weight' in sd: return TEModel.GEMMA_4_31B @@ -1463,6 +1496,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip else: clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer + elif te_model == TEModel.T5_GEMMA: + clip_target.clip = comfy.text_encoders.sa3.SAT5GemmaModel + clip_target.tokenizer = comfy.text_encoders.sa3.SAT5GemmaTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B): variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B, TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B, diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1e4434fd5..617db4f28 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -7,6 +7,7 @@ from . import sdxl_clip import comfy.text_encoders.sd2_clip import comfy.text_encoders.sd3_clip import comfy.text_encoders.sa_t5 +import comfy.text_encoders.sa3 import comfy.text_encoders.aura_t5 import comfy.text_encoders.pixart_t5 import comfy.text_encoders.hydit @@ -603,6 +604,29 @@ class StableAudio(supported_models_base.BASE): def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.sa_t5.SAT5Tokenizer, comfy.text_encoders.sa_t5.SAT5Model) +class StableAudio3(StableAudio): + unet_config = { + "audio_model": "dit1.0", + "global_cond_shared_embed": True, + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 2.0, + } + + latent_format = latent_formats.StableAudio3 + + memory_usage_factor = 7 + + def get_model(self, state_dict, prefix="", device=None): + seconds_total_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_total.": ""}, filter_keys=True) + padding_embedding = state_dict.get("conditioner.conditioners.prompt.padding_embedding", None) + return model_base.StableAudio3(self, seconds_total_embedder_weights=seconds_total_sd, padding_embedding=padding_embedding, device=device) + + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget(comfy.text_encoders.sa3.SAT5GemmaTokenizer, comfy.text_encoders.sa3.SAT5GemmaModel) + class AuraFlow(supported_models_base.BASE): unet_config = { "cond_seq_dim": 2048, @@ -2018,6 +2042,7 @@ models = [ SV3D_u, SV3D_p, SD3, + StableAudio3, StableAudio, AuraFlow, PixArtAlpha, diff --git a/comfy/text_encoders/sa3.py b/comfy/text_encoders/sa3.py new file mode 100644 index 000000000..0a1c73ec1 --- /dev/null +++ b/comfy/text_encoders/sa3.py @@ -0,0 +1,207 @@ +import torch +import torch.nn as nn +from comfy import sd1_clip +from comfy.text_encoders.llama import Attention as LlamaAttention, RMSNorm, MLP, precompute_freqs_cis, apply_rope, _make_scaled_embedding +from comfy.text_encoders.spiece_tokenizer import SPieceTokenizer + + +class T5GemmaEncoderConfig: + def __init__(self): + self.vocab_size = 256000 + self.hidden_size = 768 + self.intermediate_size = 2048 + self.num_hidden_layers = 12 + self.num_attention_heads = 12 + self.num_key_value_heads = 12 + self.head_dim = 64 + self.rms_norm_eps = 1e-6 + self.rms_norm_add = False + self.rope_theta = 10000.0 + self.attn_logit_softcapping = 50.0 + self.query_pre_attn_scalar = 64 + self.sliding_window = 4096 + self.mlp_activation = "gelu_pytorch_tanh" + self.layer_types = ["sliding_attention", "full_attention"] * 6 + self.qkv_bias = False + self.q_norm = None + self.k_norm = None + self.rms_norm_add = True + + +class T5GemmaAttention(LlamaAttention): + """Reuses LlamaAttention projection setup; overrides forward for softcap attention. + + T5Gemma applies tanh(QK^T * scale / cap) * cap between the matmul and softmax. + This nonlinearity is incompatible with fused SDPA kernels, so attention is + computed manually. Everything else (projections, RoPE, GQA expansion) is identical + to LlamaAttention so __init__ is inherited unchanged. + """ + + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__(config, device=device, dtype=dtype, ops=ops) + self.scale = config.query_pre_attn_scalar ** -0.5 + self.softcap = config.attn_logit_softcapping + + def forward(self, hidden_states, attention_mask=None, freqs_cis=None, **kwargs): + B, S, _ = hidden_states.shape + xq = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) + xk = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) + xv = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) + xq, xk = apply_rope(xq, xk, freqs_cis) + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + attn = torch.matmul(xq * self.scale, xk.transpose(-2, -1)) + attn = torch.tanh(attn / self.softcap) * self.softcap + if attention_mask is not None: + attn = attn + attention_mask + attn = torch.nn.functional.softmax(attn.float(), dim=-1).to(xq.dtype) + out = torch.matmul(attn, xv).transpose(1, 2).reshape(B, S, self.inner_size) + return self.o_proj(out), None + + +class T5GemmaBlock(nn.Module): + def __init__(self, config, layer_type, device=None, dtype=None, ops=None): + super().__init__() + self.self_attn = T5GemmaAttention(config, device=device, dtype=dtype, ops=ops) + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) + # Names match checkpoint keys: model.encoder.layers.X..weight + self.pre_self_attn_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype) + self.post_self_attn_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype) + self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype) + self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype) + self.is_sliding = (layer_type == "sliding_attention") + self.sliding_window = config.sliding_window + + def forward(self, x, attention_mask=None, freqs_cis=None): + attn_mask = attention_mask + if self.is_sliding and x.shape[1] > self.sliding_window: + S = x.shape[1] + pos = torch.arange(S, device=x.device) + dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs() + sw_mask = torch.zeros(S, S, dtype=x.dtype, device=x.device) + sw_mask.masked_fill_(dist > self.sliding_window, -torch.finfo(x.dtype).max) + sw_mask = sw_mask.unsqueeze(0).unsqueeze(0) + attn_mask = (attention_mask + sw_mask) if attention_mask is not None else sw_mask + residual = x + x = self.pre_self_attn_layernorm(x) + x, _ = self.self_attn(x, attention_mask=attn_mask, freqs_cis=freqs_cis) + x = self.post_self_attn_layernorm(x) + x = residual + x + residual = x + x = self.pre_feedforward_layernorm(x) + x = self.mlp(x) + x = self.post_feedforward_layernorm(x) + x = residual + x + return x + + +class T5GemmaEncoder(nn.Module): + """Encoder stack: embed_tokens, layers, norm. + Keys: embed_tokens.*, layers.X.*, norm.*""" + + def __init__(self, config, device, dtype, ops): + super().__init__() + self.config = config + # Gemma-style scaled embedding: output *= sqrt(hidden_size) + self.embed_tokens = _make_scaled_embedding( + ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype) + self.layers = nn.ModuleList([ + T5GemmaBlock(config, config.layer_types[i], device=device, dtype=dtype, ops=ops) + for i in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype) + + def forward(self, input_ids, attention_mask=None, embeds=None, intermediate_output=None, + final_layer_norm_intermediate=True, dtype=None, num_layers=None): + x = embeds if embeds is not None else self.embed_tokens(input_ids, out_dtype=dtype or torch.float32) + seq_len = x.shape[1] + position_ids = torch.arange(seq_len, device=x.device).unsqueeze(0) + freqs_cis = precompute_freqs_cis(self.config.head_dim, position_ids, self.config.rope_theta, device=x.device) + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) + intermediate = None + for i, layer in enumerate(self.layers): + x = layer(x, attention_mask=mask, freqs_cis=freqs_cis) + if i == intermediate_output: + intermediate = x.clone() + x = self.norm(x) + if intermediate is not None and final_layer_norm_intermediate: + intermediate = self.norm(intermediate) + return x, intermediate + + +class T5GemmaBody(nn.Module): + """Provides the 'encoder' sub-module. + Keys: encoder.*""" + + def __init__(self, config, device, dtype, ops): + super().__init__() + self.encoder = T5GemmaEncoder(config, device, dtype, ops) + + +class T5GemmaModel(nn.Module): + """Top-level model class passed to SDClipModel as model_class. + Module layout: self.model.encoder.* → matches checkpoint keys model.encoder.*""" + + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = T5GemmaEncoderConfig() + self.num_layers = config.num_hidden_layers + self.dtype = dtype + self.model = T5GemmaBody(config, device, dtype, operations) + + def get_input_embeddings(self): + return self.model.encoder.embed_tokens + + def set_input_embeddings(self, embeddings): + self.model.encoder.embed_tokens = embeddings + + def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, + intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, **kwargs): + if intermediate_output is not None and intermediate_output < 0: + intermediate_output = self.num_layers + intermediate_output + return self.model.encoder( + input_ids, attention_mask=attention_mask, embeds=embeds, + intermediate_output=intermediate_output, + final_layer_norm_intermediate=final_layer_norm_intermediate, + dtype=dtype, num_layers=self.num_layers) + + +class T5GemmaSDClipModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, + textmodel_json_config={}, dtype=dtype, + special_tokens={"pad": 0}, + model_class=T5GemmaModel, + enable_attention_masks=True, zero_out_masked=True, + model_options=model_options) + + +class T5GemmaSDTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_model = tokenizer_data.get("spiece_model", None) + super().__init__(tokenizer_model, pad_with_end=False, embedding_size=768, + embedding_key="t5gemma", tokenizer_class=SPieceTokenizer, + has_start_token=False, has_end_token=False, pad_to_max_length=False, + max_length=99999999, min_length=1, pad_token=0, + tokenizer_data=tokenizer_data, + tokenizer_args={"add_bos": False, "add_eos": False}) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + + +class SAT5GemmaTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, + tokenizer_data=tokenizer_data, clip_name="t5gemma", tokenizer=T5GemmaSDTokenizer) + + +class SAT5GemmaModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): + super().__init__(device=device, dtype=dtype, model_options=model_options, + name="t5gemma", clip_model=T5GemmaSDClipModel, **kwargs) From 4efe1ddb5c7933e9db28d5ecb0be4fa77e0edcde Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Wed, 20 May 2026 23:46:20 +0800 Subject: [PATCH 3/7] chore: update workflow templates to v0.9.79 (#14011) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f499a10ae..1c87690da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.43.18 -comfyui-workflow-templates==0.9.77 +comfyui-workflow-templates==0.9.79 comfyui-embedded-docs==0.5.0 torch torchsde From a8d2519058ea766ca3b14916bcc01ecef5efd235 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 May 2026 13:49:36 -0400 Subject: [PATCH 4/7] ComfyUI v0.22.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 4c6f5eb2a..0bb0f780c 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.21.1" +__version__ = "0.22.0" diff --git a/pyproject.toml b/pyproject.toml index 0a1554428..1e449b4a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.21.1" +version = "0.22.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 4d6a058bf1dd18fb6d4594081c3f9a7575c97256 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 21 May 2026 02:07:48 +0300 Subject: [PATCH 5/7] feat: MediaPipe face detection (CORE-235) (#14009) * Initial mediapipe face detection support * Update face_geometry.py * Account for diff sized batch input * Model folder placeholder --- comfy_extras/mediapipe/face_geometry.py | 111 ++++ comfy_extras/mediapipe/face_landmarker.py | 682 +++++++++++++++++++++ comfy_extras/nodes_mediapipe.py | 502 +++++++++++++++ folder_paths.py | 2 + models/mediapipe/put_mediapipe_models_here | 0 nodes.py | 1 + 6 files changed, 1298 insertions(+) create mode 100644 comfy_extras/mediapipe/face_geometry.py create mode 100644 comfy_extras/mediapipe/face_landmarker.py create mode 100644 comfy_extras/nodes_mediapipe.py create mode 100644 models/mediapipe/put_mediapipe_models_here diff --git a/comfy_extras/mediapipe/face_geometry.py b/comfy_extras/mediapipe/face_geometry.py new file mode 100644 index 000000000..04b2b0557 --- /dev/null +++ b/comfy_extras/mediapipe/face_geometry.py @@ -0,0 +1,111 @@ +"""Pure-numpy port of MediaPipe's face_geometry (FACE_LANDMARK_PIPELINE mode) ++ weighted Procrustes solver. Computes the 4x4 facial transformation matrix. +""" + +from __future__ import annotations + +import math +import numpy as np + + +def _solve_weighted_orthogonal_problem(src: np.ndarray, tgt: np.ndarray, weights: np.ndarray) -> np.ndarray: + """Weighted orthogonal Procrustes (similarity). Returns 4x4 M with + `target ≈ M @ homogeneous(source)` in the weighted LS sense. fp64 for + SVD stability. Port of procrustes_solver.cc.""" + sqrt_w = np.sqrt(weights.astype(np.float64)) + w_total = float((sqrt_w ** 2).sum()) + ws = src.astype(np.float64) * sqrt_w + wt = tgt.astype(np.float64) * sqrt_w + + c_w = (ws @ sqrt_w) / w_total + centered = ws - np.outer(c_w, sqrt_w) + U, _S, Vt = np.linalg.svd(wt @ centered.T, full_matrices=True) + # Disallow reflection: flip the least-significant axis when det(U)·det(V)<0. + post, pre = U.copy(), Vt.T.copy() + if np.linalg.det(post) * np.linalg.det(pre) < 0: + post[:, 2] *= -1.0 + R = post @ pre.T + + denom = float((centered * ws).sum()) + if denom < 1e-12: + raise ValueError("Procrustes denominator collapsed (degenerate source).") + scale = float((R @ centered * wt).sum()) / denom + translation = ((wt - scale * (R @ ws)) @ sqrt_w) / w_total + + M = np.eye(4, dtype=np.float64) + M[:3, :3] = scale * R + M[:3, 3] = translation + return M + + +def _estimate_scale(canonical: np.ndarray, runtime: np.ndarray, weights: np.ndarray) -> float: + """scale = ‖first column of M[:3]‖ per geometry_pipeline.cc::EstimateScale.""" + return float(np.linalg.norm(_solve_weighted_orthogonal_problem(canonical, runtime, weights)[:3, 0])) + + +def solve_facial_transformation_matrix( + landmarks_normalized: np.ndarray, + canonical_vertices: np.ndarray, + procrustes_indices: np.ndarray, + procrustes_weights: np.ndarray, + image_width: int, + image_height: int, + # face_geometry_calculator_options.pbtxt defaults + vertical_fov_degrees: float = 63.0, + near: float = 1.0, +) -> np.ndarray: + """4x4 facial transformation matrix via two-pass scale recovery + `landmarks_normalized` is (N, 3) in MediaPipe normalized convention: x, y + in [0,1] with TOP-LEFT origin, z in width-scaled units. + """ + + h_near = 2.0 * near * math.tan(0.5 * math.radians(vertical_fov_degrees)) + w_near = image_width * h_near / image_height + + sub = procrustes_indices.astype(np.int64) + screen = landmarks_normalized[sub].T.astype(np.float64).copy() + canon = canonical_vertices[sub].T.astype(np.float64).copy() + weights = procrustes_weights.astype(np.float64) + + # ProjectXY (TOP_LEFT y-flip, then scale all 3 axes; z uses x-scale). + screen[1] = 1.0 - screen[1] + screen[0] = screen[0] * w_near - 0.5 * w_near + screen[1] = screen[1] * h_near - 0.5 * h_near + screen[2] = screen[2] * w_near + depth_offset = float(screen[2].mean()) + + def _unproject(s: np.ndarray, scale: float) -> np.ndarray: + s = s.copy() + s[2] = (s[2] - depth_offset + near) / scale + s[0] *= s[2] / near + s[1] *= s[2] / near + s[2] *= -1.0 + return s + + first = screen.copy() + first[2] *= -1.0 + s1 = _estimate_scale(canon, first, weights) # 1st pass: Procrustes on projected XY + s2 = _estimate_scale(canon, _unproject(screen, s1), weights) # 2nd pass: rescale z by s1, un-project XY + return _solve_weighted_orthogonal_problem(canon, _unproject(screen, s1 * s2), weights).astype(np.float32) + + +def transformation_matrix_from_detection(face_dict: dict, image_width: int, image_height: int, canonical_data: dict) -> np.ndarray: + """Adapt a FaceLandmarker face dict to MP's normalized convention and solve. + FaceMesh emits (x, y, z) in 192-canonical units; MP's geometry expects + z_norm = z_canonical * scale_x / image_width""" + + lmks_xy, lmks_3d = face_dict["landmarks_xy"], face_dict["landmarks_3d"] + aug = np.concatenate([lmks_3d[:, :2].astype(np.float64), np.ones((lmks_xy.shape[0], 1))], axis=1) + M, *_ = np.linalg.lstsq(aug, lmks_xy.astype(np.float64), rcond=None) + scale_x = float(np.linalg.norm(M[0])) + z_scale = scale_x / image_width if scale_x > 1e-6 else 1.0 / image_width + + normalized = np.empty((lmks_xy.shape[0], 3), dtype=np.float32) + normalized[:, 0] = lmks_xy[:, 0] / image_width + normalized[:, 1] = lmks_xy[:, 1] / image_height + normalized[:, 2] = lmks_3d[:, 2] * z_scale + return solve_facial_transformation_matrix( + normalized, canonical_data["canonical_vertices"], + canonical_data["procrustes_indices"], canonical_data["procrustes_weights"], + image_width=image_width, image_height=image_height, + ) diff --git a/comfy_extras/mediapipe/face_landmarker.py b/comfy_extras/mediapipe/face_landmarker.py new file mode 100644 index 000000000..a792b6046 --- /dev/null +++ b/comfy_extras/mediapipe/face_landmarker.py @@ -0,0 +1,682 @@ +"""Pure-PyTorch port of MediaPipe's face_landmarker_v2_with_blendshapes.task: +BlazeFace detector → FaceMesh v2 → ARKit-52 blendshapes.""" + +from __future__ import annotations + +import math +from functools import lru_cache +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.special import expit +from torch import Tensor, nn + + +# Values below must stay verbatim with the published face_landmarker_v2 graph + +# face_blendshapes_graph.cc::kLandmarksSubsetIdxs +_BS_INPUT_INDICES: Tuple[int, ...] = ( + 0, 1, 4, 5, 6, 7, 8, 10, 13, 14, 17, 21, 33, 37, 39, 40, 46, 52, 53, 54, + 55, 58, 61, 63, 65, 66, 67, 70, 78, 80, 81, 82, 84, 87, 88, 91, 93, 95, + 103, 105, 107, 109, 127, 132, 133, 136, 144, 145, 146, 148, 149, 150, 152, + 153, 154, 155, 157, 158, 159, 160, 161, 162, 163, 168, 172, 173, 176, 178, + 181, 185, 191, 195, 197, 234, 246, 249, 251, 263, 267, 269, 270, 276, 282, + 283, 284, 285, 288, 291, 293, 295, 296, 297, 300, 308, 310, 311, 312, 314, + 317, 318, 321, 323, 324, 332, 334, 336, 338, 356, 361, 362, 365, 373, 374, + 375, 377, 378, 379, 380, 381, 382, 384, 385, 386, 387, 388, 389, 390, 397, + 398, 400, 402, 405, 409, 415, 454, 466, 468, 469, 470, 471, 472, 473, 474, + 475, 476, 477, +) + +# face_blendshapes_graph.cc::kCategoryNames +BLENDSHAPE_NAMES: Tuple[str, ...] = ( + "_neutral", "browDownLeft", "browDownRight", "browInnerUp", "browOuterUpLeft", + "browOuterUpRight", "cheekPuff", "cheekSquintLeft", "cheekSquintRight", + "eyeBlinkLeft", "eyeBlinkRight", "eyeLookDownLeft", "eyeLookDownRight", + "eyeLookInLeft", "eyeLookInRight", "eyeLookOutLeft", "eyeLookOutRight", + "eyeLookUpLeft", "eyeLookUpRight", "eyeSquintLeft", "eyeSquintRight", + "eyeWideLeft", "eyeWideRight", "jawForward", "jawLeft", "jawOpen", + "jawRight", "mouthClose", "mouthDimpleLeft", "mouthDimpleRight", + "mouthFrownLeft", "mouthFrownRight", "mouthFunnel", "mouthLeft", + "mouthLowerDownLeft", "mouthLowerDownRight", "mouthPressLeft", + "mouthPressRight", "mouthPucker", "mouthRight", "mouthRollLower", + "mouthRollUpper", "mouthShrugLower", "mouthShrugUpper", "mouthSmileLeft", + "mouthSmileRight", "mouthStretchLeft", "mouthStretchRight", + "mouthUpperUpLeft", "mouthUpperUpRight", "noseSneerLeft", "noseSneerRight", +) + +# face_detection.pbtxt — short-range BlazeFace. +_BF_NUM_LAYERS = 4 +_BF_INPUT_SIZE = 128 +_BF_STRIDES = (8, 16, 16, 16) +_BF_ANCHOR_OFFSET_X = 0.5 +_BF_ANCHOR_OFFSET_Y = 0.5 +_BF_ASPECT_RATIOS = (1.0,) +_BF_INTERP_SCALE_AR = 1.0 +_BF_BOX_SCALE = 128.0 +_BF_KP_OFFSET = 4 +_BF_SCORE_CLIP = 100.0 +_BF_MIN_SCORE = 0.5 + +# face_detection_full_range.pbtxt — 48x48 grid at stride 4, 1 anchor/cell. +_BF_FR_INPUT_SIZE = 192 +_BF_FR_GRID = 48 +_BF_FR_NUM_ANCHORS = _BF_FR_GRID * _BF_FR_GRID +_BF_FR_BOX_SCALE = 192.0 +_BF_FR_SCORE_CLIP = 100.0 + +_FM_INPUT_SIZE = 192 + +# Face ROI: 1.5xbbox rect warped anisotropically into 192x192. +_FACE_LEFT_EYE_KP = 0 +_FACE_RIGHT_EYE_KP = 1 +_FACE_ROI_SCALE_X = 1.5 +_FACE_ROI_SCALE_Y = 1.5 +_FACE_ROI_TARGET_ANGLE = 0.0 + + +def _tf_same_pad(x: Tensor, kernel: int, stride: int) -> Tensor: + """TF SAME pad (asymmetric on stride-2; PyTorch's symmetric pad undershoots by 1 px).""" + H, W = x.shape[-2], x.shape[-1] + pad_h = max(((H + stride - 1) // stride - 1) * stride + kernel - H, 0) + pad_w = max(((W + stride - 1) // stride - 1) * stride + kernel - W, 0) + if pad_h == 0 and pad_w == 0: + return x + return F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) + + +# BlazeFace short-range: stem 5x5/s2 → 16 BlazeBlocks → parallel heads at +# 16²x88 (2 anchors/cell) and 8²x96 (6/cell) = 896 anchors. (in, out, stride): +_BLAZEFACE_BLOCKS = [ + (24, 24, 1), (24, 28, 1), (28, 32, 2), (32, 36, 1), + (36, 42, 1), (42, 48, 2), (48, 56, 1), (56, 64, 1), + (64, 72, 1), (72, 80, 1), (80, 88, 1), (88, 96, 2), + (96, 96, 1), (96, 96, 1), (96, 96, 1), (96, 96, 1), +] + + +class BlazeFaceBlock(nn.Module): + """DW 3x3 + PW + residual. Residual max-pools on stride>1, channel-pads on out_ch>in_ch.""" + + def __init__(self, in_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + self.in_ch, self.out_ch, self.stride = in_ch, out_ch, stride + self.depthwise = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, device=device, dtype=dtype) + self.pointwise = ops.Conv2d(in_ch, out_ch, 1, padding=0, bias=True, device=device, dtype=dtype) + + def forward(self, x: Tensor) -> Tensor: + residual = F.max_pool2d(x, 2, 2) if self.stride > 1 else x + if self.out_ch > self.in_ch: + residual = F.pad(residual, (0, 0, 0, 0, 0, self.out_ch - self.in_ch)) + x = _tf_same_pad(x, 3, self.stride) if self.stride > 1 else F.pad(x, (1, 1, 1, 1)) + return F.relu(self.pointwise(self.depthwise(x)) + residual) + + +class BlazeFace(nn.Module): + """Short-range BlazeFace: (B, 3, 128, 128) in [-1, 1] → 896 anchors x 17.""" + + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + self.stem = ops.Conv2d(3, 24, 5, stride=2, padding=0, bias=True, **kw) + self.blocks = nn.ModuleList(BlazeFaceBlock(i, o, s, device=device, dtype=dtype, operations=operations) + for (i, o, s) in _BLAZEFACE_BLOCKS) + # 16²x2 + 8²x6 = 512 + 384 = 896 anchors. + self.cls_16 = ops.Conv2d(88, 2, 1, padding=0, bias=True, **kw) + self.cls_8 = ops.Conv2d(96, 6, 1, padding=0, bias=True, **kw) + self.reg_16 = ops.Conv2d(88, 32, 1, padding=0, bias=True, **kw) + self.reg_8 = ops.Conv2d(96, 96, 1, padding=0, bias=True, **kw) + + def forward(self, image_chw_normalized: Tensor) -> tuple[Tensor, Tensor]: + x = F.relu(self.stem(_tf_same_pad(image_chw_normalized, 5, 2))) + # 16x16 tap is block-10 output (before the 88→96 stride-2 in block 11). + for i in range(11): + x = self.blocks[i](x) + feat_16 = x + for i in range(11, 16): + x = self.blocks[i](x) + feat_8 = x + + def flat(t, a, k): # NHWC flatten → (B, H*W*A, K) + B, _, H, W = t.shape + return t.permute(0, 2, 3, 1).reshape(B, H * W * a, k) + + cls = torch.cat([flat(self.cls_16(feat_16), 2, 1), flat(self.cls_8(feat_8), 6, 1)], dim=1) + reg = torch.cat([flat(self.reg_16(feat_16), 2, 16), flat(self.reg_8(feat_8), 6, 16)], dim=1) + return reg, cls + + +# BlazeFace full-range (face_detection_full_range_sparse.tflite): MobileNetV2-ish +# backbone + top-down FPN, 192² input → 2304 anchors at the 48x48 grid. +class FRBlock(nn.Module): + """Double inverted residual: DW → PW(mid) → DW → PW(out) [+ residual]. + + Per source tflite: dw* have no fused activation, pw1 is always ReLU, pw2 + is ReLU only when no residual (else ReLU fuses into the ADD). + """ + + def __init__(self, in_ch: int, mid_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + self.has_residual = (in_ch == out_ch and stride == 1) + self.dw1 = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, **kw) + self.pw1 = ops.Conv2d(in_ch, mid_ch, 1, padding=0, bias=True, **kw) + self.dw2 = ops.Conv2d(mid_ch, mid_ch, 3, stride=1, padding=0, groups=mid_ch, bias=True, **kw) + self.pw2 = ops.Conv2d(mid_ch, out_ch, 1, padding=0, bias=True, **kw) + + def forward(self, x: Tensor) -> Tensor: + residual = x if self.has_residual else None + x = F.relu(self.pw1(self.dw1(F.pad(x, (1, 1, 1, 1))))) + x = self.pw2(self.dw2(F.pad(x, (1, 1, 1, 1)))) + return F.relu(x + residual) if residual is not None else F.relu(x) + + +# (in_ch, mid_ch, out_ch, stride). Stages downsample 96²x32 → 48²x64 → 24²x128 +# → 12²x192 → 6²x384. Lateral taps at indices 4, 7, 10 (see _FR_LATERAL_*). +_FR_BACKBONE_BLOCKS = [ + (32, 8, 32, 1), (32, 8, 32, 1), # 96²x32 + (32, 16, 64, 2), (64, 16, 64, 1), (64, 16, 64, 1), # 48²x64 — tap[0] + (64, 32, 128, 2), (128, 32, 128, 1), (128, 32, 128, 1), # 24²x128 — tap[1] + (128, 48, 192, 2), (192, 48, 192, 1), (192, 48, 192, 1), # 12²x192 — tap[2] + (192, 96, 384, 2), (384, 96, 384, 1), (384, 96, 384, 1), (384, 96, 384, 1), # 6²x384 +] +_FR_LATERAL_TAP_INDICES = (4, 7, 10) +_FR_LATERAL_CHANNELS = ((64, 48), (128, 64), (192, 96)) # (in, out) per side-conv + +# Decoder blocks per FPN level (after upsample-and-merge with the lateral). +_FR_DECODER_BLOCKS = [ + [(96, 48, 96, 1), (96, 48, 96, 1)], # 12²x96 + [(64, 32, 64, 1), (64, 32, 64, 1)], # 24²x64 + [(48, 24, 48, 1)], # 48²x48 — feeds the heads +] + + +def _dcr_depth_to_space(t: Tensor, r: int, c_out: int) -> Tensor: + """TF DEPTH_TO_SPACE in DCR layout (input channels = (i, j, c_out)). + pixel_shuffle uses CRD which permutes output channels for c_out > 1.""" + B_, _, H_, W_ = t.shape + t = t.reshape(B_, r, r, c_out, H_, W_) + t = t.permute(0, 3, 4, 1, 5, 2).contiguous() + return t.reshape(B_, c_out, H_ * r, W_ * r) + + +class BlazeFaceFullRange(nn.Module): + """Full-range face detector: (B, 3, 192, 192) in [-1, 1] → 2304 anchors x 17 values.""" + + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + mk_block = lambda i, m, o, s: FRBlock(i, m, o, s, device=device, dtype=dtype, operations=operations) + self.stem = ops.Conv2d(3, 32, 3, stride=2, padding=0, bias=True, **kw) + self.backbone = nn.ModuleList(mk_block(i, m, o, s) for (i, m, o, s) in _FR_BACKBONE_BLOCKS) + self.lateral_convs = nn.ModuleList(ops.Conv2d(i, o, 1, padding=0, bias=True, **kw) for (i, o) in _FR_LATERAL_CHANNELS) + self.top_conv = ops.Conv2d(384, 96, 1, padding=0, bias=True, **kw) + self.decoder_levels = nn.ModuleList( + nn.ModuleList(mk_block(i, m, o, s) for (i, m, o, s) in lvl) for lvl in _FR_DECODER_BLOCKS + ) + # 96→64 before 12→24, 64→48 before 24→48. + self.decoder_reduce_convs = nn.ModuleList([ + ops.Conv2d(96, 64, 1, padding=0, bias=True, **kw), + ops.Conv2d(64, 48, 1, padding=0, bias=True, **kw), + ]) + # Heads mix 2x2-cell info via DW-stride-2 + depth_to_space block_size=2. + self.cls_conv = ops.Conv2d(48, 4, 1, padding=0, bias=True, **kw) + self.cls_dw = ops.Conv2d(4, 4, 3, stride=2, padding=0, groups=4, bias=True, **kw) + self.reg_conv = ops.Conv2d(48, 64, 1, padding=0, bias=True, **kw) + self.reg_dw = ops.Conv2d(64, 64, 3, stride=2, padding=0, groups=64, bias=True, **kw) + + def forward(self, image_chw_normalized: Tensor) -> tuple[Tensor, Tensor]: + # Symmetric pad-1 throughout (full-range tflite uses explicit TF PAD, not SAME). + x = F.relu(self.stem(F.pad(image_chw_normalized, (1, 1, 1, 1)))) + tap_set = set(_FR_LATERAL_TAP_INDICES) + laterals: list[Tensor] = [] + for i, blk in enumerate(self.backbone): + x = blk(x) + if i in tap_set: + laterals.append(x) + + # top_conv / lateral_convs / decoder_reduce_convs all have fused ReLU in the tflite. + p = F.relu(self.top_conv(x)) + laterals_rev = list(reversed(laterals)) + lateral_convs_rev = list(reversed(self.lateral_convs)) + for level in range(len(self.decoder_levels)): + lateral = laterals_rev[level] + p = F.interpolate(p, size=lateral.shape[-2:], mode="bilinear", align_corners=False) + p = p + F.relu(lateral_convs_rev[level](lateral)) + for blk in self.decoder_levels[level]: + p = blk(p) + if level < len(self.decoder_reduce_convs): + p = F.relu(self.decoder_reduce_convs[level](p)) + + c = self.cls_dw(F.pad(self.cls_conv(p), (1, 1, 1, 1))) + c = _dcr_depth_to_space(c, r=2, c_out=1) + r = self.reg_dw(F.pad(self.reg_conv(p), (1, 1, 1, 1))) + r = _dcr_depth_to_space(r, r=2, c_out=16) + B = c.shape[0] + cls_out = c.permute(0, 2, 3, 1).reshape(B, _BF_FR_NUM_ANCHORS, 1) + reg_out = r.permute(0, 2, 3, 1).reshape(B, _BF_FR_NUM_ANCHORS, 16) + return reg_out, cls_out + + +@lru_cache(maxsize=1) +def _blazeface_full_range_anchors() -> np.ndarray: + """2304 anchors over 48x48; anchor_w=anchor_h=1 (fixed_anchor_size).""" + feat = _BF_FR_GRID + yy, xx = np.meshgrid(np.arange(feat, dtype=np.float32), np.arange(feat, dtype=np.float32), indexing="ij") + cx, cy, ones = (xx + 0.5) / feat, (yy + 0.5) / feat, np.ones_like(xx) + return np.stack([cx, cy, ones, ones], axis=-1).reshape(_BF_FR_NUM_ANCHORS, 4) + + +def _decode_blazeface_full_range(regressors: np.ndarray, classificators: np.ndarray, + score_thresh: float = _BF_MIN_SCORE) -> np.ndarray: + """Same decode as short-range with 2304-anchor grid and box_scale=192.""" + scores = expit(np.clip(classificators[:, 0], -_BF_FR_SCORE_CLIP, _BF_FR_SCORE_CLIP)) + keep = scores >= score_thresh + if not keep.any(): + return np.empty((0, 17), dtype=np.float32) + r = regressors[keep] / _BF_FR_BOX_SCALE + a = _blazeface_full_range_anchors()[keep] + cxs, cys, aws, ahs = a[:, 0:1], a[:, 1:2], a[:, 2:3], a[:, 3:4] + xc, yc = r[:, 0:1] * aws + cxs, r[:, 1:2] * ahs + cys + w, h = r[:, 2:3] * aws, r[:, 3:4] * ahs + out = np.empty((r.shape[0], 17), dtype=np.float32) + out[:, 0:1], out[:, 1:2], out[:, 2:3], out[:, 3:4] = xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2 + out[:, 4:16:2] = r[:, _BF_KP_OFFSET::2] * aws + cxs + out[:, 5:16:2] = r[:, _BF_KP_OFFSET + 1::2] * ahs + cys + out[:, 16] = scores[keep] + return out + + +# FaceMesh (face_landmarks_detector.tflite): PReLU variant of BlazeBlock, +# 17 blocks, heads for 478x3 landmarks + presence. +_FACEMESH_BLOCKS = [ # (in_ch, out_ch, stride) + (16, 16, 1), (16, 16, 1), (16, 32, 2), (32, 32, 1), (32, 32, 1), (32, 64, 2), + (64, 64, 1), (64, 64, 1), (64, 128, 2), (128, 128, 1), (128, 128, 1), (128, 128, 2), + (128, 128, 1), (128, 128, 1), (128, 128, 2), (128, 128, 1), (128, 128, 1), +] + + +class FaceMeshBlock(nn.Module): + """PReLU BlazeBlock: PReLU between DW and PW, and after the residual add.""" + + def __init__(self, in_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + self.in_ch, self.out_ch, self.stride = in_ch, out_ch, stride + self.depthwise = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, **kw) + self.prelu_dwise = nn.PReLU(num_parameters=in_ch, **kw) + self.pointwise = ops.Conv2d(in_ch, out_ch, 1, padding=0, bias=True, **kw) + self.prelu_out = nn.PReLU(num_parameters=out_ch, **kw) + + def forward(self, x: Tensor) -> Tensor: + residual = F.max_pool2d(x, 2, 2) if self.stride > 1 else x + if self.out_ch > self.in_ch: + residual = F.pad(residual, (0, 0, 0, 0, 0, self.out_ch - self.in_ch)) + x = _tf_same_pad(x, 3, self.stride) if self.stride > 1 else F.pad(x, (1, 1, 1, 1)) + return self.prelu_out(self.pointwise(self.prelu_dwise(self.depthwise(x))) + residual) + + +class FaceMesh(nn.Module): + NUM_LANDMARKS = 478 + + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + self.stem = ops.Conv2d(3, 16, 3, stride=2, padding=0, bias=True, **kw) + self.prelu_stem = nn.PReLU(num_parameters=16, **kw) + self.blocks = nn.ModuleList(FaceMeshBlock(i, o, s, device=device, dtype=dtype, operations=operations) + for (i, o, s) in _FACEMESH_BLOCKS) + self.head_reduce = ops.Conv2d(128, 8, 1, padding=0, bias=True, **kw) + self.prelu_head_reduce = nn.PReLU(num_parameters=8, **kw) + self.head_block = FaceMeshBlock(8, 8, 1, device=device, dtype=dtype, operations=operations) + self.head_presence = ops.Conv2d(8, 1, 3, padding=0, bias=True, **kw) + self.head_landmarks = ops.Conv2d(8, self.NUM_LANDMARKS * 3, 3, padding=0, bias=True, **kw) + + def forward(self, face_chw_normalized: Tensor) -> tuple[Tensor, Tensor]: + """(B, 3, 192, 192) in [0, 1] → ((B, 478, 3) landmarks in 192-canonical, (B,) presence).""" + x = self.prelu_stem(self.stem(_tf_same_pad(face_chw_normalized, 3, 2))) + for blk in self.blocks: + x = blk(x) + x = self.prelu_head_reduce(self.head_reduce(x)) + x = self.head_block(x) + B = x.shape[0] + presence = self.head_presence(x).reshape(B) + lmks = self.head_landmarks(x).reshape(B, self.NUM_LANDMARKS, 3) + return lmks, presence + + +# FaceBlendshapes (MLP-Mixer "GhumMarkerPoserMlpMixerGeneral"): +# 146x2 → token-reduce 146→96 → embed 2→64 → +cls token → 4x mixer → cls→52. +_BS_NUM_INPUT_LANDMARKS = 146 +_BS_NUM_TOKENS_REDUCED = 96 +_BS_NUM_TOKENS = 97 # +1 cls +_BS_TOKEN_DIM = 64 +_BS_TOKEN_MIX_HIDDEN = 384 +_BS_CHANNEL_MIX_HIDDEN = 256 +_BS_NUM_BLENDSHAPES = 52 +_BS_LN_EPS = 1e-6 + + +class MlpMixerBlock(nn.Module): + """MLP-Mixer block: token-mixing MLP (over tokens) → channel-mixing MLP (over dim). + Both pre-LN, both residual. LN has no beta (bias=False) to match MP.""" + + def __init__(self, num_tokens: int, token_dim: int, token_hidden: int, channel_hidden: int, + device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + # bias=False → no LN beta (matches MP). + self.ln1 = ops.LayerNorm(token_dim, eps=_BS_LN_EPS, bias=False, **kw) + self.ln2 = ops.LayerNorm(token_dim, eps=_BS_LN_EPS, bias=False, **kw) + self.token_mlp1 = ops.Linear(num_tokens, token_hidden, bias=True, **kw) + self.token_mlp2 = ops.Linear(token_hidden, num_tokens, bias=True, **kw) + self.channel_mlp1 = ops.Linear(token_dim, channel_hidden, bias=True, **kw) + self.channel_mlp2 = ops.Linear(channel_hidden, token_dim, bias=True, **kw) + + def forward(self, x: Tensor) -> Tensor: + y = self.ln1(x).transpose(1, 2) + x = x + self.token_mlp2(F.relu(self.token_mlp1(y))).transpose(1, 2) + return x + self.channel_mlp2(F.relu(self.channel_mlp1(self.ln2(x)))) + + +class FaceBlendshapes(nn.Module): + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + self.token_reduce = ops.Linear(_BS_NUM_INPUT_LANDMARKS, _BS_NUM_TOKENS_REDUCED, bias=True, **kw) + self.token_embed = ops.Linear(2, _BS_TOKEN_DIM, bias=True, **kw) + self.cls_token = nn.Parameter(torch.zeros(1, 1, _BS_TOKEN_DIM, **kw)) + self.blocks = nn.ModuleList( + MlpMixerBlock(_BS_NUM_TOKENS, _BS_TOKEN_DIM, _BS_TOKEN_MIX_HIDDEN, _BS_CHANNEL_MIX_HIDDEN, + device=device, dtype=dtype, operations=operations) for _ in range(4) + ) + self.head = ops.Linear(_BS_TOKEN_DIM, _BS_NUM_BLENDSHAPES, bias=True, **kw) + + @staticmethod + def _input_normalize(landmarks_2d: Tensor) -> Tensor: + # Centroid-subtract → L2 scale → x0.5. The 0.5 is baked into training. + centroid = landmarks_2d.mean(dim=1, keepdim=True) + x = landmarks_2d - centroid + mag = torch.sqrt((x * x).sum(dim=-1, keepdim=True)) + scale = mag.mean(dim=1, keepdim=True) + return (x / scale.clamp(min=1e-12)) * 0.5 + + def forward(self, landmarks_2d: Tensor) -> Tensor: + """(B, 146, 2) → (B, 52) in [0, 1]. Input units don't matter (centroid + L2 normalize).""" + x = self._input_normalize(landmarks_2d) + x = self.token_reduce(x.transpose(1, 2)).transpose(1, 2) + x = self.token_embed(x) + cls = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat([cls, x], dim=1) + for blk in self.blocks: + x = blk(x) + return torch.sigmoid(self.head(x[:, 0])) + + +@lru_cache(maxsize=1) +def _blazeface_anchors() -> np.ndarray: + """896 anchors per SsdAnchorsCalculator (fixed_anchor_size → anchor_w=anchor_h=1).""" + per_ar = len(_BF_ASPECT_RATIOS) + (1 if _BF_INTERP_SCALE_AR > 0 else 0) + layer_anchors: List[np.ndarray] = [] + layer = 0 + while layer < _BF_NUM_LAYERS: + stride = _BF_STRIDES[layer] + last = layer + while last < _BF_NUM_LAYERS and _BF_STRIDES[last] == stride: + last += 1 + per_cell = per_ar * (last - layer) + feat = (_BF_INPUT_SIZE + stride - 1) // stride + yy, xx = np.meshgrid(np.arange(feat, dtype=np.float32), np.arange(feat, dtype=np.float32), indexing="ij") + cx, cy, ones = (xx + _BF_ANCHOR_OFFSET_X) / feat, (yy + _BF_ANCHOR_OFFSET_Y) / feat, np.ones_like(xx) + cell = np.stack([cx, cy, ones, ones], axis=-1).reshape(-1, 4) + layer_anchors.append(np.repeat(cell, per_cell, axis=0)) + layer = last + out = np.concatenate(layer_anchors, axis=0) + assert out.shape == (896, 4), out.shape + return out + + +def _decode_blazeface(regressors: np.ndarray, classificators: np.ndarray, + score_thresh: float = _BF_MIN_SCORE) -> np.ndarray: + """Decode (regs (896,16), cls (896,1)) → (N, 17) = [xyxy, kp0x..kp5y, score] in [0, 1].""" + scores = expit(np.clip(classificators[:, 0], -_BF_SCORE_CLIP, _BF_SCORE_CLIP)) + keep = scores >= score_thresh + if not keep.any(): + return np.empty((0, 17), dtype=np.float32) + r = regressors[keep] / _BF_BOX_SCALE + a = _blazeface_anchors()[keep] # (N, 4) cx, cy, 1, 1 + cxs, cys, aws, ahs = a[:, 0:1], a[:, 1:2], a[:, 2:3], a[:, 3:4] + xc, yc = r[:, 0:1] * aws + cxs, r[:, 1:2] * ahs + cys + w, h = r[:, 2:3] * aws, r[:, 3:4] * ahs + out = np.empty((r.shape[0], 17), dtype=np.float32) + out[:, 0:1], out[:, 1:2], out[:, 2:3], out[:, 3:4] = xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2 + out[:, 4:16:2] = r[:, _BF_KP_OFFSET::2] * aws + cxs + out[:, 5:16:2] = r[:, _BF_KP_OFFSET + 1::2] * ahs + cys + out[:, 16] = scores[keep] + return out + + +def _weighted_nms(detections: np.ndarray, iou_thresh: float = 0.5) -> np.ndarray: + """MP weighted NMS — kept boxes are score-weighted averages of overlapping detections.""" + if detections.shape[0] == 0: + return detections + dets = detections[np.argsort(-detections[:, 16])] + N = dets.shape[0] + areas = np.clip(dets[:, 2] - dets[:, 0], 0, None) * np.clip(dets[:, 3] - dets[:, 1], 0, None) + kept: List[np.ndarray] = [] + used = np.zeros(N, dtype=bool) + for i in range(N): + if used[i]: + continue + ax1, ay1, ax2, ay2 = dets[i, 0:4] + merge_idx = [i] + for j in range(i + 1, N): + if used[j]: + continue + bx1, by1, bx2, by2 = dets[j, 0:4] + iw = max(0.0, min(ax2, bx2) - max(ax1, bx1)) + ih = max(0.0, min(ay2, by2) - max(ay1, by1)) + inter = iw * ih + union = areas[i] + areas[j] - inter + if union > 0 and inter / union > iou_thresh: # strict > matches MP + merge_idx.append(j) + used[j] = True + used[i] = True + cluster = dets[merge_idx] + ws = cluster[:, 16:17] + ws_sum = ws.sum() + merged = np.copy(cluster[0]) + if ws_sum > 0: + merged[:16] = (cluster[:, :16] * ws).sum(axis=0) / ws_sum + kept.append(merged) + return np.stack(kept, axis=0) if kept else np.empty((0, 17), dtype=np.float32) + + +def _detection_to_face_rect(detection: np.ndarray, image_w: int, image_h: int) -> Tuple[float, float, float, float, float]: + """Detection (normalized) → rotated 1.5xbbox ROI in image pixels (anisotropic).""" + xmin, ymin, xmax, ymax = detection[0:4] + lx = detection[4 + _FACE_LEFT_EYE_KP * 2 + 0] * image_w + ly = detection[4 + _FACE_LEFT_EYE_KP * 2 + 1] * image_h + rx = detection[4 + _FACE_RIGHT_EYE_KP * 2 + 0] * image_w + ry = detection[4 + _FACE_RIGHT_EYE_KP * 2 + 1] * image_h + # Image-y-down convention: angle = target - atan2(-dy, dx). + angle = _FACE_ROI_TARGET_ANGLE - math.atan2(ly - ry, rx - lx) + return (float((xmin + xmax) * 0.5 * image_w), + float((ymin + ymax) * 0.5 * image_h), + float((xmax - xmin) * image_w * _FACE_ROI_SCALE_X), + float((ymax - ymin) * image_h * _FACE_ROI_SCALE_Y), + float(angle)) + + +def _sample_warp(image_chw: Tensor, src_x: Tensor, src_y: Tensor, padding_mode: str) -> Tensor: + """Bilinear-sample image_chw at corner-aligned (src_x, src_y).""" + H, W = int(image_chw.shape[-2]), int(image_chw.shape[-1]) + grid = torch.stack([(2.0 * src_x + 1.0) / W - 1.0, + (2.0 * src_y + 1.0) / H - 1.0], dim=-1).unsqueeze(0) + return F.grid_sample(image_chw.unsqueeze(0), grid, mode="bilinear", + align_corners=False, padding_mode=padding_mode).squeeze(0) + + +def _warp_face_crop(image_chw: Tensor, cx: float, cy: float, width: float, height: float, + angle: float, output_size: int = _FM_INPUT_SIZE) -> Tensor: + """Rotated rect → output_size² with BORDER_REPLICATE. image_chw must be in [0, 1].""" + s_x, s_y = width / output_size, height / output_size + cos_a, sin_a = math.cos(angle), math.sin(angle) + arange = torch.arange(output_size, dtype=image_chw.dtype, device=image_chw.device) - output_size * 0.5 + v_grid, u_grid = torch.meshgrid(arange, arange, indexing="ij") + src_x = cx + u_grid * s_x * cos_a - v_grid * s_y * sin_a + src_y = cy + u_grid * s_x * sin_a + v_grid * s_y * cos_a + return _sample_warp(image_chw, src_x, src_y, "border") + + +def _blazeface_input_warp(image_chw_raw: Tensor, target: int = _BF_INPUT_SIZE) -> Tuple[Tensor, float, float, float]: + """Centered max(W,H) square → target² with BORDER_ZERO + [-1, 1] norm. + + Sub-pixel grid_sample matters; integer-pad-then-resize drifts the bbox ~5%. + Returns (warped, sub_rect_cx, sub_rect_cy, sub_rect_size) — the triplet maps + tensor-normalized [0,1] detections back to image pixels. + """ + H, W = int(image_chw_raw.shape[1]), int(image_chw_raw.shape[2]) + sub_rect_size = float(max(W, H)) + sub_rect_cx, sub_rect_cy = W * 0.5, H * 0.5 + s = sub_rect_size / target + arange = torch.arange(target, dtype=image_chw_raw.dtype, device=image_chw_raw.device) - target * 0.5 + v_grid, u_grid = torch.meshgrid(arange, arange, indexing="ij") + out = _sample_warp(image_chw_raw, sub_rect_cx + u_grid * s, sub_rect_cy + v_grid * s, "zeros") + return (out / 127.5) - 1.0, sub_rect_cx, sub_rect_cy, sub_rect_size + + +class FaceLandmarker(nn.Module): + """BlazeFace → FaceMesh v2 → blendshapes. `detector_variant` selects 'short' + (128², ≤2m) or 'full' (192² FPN, ≤5m). State dict uses inner-module prefixes + `detector.*` / `mesh.*` / `blendshapes.*`; the outer FaceLandmarkerModel + wrapper rewrites `detector_{variant}.*` keys to `detector.*` before loading. + """ + + def __init__(self, device=None, dtype=None, operations=None, detector_variant: str = "short"): + super().__init__() + det_cls = {"short": BlazeFace, "full": BlazeFaceFullRange}.get(detector_variant) + + self.detector_variant = detector_variant + self.detector = det_cls(device=device, dtype=dtype, operations=operations) + self.mesh = FaceMesh(device=device, dtype=dtype, operations=operations) + self.blendshapes = FaceBlendshapes(device=device, dtype=dtype, operations=operations) + self.register_buffer("_bs_idx", torch.tensor(_BS_INPUT_INDICES, dtype=torch.long), persistent=False) + + def run_detector_batch(self, images_rgb_uint8: List[np.ndarray], + score_thresh: float = _BF_MIN_SCORE, + iou_thresh: float = 0.5): + """Batched detector pass. Returns (img_raws, sub_rects, sizes, per_frame_decoded) + where per_frame_decoded[b] is (N, 17) in tensor-normalized [0,1] coords.""" + if not images_rgb_uint8: + return [], [], [], [] + device, dtype = self.detector.stem.weight.device, self.detector.stem.weight.dtype + det_input_size, decode_fn = ((_BF_FR_INPUT_SIZE, _decode_blazeface_full_range) + if self.detector_variant == "full" + else (_BF_INPUT_SIZE, _decode_blazeface)) + + # Same-size frames: stack once and transfer once. Variable size falls back + # to per-image (only triggers for SAM3DBody's head crops). + sizes = [tuple(img.shape[:2]) for img in images_rgb_uint8] + if len(set(sizes)) == 1: + batch_chw = torch.from_numpy(np.stack(images_rgb_uint8, axis=0)).to(device, dtype).movedim(-1, -3).contiguous() + img_raws = [batch_chw[bi] for bi in range(batch_chw.shape[0])] + else: + img_raws = [torch.from_numpy(img).to(device, dtype).movedim(-1, -3).contiguous() for img in images_rgb_uint8] + + warps = [_blazeface_input_warp(img_raw, det_input_size) for img_raw in img_raws] + det_crops = [w[0] for w in warps] + sub_rects = [(w[1], w[2], w[3]) for w in warps] + + regs_b, cls_b = self.detector(torch.stack(det_crops, dim=0)) + regs_np, cls_np = regs_b.float().cpu().numpy(), cls_b.float().cpu().numpy() + per_frame = [] + for b in range(len(images_rgb_uint8)): + decoded = decode_fn(regs_np[b], cls_np[b], score_thresh=score_thresh) + per_frame.append(_weighted_nms(decoded, iou_thresh=iou_thresh) if decoded.shape[0] > 0 else decoded) + return img_raws, sub_rects, sizes, per_frame + + def detect_batch(self, images_rgb_uint8: List[np.ndarray], num_faces: int = 1, + score_thresh: float = _BF_MIN_SCORE) -> List[List[dict]]: + """Full pipeline batched across `images_rgb_uint8`. Returns one face-dict + list per image (empty if nothing detected). Face dict: + bbox_xyxy (4,) image pixels, blendshapes {52} ∈ [0,1], + landmarks_xy (478, 2) image pixels, landmarks_3d (478, 3) in + 192-canonical (pre-transformation) units, presence float (raw logit). + """ + img_raws, sub_rects, sizes, per_frame_dets = self.run_detector_batch( + images_rgb_uint8, score_thresh=score_thresh, + ) + # tensor-normalized → image-normalized [0,1] for _detection_to_face_rect. + for b, decoded in enumerate(per_frame_dets): + if decoded.shape[0] == 0: + continue + cx, cy, size = sub_rects[b] + H, W = sizes[b] + sx0, sy0 = cx - size * 0.5, cy - size * 0.5 + decoded[:, 0:16:2] = (sx0 + size * decoded[:, 0:16:2]) / W + decoded[:, 1:16:2] = (sy0 + size * decoded[:, 1:16:2]) / H + if num_faces > 0: + per_frame_dets[b] = decoded[: int(num_faces)] + + # Collect every detected face across all frames into one mesh input. + face_params: List[Tuple[int, float, float, float, float, float, float]] = [] + mesh_crops: List[Tensor] = [] + for b, dets in enumerate(per_frame_dets): + if dets.shape[0] == 0: + continue + H, W = sizes[b] + img_for_mesh = img_raws[b] / 255.0 + for det in dets: + cx, cy, w, h, angle = _detection_to_face_rect(det, W, H) + mesh_crops.append(_warp_face_crop(img_for_mesh, cx, cy, w, h, angle, _FM_INPUT_SIZE)) + face_params.append((b, float(det[16]), cx, cy, w, h, angle)) + + results: List[List[dict]] = [[] for _ in range(len(images_rgb_uint8))] + if not mesh_crops: + return results + + lmks_canon_b, presence_b = self.mesh(torch.stack(mesh_crops, dim=0)) + bs_out_b = self.blendshapes(lmks_canon_b[:, self._bs_idx, :2]) + + # Batched canonical→image affine + params_t = torch.tensor( + [(cx, cy, w, h, math.cos(a), math.sin(a)) for (_b, _s, cx, cy, w, h, a) in face_params], + device=lmks_canon_b.device, dtype=lmks_canon_b.dtype, + ) + cxs, cys, ws, hs, cos_a, sin_a = params_t.unbind(dim=1) + inv = 1.0 / _FM_INPUT_SIZE + u = lmks_canon_b[..., 0] - _FM_INPUT_SIZE * 0.5 + v = lmks_canon_b[..., 1] - _FM_INPUT_SIZE * 0.5 + lmks_xy_t = torch.stack([ + cxs[:, None] + u * (ws * inv * cos_a)[:, None] - v * (hs * inv * sin_a)[:, None], + cys[:, None] + u * (ws * inv * sin_a)[:, None] + v * (hs * inv * cos_a)[:, None], + ], dim=-1) + + lmks_xy_np = lmks_xy_t.float().cpu().numpy() + lmks_canon_np = lmks_canon_b.float().cpu().numpy() + presence_np = presence_b.float().cpu().numpy() + bs_np = bs_out_b.float().cpu().numpy() + + for i, (b, score, *_) in enumerate(face_params): + lmks_xy = lmks_xy_np[i] + mn, mx = lmks_xy.min(0), lmks_xy.max(0) + results[b].append({ + "bbox_xyxy": np.array([mn[0], mn[1], mx[0], mx[1]], dtype=np.float32), + "blendshapes": dict(zip(BLENDSHAPE_NAMES, bs_np[i].tolist())), + "landmarks_xy": lmks_xy, + "landmarks_3d": lmks_canon_np[i], + "presence": float(presence_np[i]), + "score": score, + }) + return results diff --git a/comfy_extras/nodes_mediapipe.py b/comfy_extras/nodes_mediapipe.py new file mode 100644 index 000000000..2e67ae83f --- /dev/null +++ b/comfy_extras/nodes_mediapipe.py @@ -0,0 +1,502 @@ +"""ComfyUI nodes for the pure-PyTorch MediaPipe Face Landmarker port. + +Custom IO types: + FACE_LANDMARKER — FaceLandmarkerModel wrapper (ModelPatcher inside) + FACE_LANDMARKS — {"frames": List[List[face_dict]], "image_size": (H, W), + "connection_sets": dict[str, frozenset[(int, int)]]} + face_dict: bbox_xyxy, blendshapes, landmarks_xy, + landmarks_3d, presence, score, transformation_matrix + +MediaPipeFaceLandmarker also emits the core BOUNDING_BOX type — pair with DrawBBoxes. +""" + +from __future__ import annotations + +import numpy as np +import torch +from PIL import Image, ImageColor, ImageDraw +from tqdm.auto import tqdm +from typing_extensions import override + +import comfy.model_management +import comfy.model_patcher +import comfy.utils +import folder_paths +from comfy_api.latest import ComfyExtension, io + +from comfy_extras.mediapipe.face_landmarker import FaceLandmarker +from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection + + +FaceLandmarkerType = io.Custom("FACE_LANDMARKER") +FaceLandmarksType = io.Custom("FACE_LANDMARKS") + +_CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights") +_CONTOUR_PARTS = ("face_oval", "left_eye", "right_eye", "left_eyebrow", "right_eyebrow", "lips") + + +class FaceLandmarkerModel: + """Loaded FaceLandmarker variants + ModelPatcher per variant. + + Safetensors layout: `detector_short.*` / `detector_full.*` plus shared + `mesh.*`, `blendshapes.*`, `canonical_*`, and `topology.*`. + PReLU forces plain-nn / fp32 (manual_cast strands buffers across devices). + """ + + def __init__(self, state_dict: dict): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + self.dtype = torch.float32 + + # FACEMESH_* connection sets, embedded as int32 (N, 2) under topology.*. + base: dict[str, frozenset] = {} + for k in [k for k in state_dict if k.startswith("topology.")]: + base[k[len("topology."):]] = frozenset(map(tuple, state_dict.pop(k).tolist())) + base["contours"] = frozenset().union(*(base[p] for p in _CONTOUR_PARTS)) + base["all"] = base["contours"] | base["irises"] | base["nose"] + + self.connection_sets: dict[str, frozenset] = base + self.canonical_data: dict[str, np.ndarray] = {k: state_dict.pop(k).numpy() for k in _CANONICAL_KEYS} + + shared = {k: v for k, v in state_dict.items() if k.startswith(("mesh.", "blendshapes."))} + + self.models: dict[str, FaceLandmarker] = {} + self.patchers: dict[str, comfy.model_patcher.ModelPatcher] = {} + for variant in ("short", "full"): + prefix = f"detector_{variant}." + sub = dict(shared) + sub.update({f"detector.{k[len(prefix):]}": v for k, v in state_dict.items() if k.startswith(prefix)}) + fl = FaceLandmarker(device=offload_device, dtype=self.dtype, operations=None, detector_variant=variant).eval() + fl.load_state_dict(sub, strict=False) + + self.models[variant] = fl + self.patchers[variant] = comfy.model_patcher.CoreModelPatcher( + fl, load_device=self.load_device, offload_device=offload_device, + size=comfy.model_management.module_size(fl), + ) + + def detect_batch(self, images, num_faces: int, score_thresh: float, variant: str): + comfy.model_management.load_model_gpu(self.patchers[variant]) + return self.models[variant].detect_batch(images, num_faces=num_faces, score_thresh=score_thresh) + + +def _image_to_uint8(image: torch.Tensor) -> np.ndarray: + return image[..., :3].mul(255.0).add_(0.5).clamp_(0, 255).to(torch.uint8).cpu().numpy() + + +def _parse_color(color: str) -> tuple[int, int, int]: + try: + return ImageColor.getrgb(color)[:3] + except ValueError: + return (0, 255, 0) + + +def _copy_face(face: dict) -> dict: + """Shallow copy of a face_dict with array-fields cloned so callers can mutate.""" + return { + "bbox_xyxy": face["bbox_xyxy"].copy(), + "blendshapes": dict(face["blendshapes"]), + "landmarks_xy": face["landmarks_xy"].copy(), + "landmarks_3d": face["landmarks_3d"].copy(), + "presence": face["presence"], + "score": face["score"], + } + + +def _lerp_face(a: dict, b: dict, t: float) -> dict: + return { + "bbox_xyxy": (1 - t) * a["bbox_xyxy"] + t * b["bbox_xyxy"], + "blendshapes": {k: (1 - t) * a["blendshapes"][k] + t * b["blendshapes"][k] for k in a["blendshapes"]}, + "landmarks_xy": (1 - t) * a["landmarks_xy"] + t * b["landmarks_xy"], + "landmarks_3d": (1 - t) * a["landmarks_3d"] + t * b["landmarks_3d"], + "presence": (1 - t) * a["presence"] + t * b["presence"], + "score": (1 - t) * a["score"] + t * b["score"], + } + + +def _match_faces(a: list[dict], b: list[dict]) -> list[tuple[int, int]]: + """Greedy nearest-neighbour pairing of faces between two frames by bbox + centre distance. Unmatched (when counts differ) are dropped.""" + if not a or not b: + return [] + centers_a = np.array([(0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]), + 0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3])) for f in a]) + centers_b = np.array([(0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]), + 0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3])) for f in b]) + dists = np.linalg.norm(centers_a[:, None] - centers_b[None], axis=-1) + pairs: list[tuple[int, int]] = [] + used_a: set[int] = set() + used_b: set[int] = set() + candidates = sorted((dists[ia, ib], ia, ib) for ia in range(len(a)) for ib in range(len(b))) + for _, ia, ib in candidates: + if ia in used_a or ib in used_b: + continue + pairs.append((ia, ib)) + used_a.add(ia) + used_b.add(ib) + return pairs + + +def _fill_missing_frames(frames: list[list[dict]], mode: str) -> None: + """In-place fill empty frame slots from neighbouring detections. Multi-face + aware: pairs faces across bracketing frames by greedy bbox-centre NN. + When counts differ, unmatched faces are dropped from the synthesised frame.""" + if mode == "empty": + return + valid = [i for i, fr in enumerate(frames) if fr] + if not valid: + return # nothing to fill from + if mode == "previous": + last: list[dict] = [] + for i, fr in enumerate(frames): + if fr: + last = fr + elif last: + frames[i] = [_copy_face(f) for f in last] + return + # interpolate: lerp between bracketing valid frames; clamp at ends. + for i in range(len(frames)): + if frames[i]: + continue + prev_i = max((v for v in valid if v < i), default=None) + next_i = min((v for v in valid if v > i), default=None) + if prev_i is None: + frames[i] = [_copy_face(f) for f in frames[next_i]] + elif next_i is None: + frames[i] = [_copy_face(f) for f in frames[prev_i]] + else: + t = (i - prev_i) / (next_i - prev_i) + pairs = _match_faces(frames[prev_i], frames[next_i]) + frames[i] = [_lerp_face(frames[prev_i][a], frames[next_i][b], t) for a, b in pairs] + + +def _ordered_rings(edges: frozenset[tuple[int, int]]) -> list[list[int]]: + """Walk an unordered edge set into one or more closed-loop vertex rings + (handles multi-loop sets like FACEMESH_LIPS: outer + inner).""" + adj: dict[int, set[int]] = {} + for a, b in edges: + adj.setdefault(a, set()).add(b) + adj.setdefault(b, set()).add(a) + visited: set[int] = set() + rings: list[list[int]] = [] + for start in adj: + if start in visited: + continue + ring = [start] + visited.add(start) + prev, cur = -1, start + while True: + nxt = next((v for v in adj[cur] if v != prev), None) + if nxt is None or nxt == start: + break + ring.append(nxt) + visited.add(nxt) + prev, cur = cur, nxt + rings.append(ring) + return rings + + +class LoadMediaPipeFaceLandmarker(io.ComfyNode): + """Load MediaPipe Face Landmarker v2 weights. Contains both detector variants + (short / full), shared mesh, blendshapes, and canonical geometry.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadMediaPipeFaceLandmarker", + display_name="Load MediaPipe Face Landmarker", + category="loaders", + inputs=[ + io.Combo.Input("model_name", options=folder_paths.get_filename_list("mediapipe"), + tooltip="Face Landmarker safetensors from models/mediapipe/."), + ], + outputs=[FaceLandmarkerType.Output()], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("mediapipe", model_name), safe_load=True) + wrapper = FaceLandmarkerModel(sd) + return io.NodeOutput(wrapper) + + +# Per-frame fallback modes for detection failures in a batch. +_FALLBACK_MODES = ("empty", "previous", "interpolate") + + +class MediaPipeFaceLandmarker(io.ComfyNode): + """BlazeFace → FaceMesh v2 → ARKit-52 blendshapes, batched across the + input. Also emits a BOUNDING_BOX list (landmark-extent bbox per face) — + pair with DrawBBoxes for detector-only viz or MediaPipeFaceMeshVisualize + for the mesh overlay.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MediaPipeFaceLandmarker", + display_name="MediaPipe Face Landmarker", + category="image/detection", + inputs=[ + FaceLandmarkerType.Input("face_landmarker"), + io.Image.Input("image"), + io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short", + tooltip="Face detector range. 'short' is tuned for close-up faces " + "(within ~2 m of the camera); 'full' covers farther / smaller " + "faces (up to ~5 m) but is slower. 'both' runs both detectors and " + "keeps whichever found more faces per frame (~2× detection cost)."), + io.Int.Input("num_faces", default=1, min=0, max=16, step=1, + tooltip="Maximum faces to return per frame. 0 = no cap (return all detected)."), + io.Float.Input("min_confidence", default=0.5, min=0.0, max=1.0, step=0.01, advanced=True, + tooltip="BlazeFace score threshold. Lower to catch small/occluded faces."), + io.Combo.Input("missing_frame_fallback", options=list(_FALLBACK_MODES), default="empty", advanced=True, + tooltip="Per-frame behaviour when detection fails in a batch. " + "'empty' leaves the frame faceless. 'previous' copies the most recent successful " + "detection. 'interpolate' lerps landmarks/bbox/blendshapes between bracketing " + "successful frames. Multi-face: pairs faces across frames by greedy bbox-centre NN."), + ], + outputs=[ + FaceLandmarksType.Output(display_name="face_landmarks"), + io.BoundingBox.Output("bboxes"), + ], + ) + + @classmethod + def execute(cls, face_landmarker, image, detector_variant, num_faces, min_confidence, + missing_frame_fallback) -> io.NodeOutput: + canonical = face_landmarker.canonical_data + img_np = _image_to_uint8(image) + B, H, W = img_np.shape[:3] + chunk = 16 + is_both = detector_variant == "both" + total_work = 2 * B if is_both else B + pbar = comfy.utils.ProgressBar(total_work) + + def _run(variant: str) -> list[list[dict]]: + res: list[list[dict]] = [] + with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq: + for i in range(0, B, chunk): + end = min(i + chunk, B) + res.extend(face_landmarker.detect_batch( + [img_np[bi] for bi in range(i, end)], + num_faces=int(num_faces), + score_thresh=float(min_confidence), + variant=variant, + )) + pbar.update_absolute(min(pbar.current + (end - i), total_work)) + tq.update(end - i) + return res + + if is_both: + short_res = _run("short") + full_res = _run("full") + # Per-frame keep whichever found more faces (tie → short). + frames: list[list[dict]] = [ + short_res[bi] if len(short_res[bi]) >= len(full_res[bi]) else full_res[bi] + for bi in range(B) + ] + else: + frames = _run(detector_variant) + _fill_missing_frames(frames, missing_frame_fallback) + bboxes = [] + for per_frame in frames: + per_bb = [] + for f in per_frame: + f["transformation_matrix"] = transformation_matrix_from_detection(f, W, H, canonical) + x1, y1, x2, y2 = (float(v) for v in f["bbox_xyxy"]) + per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])}) + bboxes.append(per_bb) + return io.NodeOutput({"frames": frames, "image_size": (H, W), + "connection_sets": face_landmarker.connection_sets}, bboxes) + + +# Topology keys unioned by the 'all' connections preset (contour parts + irises + nose). +_ALL_CONNECTION_PARTS: tuple[str, ...] = (*_CONTOUR_PARTS, "irises", "nose") +_CUSTOM_FEATURES: tuple[tuple[str, bool], ...] = ( + ("face_oval", True), + ("lips", True), + ("left_eye", True), + ("right_eye", True), + ("left_eyebrow", True), + ("right_eyebrow", True), + ("irises", True), + ("nose", True), + ("tesselation", False), +) + + +class MediaPipeFaceMeshVisualize(io.ComfyNode): + """Draw a FACEMESH_* subset over an image. Topology travels with the + FACE_LANDMARKS payload (set at detection time).""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MediaPipeFaceMeshVisualize", + display_name="MediaPipe Face Mesh Visualize", + category="image/detection", + inputs=[ + FaceLandmarksType.Input("face_landmarks"), + io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."), + io.DynamicCombo.Input( + "connections", + tooltip="'all' = oval+eyes+brows+lips+irises+nose. 'fill' = solid face_oval polygon (silhouette mask). 'custom' = toggle each feature individually (including 'tesselation', the full 2547-edge wireframe).", + options=[ + io.DynamicCombo.Option("all", []), + io.DynamicCombo.Option("fill", []), + io.DynamicCombo.Option("custom", [ + io.Boolean.Input(feat, default=default, + tooltip=f"Draw the '{feat}' connection set.") + for feat, default in _CUSTOM_FEATURES + ]), + ], + ), + io.Color.Input("color", default="#00ff00"), + io.Int.Input("thickness", default=1, min=0, max=8, step=1, + tooltip="Edge line thickness in pixels. 0 disables edge drawing."), + io.Int.Input("point_size", default=2, min=0, max=16, step=1, + tooltip="Landmark dot radius in pixels. 0 disables point drawing."), + ], + outputs=[io.Image.Output()], + ) + + @classmethod + def execute(cls, face_landmarks, connections, color, thickness, point_size, image=None) -> io.NodeOutput: + sets = face_landmarks["connection_sets"] + sel = connections["connections"] + fill_rings: list[list[int]] | None = None + if sel == "fill": + fill_rings = _ordered_rings(sets["face_oval"]) + edges = frozenset() + elif sel == "custom": + parts = [feat for feat, _ in _CUSTOM_FEATURES if connections.get(feat, False)] + edges = frozenset().union(*(sets[p] for p in parts)) + else: # "all" + edges = frozenset().union(*(sets[p] for p in _ALL_CONNECTION_PARTS)) + rgb, thick, psize = _parse_color(color), int(thickness), int(point_size) + frames = face_landmarks["frames"] + if image is None: + H, W = face_landmarks["image_size"] + img_np = np.zeros((len(frames), H, W, 3), dtype=np.uint8) + else: + img_np = _image_to_uint8(image) + B = img_np.shape[0] + n_frames = len(frames) + pbar = comfy.utils.ProgressBar(B) + out = np.empty_like(img_np) + for bi in range(B): + faces = frames[bi] if bi < n_frames else [] + out[bi] = _draw_mesh(img_np[bi], faces, edges, rgb, thick, psize, fill_rings) + pbar.update_absolute(bi + 1) + return io.NodeOutput(torch.from_numpy(out).to( + device=comfy.model_management.intermediate_device(), + dtype=comfy.model_management.intermediate_dtype(), + ).div_(255.0)) + + +def _draw_mesh(image_rgb: np.ndarray, faces: list, edges, + rgb: tuple[int, int, int], thickness: int, + point_size: int, fill_rings: list[list[int]] | None = None) -> np.ndarray: + draw_edges = thickness > 0 and edges + if not faces or (fill_rings is None and not draw_edges and point_size <= 0): + return image_rgb.copy() + pil = Image.fromarray(image_rgb) + draw = ImageDraw.Draw(pil) + r = point_size * 0.5 + if fill_rings is not None: + for f in faces: + lmks = f["landmarks_xy"] + for ring in fill_rings: + draw.polygon([(float(lmks[i, 0]), float(lmks[i, 1])) for i in ring], fill=rgb) + return np.asarray(pil) + for f in faces: + lmks = f["landmarks_xy"] + n = lmks.shape[0] + if draw_edges: + for a, b in edges: + if a < n and b < n: + draw.line([(float(lmks[a, 0]), float(lmks[a, 1])), + (float(lmks[b, 0]), float(lmks[b, 1]))], fill=rgb, width=thickness) + if point_size == 1: + draw.point(lmks.flatten().tolist(), fill=rgb) + elif point_size > 1: + for x, y in lmks: + draw.ellipse((float(x) - r, float(y) - r, float(x) + r, float(y) + r), fill=rgb) + return np.asarray(pil) + + +# Mask region presets — closed-loop topologies only. +_MASK_REGIONS: tuple[str, ...] = ("face_oval", "lips", "left_eye", "right_eye", "irises") +_MASK_CUSTOM_FEATURES: tuple[tuple[str, bool], ...] = ( + ("face_oval", True), + ("lips", False), + ("left_eye", False), + ("right_eye", False), + ("irises", False), +) + + +class MediaPipeFaceMask(io.ComfyNode): + """Binary mask from face landmarks, filled polygon per face. One mask per + frame in the batch; faces in the same frame composite (union).""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MediaPipeFaceMask", + display_name="MediaPipe Face Mask", + category="image/detection", + inputs=[ + FaceLandmarksType.Input("face_landmarks"), + io.DynamicCombo.Input( + "regions", + tooltip="'all' = union of face_oval+lips+eyes+irises (which collapses to face_oval since it encloses the rest). 'custom' = toggle each region individually for combos like lips+eyes.", + options=[ + io.DynamicCombo.Option("all", []), + io.DynamicCombo.Option("custom", [ + io.Boolean.Input(reg, default=default, + tooltip=f"Include the '{reg}' region in the mask.") + for reg, default in _MASK_CUSTOM_FEATURES + ]), + ], + ), + ], + outputs=[io.Mask.Output()], + ) + + @classmethod + def execute(cls, face_landmarks, regions) -> io.NodeOutput: + sets = face_landmarks["connection_sets"] + sel = regions["regions"] + if sel == "custom": + picked = [reg for reg, _ in _MASK_CUSTOM_FEATURES if regions.get(reg, False)] + else: + picked = list(_MASK_REGIONS) + rings = [r for reg in picked for r in _ordered_rings(sets[reg])] + frames = face_landmarks["frames"] + H, W = face_landmarks["image_size"] + masks = np.zeros((len(frames), H, W), dtype=np.uint8) + pbar = comfy.utils.ProgressBar(len(frames)) + for bi, per_frame in enumerate(frames): + if per_frame: + pil = Image.new("L", (W, H), 0) + draw = ImageDraw.Draw(pil) + for f in per_frame: + lmks = f["landmarks_xy"] + for ring in rings: + draw.polygon([(float(lmks[i, 0]), float(lmks[i, 1])) for i in ring], fill=255) + masks[bi] = np.asarray(pil) + pbar.update_absolute(bi + 1) + return io.NodeOutput(torch.from_numpy(masks).to( + device=comfy.model_management.intermediate_device(), + dtype=comfy.model_management.intermediate_dtype(), + ).div_(255.0)) + + +class MediaPipeFaceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [LoadMediaPipeFaceLandmarker, MediaPipeFaceLandmarker, MediaPipeFaceMeshVisualize, MediaPipeFaceMask] + + +async def comfy_entrypoint() -> MediaPipeFaceExtension: + return MediaPipeFaceExtension() diff --git a/folder_paths.py b/folder_paths.py index ad7f0f4fc..ce152eb37 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -60,6 +60,8 @@ folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geom folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions) +folder_names_and_paths["mediapipe"] = ([os.path.join(models_dir, "mediapipe")], supported_pt_extensions) + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") diff --git a/models/mediapipe/put_mediapipe_models_here b/models/mediapipe/put_mediapipe_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index fdd6eeb5f..13e46ac8a 100644 --- a/nodes.py +++ b/nodes.py @@ -2444,6 +2444,7 @@ async def init_builtin_extra_nodes(): "nodes_hidream_o1.py", "nodes_save_3d.py", "nodes_moge.py", + "nodes_mediapipe.py", ] import_failed = [] From 5aa5ccc9e02aec94cf43e0f71d4b2f62b204b5b6 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 21 May 2026 10:03:58 +1000 Subject: [PATCH 6/7] Multi-threaded load of models from disk (big load time speedups & Offload to disk) (CORE-43,CORE-152,CORE-164,CORE-165,CORE-117) (#13802) * model_management: disable non-dynamic smart memory Disable smart memory outright for non dynamic models. This is a minor step towards deprecation of --disable-dynamic-vram and the legacy ModelPatcher. This is needed for estimate-free model development, where new models can opt-out of supplying a memory estimate and not have to worry about hard VRAM allocations due to legacy non-dynamic model patchers This is also a general stability increase for a lot of stray use cases where estimates may still be off and going forward we are not going to accurately maintain such estimates. * pinned_memory: implement with aimdo growable buffer Use a single growable buffer so we can do threaded pre-warming on pinned memory. * mm: use aimdo to do transfer from disk to pin Aimdo implements a faster threaded loader. * Add stream host pin buffer for AIMDO casts Introduce per-offload-stream HostBuffer reuse for pinned staging, include it in cast buffer reset synchronization. Defer actual casts that go via this pin path to a separate pass such that the buffer can be allocated monolithically (to avoid cudaHostRegister thrash). * remove old pin path * Implement JIT pinned memory pressure Replace the predictive pin pressure mechanism with JIT PIN memory pressure. * LowVRAMPatch: change to two-phase visit * lora: re-implement as inplace swiss-army-knife operation * prepare for multiple pin sets * implement pinned loras * requirements: comfy-aimdo 0.4.0 * ops: remove unused arg This was defeatured in aimdo iteration * ops: sync the CPU with only the offload stream activity This was syncing with the offload stream which itself is synced with the compute stream, so this was syncing CPU with compute transitively. Define the event to sync it more gently. * pins: implement freeing intermediate for pinned memory Pinning is more important than inactive intermediates and the stream pin buffer is more important than even active intermediates. * execution: implement pin eviction on RAM presure Add back proper pin freeing on RAM pressure * implement pin registration swaps Uncap the windows pins from 50% by extending the pool and have a pressure mechanism to move the pin reservations om demand. This unfortunately implies a GPU sync to do the freeing so significant hysterisis needs to be added to consolidate these pressure events. * cli_args/execution: Implement lower background cache-ram threshold Limit the amount of RAM background intermediates can use, so that switching workflows doesn't degrade performance too much. * make default * bump aimdo * model-patcher: force-cast tiny weights Flux 2 gets crazy stalls due to a mix of tiny and giant weights creating lopsided steam buffer rotations which creates stalls. * ops: refactor in prep for chunking * mm: delegate pin-on-the-way to aimdo Aimdo is able to chunk and slice this on the way for better CPU->GPU overlap. The main advantage is the ability to shorten the bus contention window between previous weight transfer and the next weights vbar fault. * bump aimdo * pinning updates * specify hostbuf max allocation size There a signs of virtual memory exhaustion on some linux systems when throwing 128GB for every little piece. Pass the actual to save aimdo from over-estimates * tests: update execution tests for caching The default caching changed to ram-cache so update these tests accordingly. Remove the LRU 0 test as this also falls through to RAM cache. --- comfy/cli_args.py | 7 +- comfy/lora.py | 19 ++- comfy/memory_management.py | 24 +++- comfy/model_management.py | 189 +++++++++++++++++----------- comfy/model_patcher.py | 138 +++++++++++++++----- comfy/ops.py | 88 +++++++++++-- comfy/pinned_memory.py | 68 ++++++---- comfy/utils.py | 2 - comfy/windows.py | 52 -------- execution.py | 12 +- main.py | 20 +-- requirements.txt | 2 +- tests/execution/test_async_nodes.py | 3 +- tests/execution/test_execution.py | 3 +- 14 files changed, 408 insertions(+), 219 deletions(-) delete mode 100644 comfy/windows.py diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 76faed3ad..9d88c8517 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -110,13 +110,11 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") -CACHE_RAM_AUTO_GB = -1.0 - cache_group = parser.add_mutually_exclusive_group() +cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metavar="GB", help="Use RAM pressure caching with the specified headroom thresholds. This is the default caching mode. The first value sets the active-cache threshold; the optional second value sets the inactive-cache/pin threshold. Defaults when no values are provided: active 25%% of system RAM (min 4GB, max 32GB), inactive 75%% of system RAM (min 12GB, max 96GB).") cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") -cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") @@ -245,6 +243,9 @@ if comfy.options.args_parsing: else: args = parser.parse_args([]) +if args.cache_ram is not None and len(args.cache_ram) > 2: + parser.error("--cache-ram accepts at most two values: active GB and inactive GB") + if args.windows_standalone_build: args.auto_launch = True diff --git a/comfy/lora.py b/comfy/lora.py index f11e26ec9..c0e8b865c 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -484,16 +484,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori return weight -def prefetch_prepared_value(value, allocate_buffer, stream): +def prefetch_prepared_value(value, counter, destination, stream, copy): if isinstance(value, torch.Tensor): - dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value)) - comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) + size = comfy.memory_management.vram_aligned_size(value) + offset = counter[0] + counter[0] += size + if destination is None: + return value + + dest = destination[offset:offset + size] + if copy: + comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) return comfy.memory_management.interpret_gathered_like([value], dest)[0] elif isinstance(value, weight_adapter.WeightAdapterBase): - return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream)) + return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy)) elif isinstance(value, tuple): - return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value) + return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value) elif isinstance(value, list): - return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value] + return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value] return value diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 48e3c11da..c43f0c4a2 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -15,7 +15,7 @@ class TensorFileSlice(NamedTuple): size: int -def read_tensor_file_slice_into(tensor, destination): +def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None): if isinstance(tensor, QuantizedTensor): if not isinstance(destination, QuantizedTensor): @@ -23,12 +23,17 @@ def read_tensor_file_slice_into(tensor, destination): if tensor._layout_cls != destination._layout_cls: return False - if not read_tensor_file_slice_into(tensor._qdata, destination._qdata): + if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream, + destination2=(destination2._qdata if destination2 is not None else None)): return False dst_orig_dtype = destination._params.orig_dtype destination._params.copy_from(tensor._params, non_blocking=False) destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype) + if destination2 is not None: + dst_orig_dtype = destination2._params.orig_dtype + destination2._params.copy_from(destination._params, non_blocking=True) + destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype) return True info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None) @@ -48,6 +53,17 @@ def read_tensor_file_slice_into(tensor, destination): if info.size == 0: return True + hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None) + if hostbuf is not None: + stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0 + device_ptr = destination2.data_ptr() if destination2 is not None else 0 + hostbuf.read_file_slice(file_obj, info.offset, info.size, + offset=destination.data_ptr() - hostbuf.get_raw_address(), + stream=stream_ptr, + device_ptr=device_ptr, + device=None if destination2 is None else destination2.device.index) + return True + buf_type = ctypes.c_ubyte * info.size view = memoryview(buf_type.from_address(destination.data_ptr())) @@ -151,7 +167,7 @@ def set_ram_cache_release_state(callback, headroom): extra_ram_release_callback = callback RAM_CACHE_HEADROOM = max(0, int(headroom)) -def extra_ram_release(target): +def extra_ram_release(target, free_active=False): if extra_ram_release_callback is None: return 0 - return extra_ram_release_callback(target) + return extra_ram_release_callback(target, free_active=free_active) diff --git a/comfy/model_management.py b/comfy/model_management.py index 21738a4c7..3894dfa9c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,6 +31,7 @@ from contextlib import nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops +import comfy_aimdo.host_buffer import comfy_aimdo.vram_buffer class VRAMState(Enum): @@ -495,6 +496,14 @@ except: current_loaded_models = [] +DIRTY_MMAPS = set() + +PIN_PRESSURE_HYSTERESIS = 256 * 1024 * 1024 + +#Freeing registerables on pressure does imply a GPU sync, so go big on +#the hysteresis so each expensive sync gives us back a good chunk. +REGISTERABLE_PIN_HYSTERESIS = 2048 * 1024 * 1024 + def module_size(module): module_mem = 0 sd = module.state_dict() @@ -503,27 +512,46 @@ def module_size(module): module_mem += t.nbytes return module_mem -def module_mmap_residency(module, free=False): - mmap_touched_mem = 0 - module_mem = 0 - bounced_mmaps = set() - sd = module.state_dict() - for k in sd: - t = sd[k] - module_mem += t.nbytes - storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage() - if not getattr(storage, "_comfy_tensor_mmap_touched", False): - continue - mmap_touched_mem += t.nbytes - if not free: - continue - storage._comfy_tensor_mmap_touched = False - mmap_obj = storage._comfy_tensor_mmap_refs[0] - if mmap_obj in bounced_mmaps: - continue - mmap_obj.bounce() - bounced_mmaps.add(mmap_obj) - return mmap_touched_mem, module_mem +def mark_mmap_dirty(storage): + mmap_refs = getattr(storage, "_comfy_tensor_mmap_refs", None) + if mmap_refs is not None: + DIRTY_MMAPS.add(mmap_refs[0]) + +def free_pins(size, evict_active=False): + freed_total = 0 + for loaded_model in reversed(current_loaded_models): + if size <= 0: + return freed_total + model = loaded_model.model + if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]): + freed = model.partially_unload_ram(size) + freed_total += freed + size -= freed + return freed_total + +def ensure_pin_budget(size, evict_active=False): + shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available + if shortfall <= 0: + return True + + to_free = shortfall + PIN_PRESSURE_HYSTERESIS + return free_pins(to_free, evict_active=evict_active) >= shortfall + +def ensure_pin_registerable(size, evict_active=False): + shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY + if MAX_PINNED_MEMORY <= 0: + return False + if shortfall <= 0: + return True + + shortfall += REGISTERABLE_PIN_HYSTERESIS + for loaded_model in reversed(current_loaded_models): + model = loaded_model.model + if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]): + shortfall -= model.unregister_inactive_pins(shortfall) + if shortfall <= 0: + return True + return shortfall <= REGISTERABLE_PIN_HYSTERESIS class LoadedModel: def __init__(self, model): @@ -553,9 +581,6 @@ class LoadedModel: def model_memory(self): return self.model.model_size() - def model_mmap_residency(self, free=False): - return self.model.model_mmap_residency(free=free) - def model_loaded_memory(self): return self.model.loaded_size() @@ -635,15 +660,9 @@ WINDOWS = any(platform.win32_ver()) EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 if WINDOWS: - import comfy.windows EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 - def get_free_ram(): - return comfy.windows.get_free_ram() -else: - def get_free_ram(): - return psutil.virtual_memory().available if args.reserve_vram is not None: EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 @@ -657,7 +676,6 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0): cleanup_models_gc() - comfy.memory_management.extra_ram_release(max(pins_required, ram_required)) unloaded_model = [] can_unload = [] unloaded_models = [] @@ -673,11 +691,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins for x in can_unload_sorted: i = x[-1] memory_to_free = 1e32 - pins_to_free = 1e32 - if not DISABLE_SMART_MEMORY or device is None: + if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None): memory_to_free = 0 if device is None else memory_required - get_free_memory(device) - pins_to_free = pins_required - get_free_ram() - if current_loaded_models[i].model.is_dynamic() and for_dynamic: + if for_dynamic: #don't actually unload dynamic models for the sake of other dynamic models #as that works on-demand. memory_required -= current_loaded_models[i].model.loaded_size() @@ -685,18 +701,6 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") unloaded_model.append(i) - if pins_to_free > 0: - logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}") - current_loaded_models[i].model.partially_unload_ram(pins_to_free) - - for x in can_unload_sorted: - i = x[-1] - ram_to_free = ram_required - psutil.virtual_memory().available - if ram_to_free <= 0 and i not in unloaded_model: - continue - resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True) - if resident_memory > 0: - logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) @@ -762,29 +766,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu model_to_unload.model.detach(unpatch_all=False) model_to_unload.model_finalizer.detach() - total_memory_required = {} - total_pins_required = {} - total_ram_required = {} for loaded_model in models_to_load: device = loaded_model.device total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device) - resident_memory, model_memory = loaded_model.model_mmap_residency() - pinned_memory = loaded_model.model.pinned_memory_size() - #FIXME: This can over-free the pins as it budgets to pin the entire model. We should - #make this JIT to keep as much pinned as possible. - pins_required = model_memory - pinned_memory - ram_required = model_memory - resident_memory - total_pins_required[device] = total_pins_required.get(device, 0) + pins_required - total_ram_required[device] = total_ram_required.get(device, 0) + ram_required for device in total_memory_required: if device != torch.device("cpu"): free_memory(total_memory_required[device] * 1.1 + extra_mem, device, - for_dynamic=free_for_dynamic, - pins_required=total_pins_required[device], - ram_required=total_ram_required[device]) + for_dynamic=free_for_dynamic) for device in total_memory_required: if device != torch.device("cpu"): @@ -1180,6 +1171,7 @@ STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) STREAM_AIMDO_CAST_BUFFERS = {} LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) +STREAM_PIN_BUFFERS = {} DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 @@ -1220,21 +1212,66 @@ def get_aimdo_cast_buffer(offload_stream, device): if cast_buffer is None: cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index) STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer - return cast_buffer + +def get_pin_buffer(offload_stream): + pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None) + if pin_buffer is None: + pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3)) + STREAM_PIN_BUFFERS[offload_stream] = pin_buffer + elif offload_stream is not None: + event = getattr(pin_buffer, "_comfy_event", None) + if event is not None: + event.synchronize() + delattr(pin_buffer, "_comfy_event") + return pin_buffer + +def resize_pin_buffer(pin_buffer, size): + global TOTAL_PINNED_MEMORY + old_size = pin_buffer.size + if size <= old_size: + return True + growth = size - old_size + comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) + ensure_pin_budget(growth, evict_active=True) + ensure_pin_registerable(growth, evict_active=True) + try: + pin_buffer.extend(size=size, reallocate=True) + except RuntimeError: + return False + TOTAL_PINNED_MEMORY += pin_buffer.size - old_size + return True + def reset_cast_buffers(): + global TOTAL_PINNED_MEMORY global LARGEST_CASTED_WEIGHT global LARGEST_AIMDO_CASTED_WEIGHT LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) - for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): + for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS): if offload_stream is not None: offload_stream.synchronize() synchronize() + for mmap_obj in DIRTY_MMAPS: + mmap_obj.bounce() + DIRTY_MMAPS.clear() + + for pin_buffer in STREAM_PIN_BUFFERS.values(): + TOTAL_PINNED_MEMORY -= pin_buffer.size + TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY) + + for loaded_model in current_loaded_models: + model = loaded_model.model + if model is not None and model.is_dynamic(): + model.model.dynamic_pins[model.load_device]["active"] = False + model.partially_unload_ram(1e30, subsets=[ "patches" ]) + model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0]) + STREAM_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear() + STREAM_PIN_BUFFERS.clear() soft_empty_cache() def get_offload_stream(device): @@ -1280,7 +1317,7 @@ def sync_stream(device, stream): current_stream(device).wait_stream(stream) -def cast_to_gathered(tensors, r, non_blocking=False, stream=None): +def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None): wf_context = nullcontext() if stream is not None: wf_context = stream @@ -1288,17 +1325,20 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None): wf_context = wf_context.as_context(stream) dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) + dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None with wf_context: for tensor in tensors: dest_view = dest_views.pop(0) + dest2_view = dest2_views.pop(0) if dest2_views is not None else None if tensor is None: continue - if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view): + if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view, stream=stream, destination2=dest2_view): continue storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage() - if hasattr(storage, "_comfy_tensor_mmap_touched"): - storage._comfy_tensor_mmap_touched = True + mark_mmap_dirty(storage) dest_view.copy_(tensor, non_blocking=non_blocking) + if dest2_view is not None: + dest2_view.copy_(dest_view, non_blocking=non_blocking) def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): @@ -1339,14 +1379,18 @@ TOTAL_PINNED_MEMORY = 0 MAX_PINNED_MEMORY = -1 if not args.disable_pinned_memory: if is_nvidia() or is_amd(): + ram = get_total_memory(torch.device("cpu")) if WINDOWS: - MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.40 # Windows limit is apparently 50% + MAX_PINNED_MEMORY = ram * 0.40 # Windows limit is apparently 50% else: - MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.90 + MAX_PINNED_MEMORY = ram * 0.90 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"]) +def pinned_hostbuf_size(size): + return max(0, int(min(size, MAX_PINNED_MEMORY) * 2)) + def discard_cuda_async_error(): try: a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) @@ -1378,8 +1422,8 @@ def pin_memory(tensor): return False size = tensor.nbytes - if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: - return False + comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) + ensure_pin_registerable(size) ptr = tensor.data_ptr() if ptr == 0: @@ -1416,7 +1460,8 @@ def unpin_memory(tensor): return False if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: - TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) + size = PINNED_MEMORY.pop(ptr) + TOTAL_PINNED_MEMORY -= size return True else: logging.warning("Unpin error.") diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4f9d8403e..c8ed02e70 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -35,6 +35,7 @@ import comfy.model_management import comfy.ops import comfy.patcher_extension import comfy.utils +import comfy_aimdo.host_buffer from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP @@ -117,6 +118,8 @@ def string_to_seed(data): return comfy.utils.string_to_seed(data) class LowVramPatch: + is_lowvram_patch = True + def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key self.patches = patches @@ -124,11 +127,21 @@ class LowVramPatch: self.set_func = set_func self.prepared_patches = None - def prepare(self, allocate_buffer, stream): - self.prepared_patches = [ - (patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4]) + def memory_required(self): + counter = [0] + for patch in self.patches[self.key]: + comfy.lora.prefetch_prepared_value(patch[1], counter, None, None, False) + return counter[0] + + def prepare(self, destination, stream, copy=True, commit=True): + counter = [0] + prepared_patches = [ + (patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream, copy), patch[2], patch[3], patch[4]) for patch in self.patches[self.key] ] + if commit: + self.prepared_patches = prepared_patches + return prepared_patches def clear_prepared(self): self.prepared_patches = None @@ -341,9 +354,6 @@ class ModelPatcher: self.size = comfy.model_management.module_size(self.model) return self.size - def model_mmap_residency(self, free=False): - return comfy.model_management.module_mmap_residency(self.model, free=free) - def loaded_size(self): return self.model.model_loaded_weight_memory @@ -1118,8 +1128,12 @@ class ModelPatcher: # Pinned memory pressure tracking is only implemented for DynamicVram loading return 0 + def loaded_ram_size(self): + # Loaded RAM pressure tracking is only implemented for DynamicVram loading + return 0 + def partially_unload_ram(self, ram_to_unload): - pass + return 0 def detach(self, unpatch_all=True): self.eject_model() @@ -1550,6 +1564,16 @@ class ModelPatcherDynamic(ModelPatcher): super().__init__(model, load_device, offload_device, size, weight_inplace_update) if not hasattr(self.model, "dynamic_vbars"): self.model.dynamic_vbars = {} + if not hasattr(self.model, "dynamic_pins"): + self.model.dynamic_pins = {} + if self.load_device not in self.model.dynamic_pins: + self.model.dynamic_pins[self.load_device] = { + "weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), + "patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), + "hostbufs_initialized": False, + "failed": False, + "active": False, + } self.non_dynamic_delegate_model = None assert load_device is not None @@ -1611,6 +1635,14 @@ class ModelPatcherDynamic(ModelPatcher): self.unpatch_hooks() vbar = self._vbar_get(create=True) + pin_state = self.model.dynamic_pins[self.load_device] + if not pin_state["hostbufs_initialized"]: + hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size()) + pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0]) + pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0]) + pin_state["hostbufs_initialized"] = True + pin_state["failed"] = False + pin_state["active"] = True if vbar is not None: vbar.prioritize() @@ -1636,7 +1668,9 @@ class ModelPatcherDynamic(ModelPatcher): if key in self.patches: if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape: return (True, 0) - setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) + lowvram_patch = LowVramPatch(key, self.patches) + lowvram_patch._pin_state = pin_state + setattr(m, param_key + "_lowvram_function", lowvram_patch) num_patches += 1 else: setattr(m, param_key + "_lowvram_function", None) @@ -1653,6 +1687,9 @@ class ModelPatcherDynamic(ModelPatcher): def force_load_param(self, param_key, device_to): key = key_param_name_to_key(n, param_key) + weight, _, _ = get_key_weight(self.model, key) + if weight is None: + return if key in self.backup: comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) self.patch_weight_to_device(key, device_to=device_to, force_cast=True) @@ -1662,17 +1699,23 @@ class ModelPatcherDynamic(ModelPatcher): if hasattr(m, "comfy_cast_weights"): m.comfy_cast_weights = True - m.pin_failed = False m.seed_key = n + m._pin_state = pin_state set_dirty(m, dirty) - force_load, v_weight_size = setup_param(self, m, n, "weight") - force_load_bias, v_weight_bias = setup_param(self, m, n, "bias") - force_load = force_load or force_load_bias - v_weight_size += v_weight_bias + #Models that mix tiny and giant weights can causing lopsided stream buffer + #rotations and stall. force the tinys over. + if module_mem > 16 * 1024: + force_load, v_weight_size = setup_param(self, m, n, "weight") + force_load_bias, v_weight_bias = setup_param(self, m, n, "bias") + force_load = force_load or force_load_bias + v_weight_size += v_weight_bias + if force_load: + logging.info(f"Module {n} has resizing Lora - force loading") + else: + force_load=True if force_load: - logging.info(f"Module {n} has resizing Lora - force loading") force_load_param(self, "weight", device_to) force_load_param(self, "bias", device_to) else: @@ -1740,23 +1783,58 @@ class ModelPatcherDynamic(ModelPatcher): return freed - def pinned_memory_size(self): - total = 0 - loading = self._load_list(for_dynamic=True) - for x in loading: - _, _, _, _, m, _ = x - pin = comfy.pinned_memory.get_pin(m) - if pin is not None: - total += pin.numel() * pin.element_size() - return total + def loaded_ram_size(self): + return (self.model.dynamic_pins[self.load_device]["weights"][0].size + + self.model.dynamic_pins[self.load_device]["patches"][0].size) - def partially_unload_ram(self, ram_to_unload): - loading = self._load_list(for_dynamic=True, default_device=self.offload_device) - for x in loading: - *_, m, _ = x - ram_to_unload -= comfy.pinned_memory.unpin_memory(m) - if ram_to_unload <= 0: - return + def pinned_memory_size(self): + return (self.model.dynamic_pins[self.load_device]["weights"][3][0] + + self.model.dynamic_pins[self.load_device]["patches"][3][0]) + + def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]): + freed = 0 + pin_state = self.model.dynamic_pins[self.load_device] + for subset in subsets: + hostbuf, stack, stack_split, pinned_size = pin_state[subset] + split = stack_split[0] + while split >= 0: + module, offset = stack[split] + split -= 1 + stack_split[0] = split + if not module._pin_registered: + continue + size = module._pin.numel() * module._pin.element_size() + if torch.cuda.cudart().cudaHostUnregister(module._pin.data_ptr()) != 0: + comfy.model_management.discard_cuda_async_error() + continue + module._pin_registered = False + comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size) + pinned_size[0] = max(0, pinned_size[0] - size) + freed += size + ram_to_unload -= size + if ram_to_unload <= 0: + return freed + return freed + + def partially_unload_ram(self, ram_to_unload, subsets=[ "weights", "patches" ]): + freed = 0 + pin_state = self.model.dynamic_pins[self.load_device] + for subset in subsets: + hostbuf, stack, stack_split, pinned_size = pin_state[subset] + while len(stack) > 0: + module, offset = stack.pop() + size = module._pin.numel() * module._pin.element_size() + del module._pin + hostbuf.truncate(offset, do_unregister=module._pin_registered) + stack_split[0] = min(stack_split[0], len(stack) - 1) + if module._pin_registered: + comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size) + pinned_size[0] = max(0, pinned_size[0] - size) + freed += size + ram_to_unload -= size + if ram_to_unload <= 0: + return freed + return freed def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): #This isn't used by the core at all and can only be to load a model out of diff --git a/comfy/ops.py b/comfy/ops.py index eae3bd873..9bcd6c900 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -75,6 +75,8 @@ except: cast_to = comfy.model_management.cast_to #TODO: remove once no more references +STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024 + def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) @@ -91,6 +93,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin offload_stream = None cast_buffer = None cast_buffer_offset = 0 + stream_pin_hostbuf = None + stream_pin_offset = 0 + stream_pin_queue = [] def ensure_offload_stream(module, required_size, check_largest): nonlocal offload_stream @@ -124,6 +129,22 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin cast_buffer_offset += buffer_size return buffer + def get_stream_pin_buffer_offset(buffer_size): + nonlocal stream_pin_hostbuf + nonlocal stream_pin_offset + + if buffer_size == 0 or offload_stream is None: + return None + + if stream_pin_hostbuf is None: + stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream) + if stream_pin_hostbuf is None: + return None + + offset = stream_pin_offset + stream_pin_offset += buffer_size + return offset + for s in comfy_modules: signature = comfy_aimdo.model_vbar.vbar_fault(s._v) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) @@ -162,23 +183,47 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin if xfer_dest is None: xfer_dest = get_cast_buffer(dest_size) - if signature is None and pin is None: - comfy.pinned_memory.pin_memory(s) - pin = comfy.pinned_memory.get_pin(s) - else: - pin = None + def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream): + if xfer_source is not None: + if getattr(xfer_source, "is_lowvram_patch", False): + xfer_source.prepare(xfer_dest, stream, copy=True, commit=False) + else: + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream) - if pin is not None: - comfy.model_management.cast_to_gathered(xfer_source, pin) - xfer_source = [ pin ] - #send it over - comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + def handle_pin(m, pin, source, dest, subset="weights", size=None): + if pin is not None: + cast_maybe_lowvram_patch([pin], dest, offload_stream) + return + if signature is None: + comfy.pinned_memory.pin_memory(m, subset=subset, size=size) + pin = comfy.pinned_memory.get_pin(m, subset=subset) + if pin is not None: + if isinstance(source, list): + comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest) + else: + cast_maybe_lowvram_patch(source, pin, None) + cast_maybe_lowvram_patch([ pin ], dest, offload_stream) + return + if pin is None: + pin_offset = get_stream_pin_buffer_offset(size) + if pin_offset is not None: + stream_pin_queue.append((source, pin_offset, size, dest)) + return + cast_maybe_lowvram_patch(source, dest, offload_stream) + + handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size) for param_key in ("weight", "bias"): - lowvram_fn = getattr(s, param_key + "_lowvram_function", None) - if lowvram_fn is not None: + lowvram_source = getattr(s, param_key + "_lowvram_function", None) + if lowvram_source is not None: ensure_offload_stream(s, cast_buffer_offset, False) - lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream) + lowvram_size = lowvram_source.memory_required() + lowvram_dest = get_cast_buffer(lowvram_size) + lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True) + + pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches") + handle_pin(lowvram_source, pin, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size) + prefetch["xfer_dest"] = xfer_dest prefetch["cast_dest"] = cast_dest @@ -186,6 +231,23 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin prefetch["needs_cast"] = needs_cast s._prefetch = prefetch + if stream_pin_offset > 0: + if stream_pin_hostbuf.size < stream_pin_offset: + if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM): + for xfer_source, _, _, xfer_dest in stream_pin_queue: + cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream) + return offload_stream + stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf) + stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf + for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue: + pin = stream_pin_tensor[pin_offset:pin_offset + pin_size] + if isinstance(xfer_source, list): + comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest) + else: + cast_maybe_lowvram_patch(xfer_source, pin, None) + comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream) + stream_pin_hostbuf._comfy_event = offload_stream.record_event() + return offload_stream diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 6d3ba367a..0e8f573ba 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -2,42 +2,62 @@ import comfy.model_management import comfy.memory_management import comfy_aimdo.host_buffer import comfy_aimdo.torch +import torch from comfy.cli_args import args -def get_pin(module): - return getattr(module, "_pin", None) +def get_pin(module, subset="weights"): + pin = getattr(module, "_pin", None) + if pin is None or module._pin_registered or args.disable_pinned_memory: + return pin -def pin_memory(module): - if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: + _, _, stack_split, pinned_size = module._pin_state[subset] + size = pin.nbytes + comfy.model_management.ensure_pin_registerable(size) + + if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0: + comfy.model_management.discard_cuda_async_error() + return pin + + module._pin_registered = True + stack_split[0] = max(stack_split[0], module._pin_stack_index) + comfy.model_management.TOTAL_PINNED_MEMORY += size + pinned_size[0] += size + return pin + +def pin_memory(module, subset="weights", size=None): + pin_state = module._pin_state + if args.disable_pinned_memory: return - size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) + pin = get_pin(module, subset) + if pin is not None or pin_state["failed"]: + return - if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: - module.pin_failed = True + hostbuf, stack, stack_split, pinned_size = pin_state[subset] + if size is None: + size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) + offset = hostbuf.size + registerable_size = size + max(0, hostbuf.size - pinned_size[0]) + + comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) + if (not comfy.model_management.ensure_pin_budget(size) or + not comfy.model_management.ensure_pin_registerable(registerable_size)): + pin_state["failed"] = True return False try: - hostbuf = comfy_aimdo.host_buffer.HostBuffer(size) + hostbuf.extend(size=size) except RuntimeError: - module.pin_failed = True + pin_state["failed"] = True return False - module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf) - module._pin_hostbuf = hostbuf + module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size] + module._pin.untyped_storage()._comfy_hostbuf = hostbuf + stack.append((module, offset)) + module._pin_registered = True + module._pin_stack_index = len(stack) - 1 + stack_split[0] = max(stack_split[0], module._pin_stack_index) comfy.model_management.TOTAL_PINNED_MEMORY += size + pinned_size[0] += size return True - -def unpin_memory(module): - if get_pin(module) is None: - return 0 - size = module._pin.numel() * module._pin.element_size() - - comfy.model_management.TOTAL_PINNED_MEMORY -= size - if comfy.model_management.TOTAL_PINNED_MEMORY < 0: - comfy.model_management.TOTAL_PINNED_MEMORY = 0 - - del module._pin - del module._pin_hostbuf - return size diff --git a/comfy/utils.py b/comfy/utils.py index 66682690a..00e382fac 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -113,7 +113,6 @@ def load_safetensors(ckpt): "_comfy_tensor_file_slice", comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start)) setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv)) - setattr(storage, "_comfy_tensor_mmap_touched", False) sd[name] = tensor return sd, header.get("__metadata__", {}), @@ -1451,4 +1450,3 @@ def deepcopy_list_dict(obj, memo=None): memo[obj_id] = res return res - diff --git a/comfy/windows.py b/comfy/windows.py deleted file mode 100644 index 213dc481d..000000000 --- a/comfy/windows.py +++ /dev/null @@ -1,52 +0,0 @@ -import ctypes -import logging -import psutil -from ctypes import wintypes - -import comfy_aimdo.control - -psapi = ctypes.WinDLL("psapi") -kernel32 = ctypes.WinDLL("kernel32") - -class PERFORMANCE_INFORMATION(ctypes.Structure): - _fields_ = [ - ("cb", wintypes.DWORD), - ("CommitTotal", ctypes.c_size_t), - ("CommitLimit", ctypes.c_size_t), - ("CommitPeak", ctypes.c_size_t), - ("PhysicalTotal", ctypes.c_size_t), - ("PhysicalAvailable", ctypes.c_size_t), - ("SystemCache", ctypes.c_size_t), - ("KernelTotal", ctypes.c_size_t), - ("KernelPaged", ctypes.c_size_t), - ("KernelNonpaged", ctypes.c_size_t), - ("PageSize", ctypes.c_size_t), - ("HandleCount", wintypes.DWORD), - ("ProcessCount", wintypes.DWORD), - ("ThreadCount", wintypes.DWORD), - ] - -def get_free_ram(): - #Windows is way too conservative and chalks recently used uncommitted model RAM - #as "in-use". So, calculate free RAM for the sake of general use as the greater of: - # - #1: What psutil says - #2: Total Memory - (Committed Memory - VRAM in use) - # - #We have to subtract VRAM in use from the comitted memory as WDDM creates a naked - #commit charge for all VRAM used just incase it wants to page it all out. This just - #isn't realistic so "overcommit" on our calculations by just subtracting it off. - - pi = PERFORMANCE_INFORMATION() - pi.cb = ctypes.sizeof(pi) - - if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb): - logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal") - return psutil.virtual_memory().available - - committed = pi.CommitTotal * pi.PageSize - total = pi.PhysicalTotal * pi.PageSize - - return max(psutil.virtual_memory().available, - total - (committed - comfy_aimdo.control.get_total_vram_usage())) - diff --git a/execution.py b/execution.py index 4c7de2e84..5246d651c 100644 --- a/execution.py +++ b/execution.py @@ -2,6 +2,7 @@ import copy import heapq import inspect import logging +import psutil import sys import threading import time @@ -727,6 +728,7 @@ class PromptExecutor: self._notify_prompt_lifecycle("start", prompt_id) ram_headroom = int(self.cache_args["ram"] * (1024 ** 3)) + ram_inactive_headroom = int(self.cache_args["ram_inactive"] * (1024 ** 3)) ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom) @@ -780,8 +782,14 @@ class PromptExecutor: execution_list.complete_node_execution() if self.cache_type == CacheType.RAM_PRESSURE: - comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom) - ram_release_callback(ram_headroom, free_active=True) + ram_release_callback(ram_inactive_headroom) + ram_shortfall = ram_headroom - psutil.virtual_memory().available + freed = comfy.model_management.free_pins(ram_shortfall + 512 * (1024 ** 2)) + if freed < ram_shortfall: + if freed > 64 * (1024 ** 2): + # AIMDO MEM_DECOMMIT can outrun psutil.available catching up. + time.sleep(0.05) + ram_release_callback(ram_headroom, free_active=True) else: # Only execute when the while-loop ends without break # Send cached UI for intermediate output nodes that weren't executed diff --git a/main.py b/main.py index a6fdaf43c..1e47cab84 100644 --- a/main.py +++ b/main.py @@ -283,19 +283,25 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]: def prompt_worker(q, server_instance): current_time: float = 0.0 - cache_ram = args.cache_ram - if cache_ram < 0: + cache_ram = 0 + cache_ram_inactive = 0 + if not args.cache_classic and not args.cache_none and args.cache_lru <= 0: cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0)) + cache_ram_inactive = min(96.0, max(12.0, comfy.model_management.total_ram * 0.75 / 1024.0)) + if len(args.cache_ram) > 0: + cache_ram = args.cache_ram[0] + if len(args.cache_ram) > 1: + cache_ram_inactive = args.cache_ram[1] - cache_type = execution.CacheType.CLASSIC - if args.cache_lru > 0: + cache_type = execution.CacheType.RAM_PRESSURE + if args.cache_classic: + cache_type = execution.CacheType.CLASSIC + elif args.cache_lru > 0: cache_type = execution.CacheType.LRU - elif cache_ram > 0: - cache_type = execution.CacheType.RAM_PRESSURE elif args.cache_none: cache_type = execution.CacheType.NONE - e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } ) + e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram, "ram_inactive" : cache_ram_inactive } ) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 diff --git a/requirements.txt b/requirements.txt index 1c87690da..d2986eda8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0 filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo==0.3.0 +comfy-aimdo==0.4.3 requests simpleeval>=1.0.0 blake3 diff --git a/tests/execution/test_async_nodes.py b/tests/execution/test_async_nodes.py index c771b4b36..54660c112 100644 --- a/tests/execution/test_async_nodes.py +++ b/tests/execution/test_async_nodes.py @@ -14,7 +14,6 @@ from tests.execution.test_execution import ComfyClient, run_warmup class TestAsyncNodes: @fixture(scope="class", autouse=True, params=[ (False, 0), - (True, 0), (True, 100), ]) def _server(self, args_pytest, request): @@ -29,6 +28,8 @@ class TestAsyncNodes: use_lru, lru_size = request.param if use_lru: pargs += ['--cache-lru', str(lru_size)] + else: + pargs += ['--cache-classic'] # Running server with args: pargs p = subprocess.Popen(pargs) yield diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index f73ca7e3c..15e2304fc 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -183,8 +183,7 @@ class TestExecution: # Initialize server and client # @fixture(scope="class", autouse=True, params=[ - { "extra_args" : [], "should_cache_results" : True }, - { "extra_args" : ["--cache-lru", 0], "should_cache_results" : True }, + { "extra_args" : ["--cache-classic"], "should_cache_results" : True }, { "extra_args" : ["--cache-lru", 100], "should_cache_results" : True }, { "extra_args" : ["--cache-none"], "should_cache_results" : False }, ]) From 95fdc6cf910f809e39edc3254470e619ffa9dbf8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 20 May 2026 17:17:55 -0700 Subject: [PATCH 7/7] Repo security stuff. (#14019) --- CODEOWNERS | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CODEOWNERS b/CODEOWNERS index 946dbf946..043c0ec75 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,2 +1,5 @@ -# Admins * @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai + +/CODEOWNERS @comfyanonymous +/.ci/ @comfyanonymous +/.github/ @comfyanonymous