From 87256acf20ef04b1a8206fce7c02f60c72ea0615 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 20 Nov 2025 22:38:47 +0200 Subject: [PATCH] Fix TokenRefiner for fp16 Otherwise x.sum has infs, just in case only casting if input is fp16, I don't know if necessary. --- comfy/ldm/hunyuan_video/model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 68ba27f5e..f75c6e0e1 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -158,7 +158,10 @@ class TokenRefiner(nn.Module): t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype)) # m = mask.float().unsqueeze(-1) # c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise - c = x.sum(dim=1) / x.shape[1] + if x.dtype == torch.float16: + c = x.float().sum(dim=1) / x.shape[1] + else: + c = x.sum(dim=1) / x.shape[1] c = t + self.c_embedder(c.to(x.dtype)) x = self.input_embedder(x)