From 0ebeac98a78885d4c13c09691e027fb141d9e581 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 8 Apr 2026 19:08:26 +0200 Subject: [PATCH] removed unnecessary vae float32 upcast --- comfy/ldm/trellis2/vae.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 2a18c496a..c42ad8d2f 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -2,7 +2,6 @@ import math import torch import numpy as np import torch.nn as nn -import comfy.model_management import torch.nn.functional as F from fractions import Fraction from dataclasses import dataclass @@ -78,7 +77,10 @@ class LayerNorm32(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: x_dtype = x.dtype x = x.to(torch.float32) - o = super().forward(x) + w = self.weight.to(torch.float32) + b = self.bias.to(torch.float32) if self.bias is not None else None + + o = F.layer_norm(x, self.normalized_shape, w, b, self.eps) return o.to(x_dtype) class SparseConvNeXtBlock3d(nn.Module): @@ -102,8 +104,7 @@ class SparseConvNeXtBlock3d(nn.Module): def _forward(self, x): h = self.conv(x) - norm = self.norm.to(torch.float32) - h = h.replace(norm(h.feats)) + h = h.replace(self.norm(h.feats)) h = h.replace(self.mlp(h.feats)) return h + x @@ -213,15 +214,13 @@ class SparseResBlockC2S3d(nn.Module): dtype = next(self.to_subdiv.parameters()).dtype x = x.to(dtype) subdiv = self.to_subdiv(x) - norm1 = self.norm1.to(torch.float32) - norm2 = self.norm2.to(torch.float32) - h = x.replace(norm1(x.feats)) + h = x.replace(self.norm1(x.feats)) h = h.replace(F.silu(h.feats)) h = self.conv1(h) subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None h = self.updown(h, subdiv_binarized) x = self.updown(x, subdiv_binarized) - h = h.replace(norm2(h.feats)) + h = h.replace(self.norm2(h.feats)) h = h.replace(F.silu(h.feats)) h = self.conv2(h) h = h + self.skip_connection(x) @@ -1300,8 +1299,6 @@ class ResBlock3d(nn.Module): self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: - self.norm1 = self.norm1.to(torch.float32) - self.norm2 = self.norm2.to(torch.float32) h = self.norm1(x) h = F.silu(h) dtype = next(self.conv1.parameters()).dtype @@ -1381,8 +1378,7 @@ class SparseStructureDecoder(nn.Module): for block in self.blocks: h = block(h) - h = h.to(torch.float32) - self.out_layer = self.out_layer.to(torch.float32) + h = h.type(x.dtype) h = self.out_layer(h) return h