mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-14 20:42:31 +08:00
removed unnecessary vae float32 upcast
This commit is contained in:
parent
2cb06431e8
commit
0ebeac98a7
@ -2,7 +2,6 @@ import math
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import comfy.model_management
|
||||
import torch.nn.functional as F
|
||||
from fractions import Fraction
|
||||
from dataclasses import dataclass
|
||||
@ -78,7 +77,10 @@ class LayerNorm32(nn.LayerNorm):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
o = super().forward(x)
|
||||
w = self.weight.to(torch.float32)
|
||||
b = self.bias.to(torch.float32) if self.bias is not None else None
|
||||
|
||||
o = F.layer_norm(x, self.normalized_shape, w, b, self.eps)
|
||||
return o.to(x_dtype)
|
||||
|
||||
class SparseConvNeXtBlock3d(nn.Module):
|
||||
@ -102,8 +104,7 @@ class SparseConvNeXtBlock3d(nn.Module):
|
||||
|
||||
def _forward(self, x):
|
||||
h = self.conv(x)
|
||||
norm = self.norm.to(torch.float32)
|
||||
h = h.replace(norm(h.feats))
|
||||
h = h.replace(self.norm(h.feats))
|
||||
h = h.replace(self.mlp(h.feats))
|
||||
return h + x
|
||||
|
||||
@ -213,15 +214,13 @@ class SparseResBlockC2S3d(nn.Module):
|
||||
dtype = next(self.to_subdiv.parameters()).dtype
|
||||
x = x.to(dtype)
|
||||
subdiv = self.to_subdiv(x)
|
||||
norm1 = self.norm1.to(torch.float32)
|
||||
norm2 = self.norm2.to(torch.float32)
|
||||
h = x.replace(norm1(x.feats))
|
||||
h = x.replace(self.norm1(x.feats))
|
||||
h = h.replace(F.silu(h.feats))
|
||||
h = self.conv1(h)
|
||||
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
|
||||
h = self.updown(h, subdiv_binarized)
|
||||
x = self.updown(x, subdiv_binarized)
|
||||
h = h.replace(norm2(h.feats))
|
||||
h = h.replace(self.norm2(h.feats))
|
||||
h = h.replace(F.silu(h.feats))
|
||||
h = self.conv2(h)
|
||||
h = h + self.skip_connection(x)
|
||||
@ -1300,8 +1299,6 @@ class ResBlock3d(nn.Module):
|
||||
self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
self.norm1 = self.norm1.to(torch.float32)
|
||||
self.norm2 = self.norm2.to(torch.float32)
|
||||
h = self.norm1(x)
|
||||
h = F.silu(h)
|
||||
dtype = next(self.conv1.parameters()).dtype
|
||||
@ -1381,8 +1378,7 @@ class SparseStructureDecoder(nn.Module):
|
||||
for block in self.blocks:
|
||||
h = block(h)
|
||||
|
||||
h = h.to(torch.float32)
|
||||
self.out_layer = self.out_layer.to(torch.float32)
|
||||
h = h.type(x.dtype)
|
||||
h = self.out_layer(h)
|
||||
return h
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user