removed unnecessary vae float32 upcast

This commit is contained in:
Yousef Rafat 2026-04-08 19:08:26 +02:00
parent 2cb06431e8
commit 0ebeac98a7

View File

@ -2,7 +2,6 @@ import math
import torch import torch
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
import comfy.model_management
import torch.nn.functional as F import torch.nn.functional as F
from fractions import Fraction from fractions import Fraction
from dataclasses import dataclass from dataclasses import dataclass
@ -78,7 +77,10 @@ class LayerNorm32(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dtype = x.dtype x_dtype = x.dtype
x = x.to(torch.float32) x = x.to(torch.float32)
o = super().forward(x) w = self.weight.to(torch.float32)
b = self.bias.to(torch.float32) if self.bias is not None else None
o = F.layer_norm(x, self.normalized_shape, w, b, self.eps)
return o.to(x_dtype) return o.to(x_dtype)
class SparseConvNeXtBlock3d(nn.Module): class SparseConvNeXtBlock3d(nn.Module):
@ -102,8 +104,7 @@ class SparseConvNeXtBlock3d(nn.Module):
def _forward(self, x): def _forward(self, x):
h = self.conv(x) h = self.conv(x)
norm = self.norm.to(torch.float32) h = h.replace(self.norm(h.feats))
h = h.replace(norm(h.feats))
h = h.replace(self.mlp(h.feats)) h = h.replace(self.mlp(h.feats))
return h + x return h + x
@ -213,15 +214,13 @@ class SparseResBlockC2S3d(nn.Module):
dtype = next(self.to_subdiv.parameters()).dtype dtype = next(self.to_subdiv.parameters()).dtype
x = x.to(dtype) x = x.to(dtype)
subdiv = self.to_subdiv(x) subdiv = self.to_subdiv(x)
norm1 = self.norm1.to(torch.float32) h = x.replace(self.norm1(x.feats))
norm2 = self.norm2.to(torch.float32)
h = x.replace(norm1(x.feats))
h = h.replace(F.silu(h.feats)) h = h.replace(F.silu(h.feats))
h = self.conv1(h) h = self.conv1(h)
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
h = self.updown(h, subdiv_binarized) h = self.updown(h, subdiv_binarized)
x = self.updown(x, subdiv_binarized) x = self.updown(x, subdiv_binarized)
h = h.replace(norm2(h.feats)) h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats)) h = h.replace(F.silu(h.feats))
h = self.conv2(h) h = self.conv2(h)
h = h + self.skip_connection(x) h = h + self.skip_connection(x)
@ -1300,8 +1299,6 @@ class ResBlock3d(nn.Module):
self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
self.norm1 = self.norm1.to(torch.float32)
self.norm2 = self.norm2.to(torch.float32)
h = self.norm1(x) h = self.norm1(x)
h = F.silu(h) h = F.silu(h)
dtype = next(self.conv1.parameters()).dtype dtype = next(self.conv1.parameters()).dtype
@ -1381,8 +1378,7 @@ class SparseStructureDecoder(nn.Module):
for block in self.blocks: for block in self.blocks:
h = block(h) h = block(h)
h = h.to(torch.float32) h = h.type(x.dtype)
self.out_layer = self.out_layer.to(torch.float32)
h = self.out_layer(h) h = self.out_layer(h)
return h return h