mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
Refactor time embedding and projection layers
This commit is contained in:
parent
de2ff57f3c
commit
5f041f4e5e
@ -379,6 +379,7 @@ class TwinFlowZImageTransformer(nn.Module):
|
||||
"device": device,
|
||||
"dtype": dtype,
|
||||
}
|
||||
self.time_embed_dim = 256 if z_image_modulation else min(dim, 1024)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
@ -395,12 +396,12 @@ class TwinFlowZImageTransformer(nn.Module):
|
||||
|
||||
self.t_embedder = TimestepEmbedder(
|
||||
min(dim, 1024),
|
||||
output_size=256 if z_image_modulation else None,
|
||||
output_size=self.time_embed_dim if z_image_modulation else None,
|
||||
**operation_settings,
|
||||
)
|
||||
self.t_embedder_2 = TimestepEmbedder(
|
||||
min(dim, 1024),
|
||||
output_size=256 if z_image_modulation else None,
|
||||
output_size=self.time_embed_dim if z_image_modulation else None,
|
||||
**operation_settings,
|
||||
)
|
||||
|
||||
@ -477,6 +478,22 @@ class TwinFlowZImageTransformer(nn.Module):
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
self.clip_text_concat_proj = nn.Sequential(
|
||||
operation_settings.get("operations").RMSNorm(
|
||||
clip_text_dim + self.time_embed_dim,
|
||||
eps=norm_eps,
|
||||
elementwise_affine=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
operation_settings.get("operations").Linear(
|
||||
clip_text_dim + self.time_embed_dim,
|
||||
self.time_embed_dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
@ -531,7 +548,7 @@ class TwinFlowZImageTransformer(nn.Module):
|
||||
t_emb_2 = self.t_embedder_2(t * self.time_scale, dtype=x_dtype)
|
||||
return t_emb + t_emb_2
|
||||
|
||||
target_t = target_timestep.to(device=t.device, dtype=t.dtype)
|
||||
target_t = torch.as_tensor(target_timestep, device=t.device, dtype=t.dtype)
|
||||
if target_t.ndim == 0:
|
||||
target_t = target_t.expand_as(t)
|
||||
|
||||
@ -636,6 +653,8 @@ class TwinFlowZImageTransformer(nn.Module):
|
||||
dim=1,
|
||||
)
|
||||
cap_pos_ids = torch.nn.functional.pad(cap_pos_ids, (0, 0, 0, pad_extra), value=0)
|
||||
if cap_mask is not None and pad_extra > 0:
|
||||
cap_mask = torch.nn.functional.pad(cap_mask, (0, pad_extra), value=0)
|
||||
|
||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||
|
||||
@ -703,7 +722,7 @@ class TwinFlowZImageTransformer(nn.Module):
|
||||
else:
|
||||
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
adaln_input = torch.cat((t_emb, pooled), dim=-1)
|
||||
adaln_input = self.clip_text_pooled_proj(adaln_input)
|
||||
adaln_input = self.clip_text_concat_proj(adaln_input)
|
||||
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(
|
||||
x,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user