From 5f041f4e5e1131bcccd0dc339c39ab75717e421b Mon Sep 17 00:00:00 2001 From: azazeal04 <132445160+azazeal04@users.noreply.github.com> Date: Sat, 4 Apr 2026 19:38:55 +0200 Subject: [PATCH] Refactor time embedding and projection layers --- comfy/ldm/twinflow/model.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/twinflow/model.py b/comfy/ldm/twinflow/model.py index 94ec1e3f3..7ed399866 100644 --- a/comfy/ldm/twinflow/model.py +++ b/comfy/ldm/twinflow/model.py @@ -379,6 +379,7 @@ class TwinFlowZImageTransformer(nn.Module): "device": device, "dtype": dtype, } + self.time_embed_dim = 256 if z_image_modulation else min(dim, 1024) self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size @@ -395,12 +396,12 @@ class TwinFlowZImageTransformer(nn.Module): self.t_embedder = TimestepEmbedder( min(dim, 1024), - output_size=256 if z_image_modulation else None, + output_size=self.time_embed_dim if z_image_modulation else None, **operation_settings, ) self.t_embedder_2 = TimestepEmbedder( min(dim, 1024), - output_size=256 if z_image_modulation else None, + output_size=self.time_embed_dim if z_image_modulation else None, **operation_settings, ) @@ -477,6 +478,22 @@ class TwinFlowZImageTransformer(nn.Module): dtype=operation_settings.get("dtype"), ), ) + self.clip_text_concat_proj = nn.Sequential( + operation_settings.get("operations").RMSNorm( + clip_text_dim + self.time_embed_dim, + eps=norm_eps, + elementwise_affine=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + operation_settings.get("operations").Linear( + clip_text_dim + self.time_embed_dim, + self.time_embed_dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) self.layers = nn.ModuleList( [ @@ -531,7 +548,7 @@ class TwinFlowZImageTransformer(nn.Module): t_emb_2 = self.t_embedder_2(t * self.time_scale, dtype=x_dtype) return t_emb + t_emb_2 - target_t = target_timestep.to(device=t.device, dtype=t.dtype) + target_t = torch.as_tensor(target_timestep, device=t.device, dtype=t.dtype) if target_t.ndim == 0: target_t = target_t.expand_as(t) @@ -636,6 +653,8 @@ class TwinFlowZImageTransformer(nn.Module): dim=1, ) cap_pos_ids = torch.nn.functional.pad(cap_pos_ids, (0, 0, 0, pad_extra), value=0) + if cap_mask is not None and pad_extra > 0: + cap_mask = torch.nn.functional.pad(cap_mask, (0, pad_extra), value=0) freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) @@ -703,7 +722,7 @@ class TwinFlowZImageTransformer(nn.Module): else: pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype) adaln_input = torch.cat((t_emb, pooled), dim=-1) - adaln_input = self.clip_text_pooled_proj(adaln_input) + adaln_input = self.clip_text_concat_proj(adaln_input) img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed( x,