mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-19 15:02:44 +08:00
Fix condition and update transformer_options handling
This commit is contained in:
parent
479536e34d
commit
6965b83a2a
@ -108,7 +108,7 @@ class JointAttention(nn.Module):
|
|||||||
xq, xk = apply_rope(xq, xk, freqs_cis)
|
xq, xk = apply_rope(xq, xk, freqs_cis)
|
||||||
|
|
||||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
if n_rep >= 1:
|
if n_rep > 1:
|
||||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||||
|
|
||||||
@ -700,13 +700,15 @@ class TwinFlowZImageTransformer(nn.Module):
|
|||||||
context,
|
context,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
transformer_options={},
|
transformer_options=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if transformer_options is None:
|
||||||
|
transformer_options = {}
|
||||||
|
|
||||||
t = 1.0 - timesteps
|
t = 1.0 - timesteps
|
||||||
|
|
||||||
adaln_input = self._compute_twinflow_adaln(t, x.dtype, transformer_options=transformer_options)
|
adaln_input = self._compute_twinflow_adaln(t, x.dtype, transformer_options=transformer_options)
|
||||||
t_emb = self.t_embedder(t * self.time_scale, dtype=x.dtype)
|
|
||||||
|
|
||||||
cap_feats = context
|
cap_feats = context
|
||||||
cap_mask = attention_mask
|
cap_mask = attention_mask
|
||||||
@ -721,7 +723,7 @@ class TwinFlowZImageTransformer(nn.Module):
|
|||||||
pooled = self.clip_text_pooled_proj(pooled)
|
pooled = self.clip_text_pooled_proj(pooled)
|
||||||
else:
|
else:
|
||||||
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
|
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||||
adaln_input = torch.cat((t_emb, pooled), dim=-1)
|
adaln_input = torch.cat((adaln_input, pooled), dim=-1)
|
||||||
adaln_input = self.clip_text_concat_proj(adaln_input)
|
adaln_input = self.clip_text_concat_proj(adaln_input)
|
||||||
|
|
||||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(
|
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user