Refactor time embedding and projection layers

This commit is contained in:
azazeal04 2026-04-04 19:38:55 +02:00 committed by GitHub
parent de2ff57f3c
commit 5f041f4e5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -379,6 +379,7 @@ class TwinFlowZImageTransformer(nn.Module):
"device": device,
"dtype": dtype,
}
self.time_embed_dim = 256 if z_image_modulation else min(dim, 1024)
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
@ -395,12 +396,12 @@ class TwinFlowZImageTransformer(nn.Module):
self.t_embedder = TimestepEmbedder(
min(dim, 1024),
output_size=256 if z_image_modulation else None,
output_size=self.time_embed_dim if z_image_modulation else None,
**operation_settings,
)
self.t_embedder_2 = TimestepEmbedder(
min(dim, 1024),
output_size=256 if z_image_modulation else None,
output_size=self.time_embed_dim if z_image_modulation else None,
**operation_settings,
)
@ -477,6 +478,22 @@ class TwinFlowZImageTransformer(nn.Module):
dtype=operation_settings.get("dtype"),
),
)
self.clip_text_concat_proj = nn.Sequential(
operation_settings.get("operations").RMSNorm(
clip_text_dim + self.time_embed_dim,
eps=norm_eps,
elementwise_affine=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
operation_settings.get("operations").Linear(
clip_text_dim + self.time_embed_dim,
self.time_embed_dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.layers = nn.ModuleList(
[
@ -531,7 +548,7 @@ class TwinFlowZImageTransformer(nn.Module):
t_emb_2 = self.t_embedder_2(t * self.time_scale, dtype=x_dtype)
return t_emb + t_emb_2
target_t = target_timestep.to(device=t.device, dtype=t.dtype)
target_t = torch.as_tensor(target_timestep, device=t.device, dtype=t.dtype)
if target_t.ndim == 0:
target_t = target_t.expand_as(t)
@ -636,6 +653,8 @@ class TwinFlowZImageTransformer(nn.Module):
dim=1,
)
cap_pos_ids = torch.nn.functional.pad(cap_pos_ids, (0, 0, 0, pad_extra), value=0)
if cap_mask is not None and pad_extra > 0:
cap_mask = torch.nn.functional.pad(cap_mask, (0, pad_extra), value=0)
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
@ -703,7 +722,7 @@ class TwinFlowZImageTransformer(nn.Module):
else:
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = torch.cat((t_emb, pooled), dim=-1)
adaln_input = self.clip_text_pooled_proj(adaln_input)
adaln_input = self.clip_text_concat_proj(adaln_input)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(
x,