Fix condition and update transformer_options handling

This commit is contained in:
azazeal04 2026-04-04 19:57:37 +02:00 committed by GitHub
parent 479536e34d
commit 6965b83a2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -108,7 +108,7 @@ class JointAttention(nn.Module):
xq, xk = apply_rope(xq, xk, freqs_cis)
n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
if n_rep > 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
@ -700,13 +700,15 @@ class TwinFlowZImageTransformer(nn.Module):
context,
num_tokens,
attention_mask=None,
transformer_options={},
transformer_options=None,
**kwargs,
):
if transformer_options is None:
transformer_options = {}
t = 1.0 - timesteps
adaln_input = self._compute_twinflow_adaln(t, x.dtype, transformer_options=transformer_options)
t_emb = self.t_embedder(t * self.time_scale, dtype=x.dtype)
cap_feats = context
cap_mask = attention_mask
@ -721,7 +723,7 @@ class TwinFlowZImageTransformer(nn.Module):
pooled = self.clip_text_pooled_proj(pooled)
else:
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = torch.cat((t_emb, pooled), dim=-1)
adaln_input = torch.cat((adaln_input, pooled), dim=-1)
adaln_input = self.clip_text_concat_proj(adaln_input)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(