mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
Replace empty tensor initialization with zeros
This commit is contained in:
parent
5d119f0532
commit
49fef1697c
@ -523,8 +523,8 @@ class TwinFlowZImageTransformer(nn.Module):
|
||||
)
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
|
||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
|
||||
self.x_pad_token = nn.Parameter(torch.zeros((1, dim), device=device, dtype=dtype))
|
||||
self.cap_pad_token = nn.Parameter(torch.zeros((1, dim), device=device, dtype=dtype))
|
||||
|
||||
assert (dim // n_heads) == sum(axes_dims)
|
||||
self.axes_dims = axes_dims
|
||||
@ -630,31 +630,33 @@ class TwinFlowZImageTransformer(nn.Module):
|
||||
torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start
|
||||
).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
|
||||
x_pad_extra = 0
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
|
||||
x_pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
|
||||
x = torch.cat(
|
||||
(
|
||||
x,
|
||||
self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1),
|
||||
self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], x_pad_extra, 1),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
|
||||
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, x_pad_extra))
|
||||
|
||||
cap_pad_extra = 0
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||
cap_pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||
cap_feats = torch.cat(
|
||||
(
|
||||
cap_feats,
|
||||
self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True)
|
||||
.unsqueeze(0)
|
||||
.repeat(cap_feats.shape[0], pad_extra, 1),
|
||||
.repeat(cap_feats.shape[0], cap_pad_extra, 1),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
cap_pos_ids = torch.nn.functional.pad(cap_pos_ids, (0, 0, 0, pad_extra), value=0)
|
||||
if cap_mask is not None and pad_extra > 0:
|
||||
cap_mask = torch.nn.functional.pad(cap_mask, (0, pad_extra), value=0)
|
||||
cap_pos_ids = torch.nn.functional.pad(cap_pos_ids, (0, 0, 0, cap_pad_extra), value=0)
|
||||
if cap_mask is not None and cap_pad_extra > 0:
|
||||
cap_mask = torch.nn.functional.pad(cap_mask, (0, cap_pad_extra), value=0)
|
||||
|
||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||
|
||||
@ -677,7 +679,14 @@ class TwinFlowZImageTransformer(nn.Module):
|
||||
)
|
||||
|
||||
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||
mask = None
|
||||
if cap_mask is not None:
|
||||
cap_mask_bool = cap_mask if cap_mask.dtype == torch.bool else cap_mask > 0
|
||||
img_mask = torch.ones((bsz, x.shape[1]), device=cap_mask.device, dtype=torch.bool)
|
||||
if x_pad_extra > 0:
|
||||
img_mask[:, -x_pad_extra:] = False
|
||||
mask = torch.cat((cap_mask_bool, img_mask), dim=1)
|
||||
else:
|
||||
mask = None
|
||||
img_sizes = [(H, W)] * bsz
|
||||
l_effective_cap_len = [cap_feats.shape[1]] * bsz
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user