Replace empty tensor initialization with zeros

This commit is contained in:
azazeal04 2026-04-04 20:21:27 +02:00 committed by GitHub
parent 5d119f0532
commit 49fef1697c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -523,8 +523,8 @@ class TwinFlowZImageTransformer(nn.Module):
)
if self.pad_tokens_multiple is not None:
self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
self.x_pad_token = nn.Parameter(torch.zeros((1, dim), device=device, dtype=dtype))
self.cap_pad_token = nn.Parameter(torch.zeros((1, dim), device=device, dtype=dtype))
assert (dim // n_heads) == sum(axes_dims)
self.axes_dims = axes_dims
@ -630,31 +630,33 @@ class TwinFlowZImageTransformer(nn.Module):
torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start
).view(1, -1).repeat(H_tokens, 1).flatten()
x_pad_extra = 0
if self.pad_tokens_multiple is not None:
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
x_pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
x = torch.cat(
(
x,
self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1),
self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], x_pad_extra, 1),
),
dim=1,
)
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, x_pad_extra))
cap_pad_extra = 0
if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
cap_pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
cap_feats = torch.cat(
(
cap_feats,
self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True)
.unsqueeze(0)
.repeat(cap_feats.shape[0], pad_extra, 1),
.repeat(cap_feats.shape[0], cap_pad_extra, 1),
),
dim=1,
)
cap_pos_ids = torch.nn.functional.pad(cap_pos_ids, (0, 0, 0, pad_extra), value=0)
if cap_mask is not None and pad_extra > 0:
cap_mask = torch.nn.functional.pad(cap_mask, (0, pad_extra), value=0)
cap_pos_ids = torch.nn.functional.pad(cap_pos_ids, (0, 0, 0, cap_pad_extra), value=0)
if cap_mask is not None and cap_pad_extra > 0:
cap_mask = torch.nn.functional.pad(cap_mask, (0, cap_pad_extra), value=0)
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
@ -677,7 +679,14 @@ class TwinFlowZImageTransformer(nn.Module):
)
padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None
if cap_mask is not None:
cap_mask_bool = cap_mask if cap_mask.dtype == torch.bool else cap_mask > 0
img_mask = torch.ones((bsz, x.shape[1]), device=cap_mask.device, dtype=torch.bool)
if x_pad_extra > 0:
img_mask[:, -x_pad_extra:] = False
mask = torch.cat((cap_mask_bool, img_mask), dim=1)
else:
mask = None
img_sizes = [(H, W)] * bsz
l_effective_cap_len = [cap_feats.shape[1]] * bsz