mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-19 06:52:31 +08:00
Implement _compute_twinflow_adaln method
Added a method to compute TwinFlow adaLN input with delta-time conditioning.
This commit is contained in:
parent
4fe75dcce8
commit
de2ff57f3c
@ -516,6 +516,41 @@ class TwinFlowZImageTransformer(nn.Module):
|
|||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
|
|
||||||
|
def _compute_twinflow_adaln(self, t: torch.Tensor, x_dtype: torch.dtype, transformer_options={}):
|
||||||
|
"""
|
||||||
|
Compute TwinFlow adaLN input.
|
||||||
|
|
||||||
|
If `target_timestep` is provided in transformer options, apply the
|
||||||
|
TwinFlow delta-time conditioning:
|
||||||
|
t_emb + t_embedder_2((target - t) * time_scale) * abs(target - t)
|
||||||
|
otherwise fallback to the baseline additive embedding.
|
||||||
|
"""
|
||||||
|
t_emb = self.t_embedder(t * self.time_scale, dtype=x_dtype)
|
||||||
|
target_timestep = transformer_options.get("target_timestep", None)
|
||||||
|
if target_timestep is None:
|
||||||
|
t_emb_2 = self.t_embedder_2(t * self.time_scale, dtype=x_dtype)
|
||||||
|
return t_emb + t_emb_2
|
||||||
|
|
||||||
|
target_t = target_timestep.to(device=t.device, dtype=t.dtype)
|
||||||
|
if target_t.ndim == 0:
|
||||||
|
target_t = target_t.expand_as(t)
|
||||||
|
|
||||||
|
# If values look scaled (roughly sigma/timestep in [0..1000]), normalize.
|
||||||
|
t_abs_max = float(t.detach().abs().max().item()) if t.numel() else 0.0
|
||||||
|
tt_abs_max = float(target_t.detach().abs().max().item()) if target_t.numel() else 0.0
|
||||||
|
scaled_domain = (max(t_abs_max, tt_abs_max) > 2.0) and (self.time_scale > 2.0)
|
||||||
|
if scaled_domain:
|
||||||
|
t_norm = t / self.time_scale
|
||||||
|
tt_norm = target_t / self.time_scale
|
||||||
|
else:
|
||||||
|
t_norm = t
|
||||||
|
tt_norm = target_t
|
||||||
|
|
||||||
|
delta_abs = (t_norm - tt_norm).abs().unsqueeze(1).to(t_emb.dtype)
|
||||||
|
diff_in = (tt_norm - t_norm) * self.time_scale
|
||||||
|
t_emb_2 = self.t_embedder_2(diff_in, dtype=x_dtype)
|
||||||
|
return t_emb + t_emb_2 * delta_abs
|
||||||
|
|
||||||
def unpatchify(
|
def unpatchify(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -651,10 +686,8 @@ class TwinFlowZImageTransformer(nn.Module):
|
|||||||
):
|
):
|
||||||
t = 1.0 - timesteps
|
t = 1.0 - timesteps
|
||||||
|
|
||||||
|
adaln_input = self._compute_twinflow_adaln(t, x.dtype, transformer_options=transformer_options)
|
||||||
t_emb = self.t_embedder(t * self.time_scale, dtype=x.dtype)
|
t_emb = self.t_embedder(t * self.time_scale, dtype=x.dtype)
|
||||||
t_emb_2 = self.t_embedder_2(t * self.time_scale, dtype=x.dtype)
|
|
||||||
|
|
||||||
adaln_input = t_emb + t_emb_2
|
|
||||||
|
|
||||||
cap_feats = context
|
cap_feats = context
|
||||||
cap_mask = attention_mask
|
cap_mask = attention_mask
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user