mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
Implement _compute_twinflow_adaln method
Added a method to compute TwinFlow adaLN input with delta-time conditioning.
This commit is contained in:
parent
4fe75dcce8
commit
de2ff57f3c
@ -516,6 +516,41 @@ class TwinFlowZImageTransformer(nn.Module):
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
|
||||
def _compute_twinflow_adaln(self, t: torch.Tensor, x_dtype: torch.dtype, transformer_options={}):
|
||||
"""
|
||||
Compute TwinFlow adaLN input.
|
||||
|
||||
If `target_timestep` is provided in transformer options, apply the
|
||||
TwinFlow delta-time conditioning:
|
||||
t_emb + t_embedder_2((target - t) * time_scale) * abs(target - t)
|
||||
otherwise fallback to the baseline additive embedding.
|
||||
"""
|
||||
t_emb = self.t_embedder(t * self.time_scale, dtype=x_dtype)
|
||||
target_timestep = transformer_options.get("target_timestep", None)
|
||||
if target_timestep is None:
|
||||
t_emb_2 = self.t_embedder_2(t * self.time_scale, dtype=x_dtype)
|
||||
return t_emb + t_emb_2
|
||||
|
||||
target_t = target_timestep.to(device=t.device, dtype=t.dtype)
|
||||
if target_t.ndim == 0:
|
||||
target_t = target_t.expand_as(t)
|
||||
|
||||
# If values look scaled (roughly sigma/timestep in [0..1000]), normalize.
|
||||
t_abs_max = float(t.detach().abs().max().item()) if t.numel() else 0.0
|
||||
tt_abs_max = float(target_t.detach().abs().max().item()) if target_t.numel() else 0.0
|
||||
scaled_domain = (max(t_abs_max, tt_abs_max) > 2.0) and (self.time_scale > 2.0)
|
||||
if scaled_domain:
|
||||
t_norm = t / self.time_scale
|
||||
tt_norm = target_t / self.time_scale
|
||||
else:
|
||||
t_norm = t
|
||||
tt_norm = target_t
|
||||
|
||||
delta_abs = (t_norm - tt_norm).abs().unsqueeze(1).to(t_emb.dtype)
|
||||
diff_in = (tt_norm - t_norm) * self.time_scale
|
||||
t_emb_2 = self.t_embedder_2(diff_in, dtype=x_dtype)
|
||||
return t_emb + t_emb_2 * delta_abs
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -651,10 +686,8 @@ class TwinFlowZImageTransformer(nn.Module):
|
||||
):
|
||||
t = 1.0 - timesteps
|
||||
|
||||
adaln_input = self._compute_twinflow_adaln(t, x.dtype, transformer_options=transformer_options)
|
||||
t_emb = self.t_embedder(t * self.time_scale, dtype=x.dtype)
|
||||
t_emb_2 = self.t_embedder_2(t * self.time_scale, dtype=x.dtype)
|
||||
|
||||
adaln_input = t_emb + t_emb_2
|
||||
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
|
||||
Loading…
Reference in New Issue
Block a user