diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index 064591aa3..6436dd304 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -30,15 +30,14 @@ class TimeEmbeddings(nn.Module): self.model_dim = model_dim self.max_period = max_period self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False) - self.dtype = operation_settings.get("dtype") operations = operation_settings.get("operations") - self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=self.dtype) + self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.activation = nn.SiLU() - self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=self.dtype) + self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) def forward(self, timestep): args = torch.outer(timestep, self.freqs.to(device=timestep.device)) - time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(self.dtype) + time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(torch.bfloat16) #todo dtype time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) return time_embed