mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
Fix condition and update transformer_options handling
This commit is contained in:
parent
479536e34d
commit
6965b83a2a
@ -108,7 +108,7 @@ class JointAttention(nn.Module):
|
||||
xq, xk = apply_rope(xq, xk, freqs_cis)
|
||||
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
if n_rep > 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
@ -700,13 +700,15 @@ class TwinFlowZImageTransformer(nn.Module):
|
||||
context,
|
||||
num_tokens,
|
||||
attention_mask=None,
|
||||
transformer_options={},
|
||||
transformer_options=None,
|
||||
**kwargs,
|
||||
):
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
|
||||
t = 1.0 - timesteps
|
||||
|
||||
adaln_input = self._compute_twinflow_adaln(t, x.dtype, transformer_options=transformer_options)
|
||||
t_emb = self.t_embedder(t * self.time_scale, dtype=x.dtype)
|
||||
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
@ -721,7 +723,7 @@ class TwinFlowZImageTransformer(nn.Module):
|
||||
pooled = self.clip_text_pooled_proj(pooled)
|
||||
else:
|
||||
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
adaln_input = torch.cat((t_emb, pooled), dim=-1)
|
||||
adaln_input = torch.cat((adaln_input, pooled), dim=-1)
|
||||
adaln_input = self.clip_text_concat_proj(adaln_input)
|
||||
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user