mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 05:22:34 +08:00
Fix TokenRefiner for fp16
Otherwise x.sum has infs, just in case only casting if input is fp16, I don't know if necessary.
This commit is contained in:
parent
b1652790d2
commit
87256acf20
@ -158,7 +158,10 @@ class TokenRefiner(nn.Module):
|
|||||||
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||||
# m = mask.float().unsqueeze(-1)
|
# m = mask.float().unsqueeze(-1)
|
||||||
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
|
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
|
||||||
c = x.sum(dim=1) / x.shape[1]
|
if x.dtype == torch.float16:
|
||||||
|
c = x.float().sum(dim=1) / x.shape[1]
|
||||||
|
else:
|
||||||
|
c = x.sum(dim=1) / x.shape[1]
|
||||||
|
|
||||||
c = t + self.c_embedder(c.to(x.dtype))
|
c = t + self.c_embedder(c.to(x.dtype))
|
||||||
x = self.input_embedder(x)
|
x = self.input_embedder(x)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user