mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-04 00:37:32 +08:00
Remove code to support RMSNorm on old pytorch. (#12499)
This commit is contained in:
parent
1978f59ffd
commit
4454fab7f0
@ -21,7 +21,6 @@ import logging
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.rmsnorm
|
|
||||||
import json
|
import json
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.pinned_memory
|
import comfy.pinned_memory
|
||||||
@ -463,7 +462,7 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
|
class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
self.bias = None
|
self.bias = None
|
||||||
return None
|
return None
|
||||||
@ -475,8 +474,7 @@ class disable_weight_init:
|
|||||||
weight = None
|
weight = None
|
||||||
bias = None
|
bias = None
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||||
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
@ -1,57 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import numbers
|
|
||||||
import logging
|
|
||||||
|
|
||||||
RMSNorm = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
rms_norm_torch = torch.nn.functional.rms_norm
|
|
||||||
RMSNorm = torch.nn.RMSNorm
|
|
||||||
except:
|
|
||||||
rms_norm_torch = None
|
|
||||||
logging.warning("Please update pytorch to use native RMSNorm")
|
|
||||||
|
|
||||||
|
RMSNorm = torch.nn.RMSNorm
|
||||||
|
|
||||||
def rms_norm(x, weight=None, eps=1e-6):
|
def rms_norm(x, weight=None, eps=1e-6):
|
||||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
if weight is None:
|
||||||
if weight is None:
|
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
||||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
|
||||||
else:
|
|
||||||
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
|
||||||
else:
|
else:
|
||||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
if weight is None:
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
|
||||||
|
|
||||||
|
|
||||||
if RMSNorm is None:
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
normalized_shape,
|
|
||||||
eps=1e-6,
|
|
||||||
elementwise_affine=True,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
):
|
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
|
||||||
super().__init__()
|
|
||||||
if isinstance(normalized_shape, numbers.Integral):
|
|
||||||
# mypy error: incompatible types in assignment
|
|
||||||
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
|
||||||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
|
||||||
self.eps = eps
|
|
||||||
self.elementwise_affine = elementwise_affine
|
|
||||||
if self.elementwise_affine:
|
|
||||||
self.weight = torch.nn.Parameter(
|
|
||||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.register_parameter("weight", None)
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return rms_norm(x, self.weight, self.eps)
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user