Implement _compute_twinflow_adaln method

Added a method to compute TwinFlow adaLN input with delta-time conditioning.
This commit is contained in:
azazeal04 2026-04-04 18:49:19 +02:00 committed by GitHub
parent 4fe75dcce8
commit de2ff57f3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -516,6 +516,41 @@ class TwinFlowZImageTransformer(nn.Module):
self.dim = dim
self.n_heads = n_heads
def _compute_twinflow_adaln(self, t: torch.Tensor, x_dtype: torch.dtype, transformer_options={}):
"""
Compute TwinFlow adaLN input.
If `target_timestep` is provided in transformer options, apply the
TwinFlow delta-time conditioning:
t_emb + t_embedder_2((target - t) * time_scale) * abs(target - t)
otherwise fallback to the baseline additive embedding.
"""
t_emb = self.t_embedder(t * self.time_scale, dtype=x_dtype)
target_timestep = transformer_options.get("target_timestep", None)
if target_timestep is None:
t_emb_2 = self.t_embedder_2(t * self.time_scale, dtype=x_dtype)
return t_emb + t_emb_2
target_t = target_timestep.to(device=t.device, dtype=t.dtype)
if target_t.ndim == 0:
target_t = target_t.expand_as(t)
# If values look scaled (roughly sigma/timestep in [0..1000]), normalize.
t_abs_max = float(t.detach().abs().max().item()) if t.numel() else 0.0
tt_abs_max = float(target_t.detach().abs().max().item()) if target_t.numel() else 0.0
scaled_domain = (max(t_abs_max, tt_abs_max) > 2.0) and (self.time_scale > 2.0)
if scaled_domain:
t_norm = t / self.time_scale
tt_norm = target_t / self.time_scale
else:
t_norm = t
tt_norm = target_t
delta_abs = (t_norm - tt_norm).abs().unsqueeze(1).to(t_emb.dtype)
diff_in = (tt_norm - t_norm) * self.time_scale
t_emb_2 = self.t_embedder_2(diff_in, dtype=x_dtype)
return t_emb + t_emb_2 * delta_abs
def unpatchify(
self,
x: torch.Tensor,
@ -651,10 +686,8 @@ class TwinFlowZImageTransformer(nn.Module):
):
t = 1.0 - timesteps
adaln_input = self._compute_twinflow_adaln(t, x.dtype, transformer_options=transformer_options)
t_emb = self.t_embedder(t * self.time_scale, dtype=x.dtype)
t_emb_2 = self.t_embedder_2(t * self.time_scale, dtype=x.dtype)
adaln_input = t_emb + t_emb_2
cap_feats = context
cap_mask = attention_mask