diff --git a/comfy/ldm/pixeldit/pid.py b/comfy/ldm/pixeldit/pid.py index ceb601647..f95698d68 100644 --- a/comfy/ldm/pixeldit/pid.py +++ b/comfy/ldm/pixeldit/pid.py @@ -251,7 +251,7 @@ class PidNet(PixDiT_T2I): pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype) x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - t_emb = self.t_embedder(timesteps.view(-1)).view(B, -1, self.hidden_size) + t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size) Ltxt = min(context.shape[1], self.txt_max_length) y = context[:, :Ltxt, :]