From 5f1346ccd107b48a5ef15f30b4525229d6446830 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 24 Nov 2025 22:47:54 +0200 Subject: [PATCH] Fix fp8 --- comfy/ldm/kandinsky5/model.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index 064591aa3..6436dd304 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -30,15 +30,14 @@ class TimeEmbeddings(nn.Module): self.model_dim = model_dim self.max_period = max_period self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False) - self.dtype = operation_settings.get("dtype") operations = operation_settings.get("operations") - self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=self.dtype) + self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.activation = nn.SiLU() - self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=self.dtype) + self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) def forward(self, timestep): args = torch.outer(timestep, self.freqs.to(device=timestep.device)) - time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(self.dtype) + time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(torch.bfloat16) #todo dtype time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) return time_embed