diff --git a/comfy/ldm/twinflow/model.py b/comfy/ldm/twinflow/model.py index a99bcd06e..94ec1e3f3 100644 --- a/comfy/ldm/twinflow/model.py +++ b/comfy/ldm/twinflow/model.py @@ -516,6 +516,41 @@ class TwinFlowZImageTransformer(nn.Module): self.dim = dim self.n_heads = n_heads + def _compute_twinflow_adaln(self, t: torch.Tensor, x_dtype: torch.dtype, transformer_options={}): + """ + Compute TwinFlow adaLN input. + + If `target_timestep` is provided in transformer options, apply the + TwinFlow delta-time conditioning: + t_emb + t_embedder_2((target - t) * time_scale) * abs(target - t) + otherwise fallback to the baseline additive embedding. + """ + t_emb = self.t_embedder(t * self.time_scale, dtype=x_dtype) + target_timestep = transformer_options.get("target_timestep", None) + if target_timestep is None: + t_emb_2 = self.t_embedder_2(t * self.time_scale, dtype=x_dtype) + return t_emb + t_emb_2 + + target_t = target_timestep.to(device=t.device, dtype=t.dtype) + if target_t.ndim == 0: + target_t = target_t.expand_as(t) + + # If values look scaled (roughly sigma/timestep in [0..1000]), normalize. + t_abs_max = float(t.detach().abs().max().item()) if t.numel() else 0.0 + tt_abs_max = float(target_t.detach().abs().max().item()) if target_t.numel() else 0.0 + scaled_domain = (max(t_abs_max, tt_abs_max) > 2.0) and (self.time_scale > 2.0) + if scaled_domain: + t_norm = t / self.time_scale + tt_norm = target_t / self.time_scale + else: + t_norm = t + tt_norm = target_t + + delta_abs = (t_norm - tt_norm).abs().unsqueeze(1).to(t_emb.dtype) + diff_in = (tt_norm - t_norm) * self.time_scale + t_emb_2 = self.t_embedder_2(diff_in, dtype=x_dtype) + return t_emb + t_emb_2 * delta_abs + def unpatchify( self, x: torch.Tensor, @@ -651,10 +686,8 @@ class TwinFlowZImageTransformer(nn.Module): ): t = 1.0 - timesteps + adaln_input = self._compute_twinflow_adaln(t, x.dtype, transformer_options=transformer_options) t_emb = self.t_embedder(t * self.time_scale, dtype=x.dtype) - t_emb_2 = self.t_embedder_2(t * self.time_scale, dtype=x.dtype) - - adaln_input = t_emb + t_emb_2 cap_feats = context cap_mask = attention_mask