Fix TokenRefiner for fp16

Otherwise x.sum has infs, just in case only casting if input is fp16, I don't know if necessary.
This commit is contained in:
kijai 2025-11-20 22:38:47 +02:00 committed by comfyanonymous
parent b1652790d2
commit 87256acf20

View File

@ -158,6 +158,9 @@ class TokenRefiner(nn.Module):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype)) t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1) # m = mask.float().unsqueeze(-1)
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise # c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
if x.dtype == torch.float16:
c = x.float().sum(dim=1) / x.shape[1]
else:
c = x.sum(dim=1) / x.shape[1] c = x.sum(dim=1) / x.shape[1]
c = t + self.c_embedder(c.to(x.dtype)) c = t + self.c_embedder(c.to(x.dtype))