From 9461f3038783dec19cf8c12bbdfffc171c32f792 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 20:56:56 -0700 Subject: [PATCH] Made StableAudio work with optimized_attention_override --- comfy/ldm/audio/dit.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index 179c5b67e..5ca0e1e52 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -298,7 +298,8 @@ class Attention(nn.Module): mask = None, context_mask = None, rotary_pos_emb = None, - causal = None + causal = None, + transformer_options={}, ): h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None @@ -363,7 +364,7 @@ class Attention(nn.Module): heads_per_kv_head = h // kv_h k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) - out = optimized_attention(q, k, v, h, skip_reshape=True) + out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) out = self.to_out(out) if mask is not None: @@ -488,7 +489,8 @@ class TransformerBlock(nn.Module): global_cond=None, mask = None, context_mask = None, - rotary_pos_emb = None + rotary_pos_emb = None, + transformer_options={} ): if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: @@ -498,12 +500,12 @@ class TransformerBlock(nn.Module): residual = x x = self.pre_norm(x) x = x * (1 + scale_self) + shift_self - x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb) + x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options) x = x * torch.sigmoid(1 - gate_self) x = x + residual if context is not None: - x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options) if self.conformer is not None: x = x + self.conformer(x) @@ -517,10 +519,10 @@ class TransformerBlock(nn.Module): x = x + residual else: - x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb) + x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options) if context is not None: - x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options) if self.conformer is not None: x = x + self.conformer(x) @@ -606,7 +608,8 @@ class ContinuousTransformer(nn.Module): return_info = False, **kwargs ): - patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {}) + transformer_options = kwargs.get("transformer_options", {}) + patches_replace = transformer_options.get("patches_replace", {}) batch, seq, device = *x.shape[:2], x.device context = kwargs["context"] @@ -645,13 +648,13 @@ class ContinuousTransformer(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"]) + out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context) + x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options) # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) if return_info: