removed unnecessary vae float32 upcast

This commit is contained in:
Yousef Rafat 2026-04-08 19:08:26 +02:00
parent 2cb06431e8
commit 0ebeac98a7

View File

@ -2,7 +2,6 @@ import math
import torch
import numpy as np
import torch.nn as nn
import comfy.model_management
import torch.nn.functional as F
from fractions import Fraction
from dataclasses import dataclass
@ -78,7 +77,10 @@ class LayerNorm32(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dtype = x.dtype
x = x.to(torch.float32)
o = super().forward(x)
w = self.weight.to(torch.float32)
b = self.bias.to(torch.float32) if self.bias is not None else None
o = F.layer_norm(x, self.normalized_shape, w, b, self.eps)
return o.to(x_dtype)
class SparseConvNeXtBlock3d(nn.Module):
@ -102,8 +104,7 @@ class SparseConvNeXtBlock3d(nn.Module):
def _forward(self, x):
h = self.conv(x)
norm = self.norm.to(torch.float32)
h = h.replace(norm(h.feats))
h = h.replace(self.norm(h.feats))
h = h.replace(self.mlp(h.feats))
return h + x
@ -213,15 +214,13 @@ class SparseResBlockC2S3d(nn.Module):
dtype = next(self.to_subdiv.parameters()).dtype
x = x.to(dtype)
subdiv = self.to_subdiv(x)
norm1 = self.norm1.to(torch.float32)
norm2 = self.norm2.to(torch.float32)
h = x.replace(norm1(x.feats))
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
h = self.conv1(h)
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
h = self.updown(h, subdiv_binarized)
x = self.updown(x, subdiv_binarized)
h = h.replace(norm2(h.feats))
h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x)
@ -1300,8 +1299,6 @@ class ResBlock3d(nn.Module):
self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.norm1 = self.norm1.to(torch.float32)
self.norm2 = self.norm2.to(torch.float32)
h = self.norm1(x)
h = F.silu(h)
dtype = next(self.conv1.parameters()).dtype
@ -1381,8 +1378,7 @@ class SparseStructureDecoder(nn.Module):
for block in self.blocks:
h = block(h)
h = h.to(torch.float32)
self.out_layer = self.out_layer.to(torch.float32)
h = h.type(x.dtype)
h = self.out_layer(h)
return h