diff --git a/comfy/ldm/twinflow/model.py b/comfy/ldm/twinflow/model.py index 4b21a1f84..90e511c92 100644 --- a/comfy/ldm/twinflow/model.py +++ b/comfy/ldm/twinflow/model.py @@ -523,8 +523,8 @@ class TwinFlowZImageTransformer(nn.Module): ) if self.pad_tokens_multiple is not None: - self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) - self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) + self.x_pad_token = nn.Parameter(torch.zeros((1, dim), device=device, dtype=dtype)) + self.cap_pad_token = nn.Parameter(torch.zeros((1, dim), device=device, dtype=dtype)) assert (dim // n_heads) == sum(axes_dims) self.axes_dims = axes_dims @@ -630,31 +630,33 @@ class TwinFlowZImageTransformer(nn.Module): torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start ).view(1, -1).repeat(H_tokens, 1).flatten() + x_pad_extra = 0 if self.pad_tokens_multiple is not None: - pad_extra = (-x.shape[1]) % self.pad_tokens_multiple + x_pad_extra = (-x.shape[1]) % self.pad_tokens_multiple x = torch.cat( ( x, - self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1), + self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], x_pad_extra, 1), ), dim=1, ) - x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) + x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, x_pad_extra)) + cap_pad_extra = 0 if self.pad_tokens_multiple is not None: - pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple + cap_pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple cap_feats = torch.cat( ( cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True) .unsqueeze(0) - .repeat(cap_feats.shape[0], pad_extra, 1), + .repeat(cap_feats.shape[0], cap_pad_extra, 1), ), dim=1, ) - cap_pos_ids = torch.nn.functional.pad(cap_pos_ids, (0, 0, 0, pad_extra), value=0) - if cap_mask is not None and pad_extra > 0: - cap_mask = torch.nn.functional.pad(cap_mask, (0, pad_extra), value=0) + cap_pos_ids = torch.nn.functional.pad(cap_pos_ids, (0, 0, 0, cap_pad_extra), value=0) + if cap_mask is not None and cap_pad_extra > 0: + cap_mask = torch.nn.functional.pad(cap_mask, (0, cap_pad_extra), value=0) freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) @@ -677,7 +679,14 @@ class TwinFlowZImageTransformer(nn.Module): ) padded_full_embed = torch.cat((cap_feats, x), dim=1) - mask = None + if cap_mask is not None: + cap_mask_bool = cap_mask if cap_mask.dtype == torch.bool else cap_mask > 0 + img_mask = torch.ones((bsz, x.shape[1]), device=cap_mask.device, dtype=torch.bool) + if x_pad_extra > 0: + img_mask[:, -x_pad_extra:] = False + mask = torch.cat((cap_mask_bool, img_mask), dim=1) + else: + mask = None img_sizes = [(H, W)] * bsz l_effective_cap_len = [cap_feats.shape[1]] * bsz