mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
added rmsnorm patch for torch's older than 2.4
This commit is contained in:
parent
f9d7fcb696
commit
8966009181
@ -293,6 +293,121 @@ if is_zluda:
|
||||
torch.istft = z_istft
|
||||
torch.jit.script = z_jit
|
||||
# ------------------- End Audio Patch -------------------
|
||||
# ------------------- RMSNorm Compatibility Patch -------------------
|
||||
# Fix for PyTorch < 2.4.0 which doesn't have nn.RMSNorm
|
||||
import torch.nn as nn
|
||||
|
||||
if not hasattr(nn, 'RMSNorm'):
|
||||
print(" :: PyTorch RMSNorm not found, adding compatibility layer.")
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-6, elementwise_affine=True, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, device=device, dtype=dtype))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
|
||||
def forward(self, x):
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.eps)
|
||||
if self.elementwise_affine:
|
||||
x = x * self.weight
|
||||
return x
|
||||
|
||||
# Monkey patch nn.RMSNorm
|
||||
nn.RMSNorm = RMSNorm
|
||||
print(" :: RMSNorm compatibility layer installed.")
|
||||
else:
|
||||
print(" :: PyTorch RMSNorm found, no patch needed.")
|
||||
# ------------------- End RMSNorm Patch -------------------
|
||||
# ------------------- RMSNorm Compatibility Patch -------------------
|
||||
# Fix for PyTorch < 2.4.0 which doesn't have nn.RMSNorm
|
||||
# Based on ComfyUI's actual RMSNorm implementation
|
||||
import torch.nn as nn
|
||||
import numbers
|
||||
|
||||
if not hasattr(nn, 'RMSNorm'):
|
||||
print(" :: PyTorch RMSNorm not found, adding ComfyUI-compatible layer.")
|
||||
|
||||
# Check if torch.nn.functional.rms_norm exists
|
||||
rms_norm_torch = None
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
except AttributeError:
|
||||
rms_norm_torch = None
|
||||
|
||||
def rms_norm_fallback(x, weight=None, eps=1e-6):
|
||||
"""Fallback RMSNorm implementation when native function unavailable"""
|
||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
# Try to import comfy.model_management for proper casting
|
||||
try:
|
||||
import comfy.model_management
|
||||
cast_fn = comfy.model_management.cast_to
|
||||
except ImportError:
|
||||
# Fallback casting function if comfy not available
|
||||
cast_fn = lambda w, dtype, device: w.to(dtype=dtype, device=device) if w is not None else None
|
||||
|
||||
if weight is None:
|
||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||
else:
|
||||
return rms_norm_torch(x, weight.shape, weight=cast_fn(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
else:
|
||||
# Manual implementation
|
||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
if weight is None:
|
||||
return r
|
||||
else:
|
||||
# Try to use comfy's cast function, fallback to simple casting
|
||||
try:
|
||||
import comfy.model_management
|
||||
weight_casted = comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
except ImportError:
|
||||
weight_casted = weight.to(dtype=x.dtype, device=x.device) if weight is not None else None
|
||||
return r * weight_casted
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape,
|
||||
eps=1e-6,
|
||||
elementwise_affine=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
|
||||
# Handle both int and tuple normalized_shape (like ComfyUI does)
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
normalized_shape = (normalized_shape,)
|
||||
self.normalized_shape = tuple(normalized_shape)
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
|
||||
if self.elementwise_affine:
|
||||
# Use empty() like ComfyUI, not ones()
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
# Initialize like LayerNorm
|
||||
nn.init.ones_(self.weight)
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
|
||||
self.bias = None # RMSNorm doesn't use bias
|
||||
|
||||
def forward(self, x):
|
||||
return rms_norm_fallback(x, self.weight, self.eps)
|
||||
|
||||
# Monkey patch nn.RMSNorm
|
||||
nn.RMSNorm = RMSNorm
|
||||
print(" :: ComfyUI-compatible RMSNorm layer installed.")
|
||||
else:
|
||||
print(" :: PyTorch RMSNorm found, no patch needed.")
|
||||
# ------------------- End RMSNorm Patch -------------------
|
||||
|
||||
# ------------------- Top-K Fallback Patch -------------------
|
||||
if is_zluda:
|
||||
@ -360,3 +475,4 @@ else:
|
||||
print(f" :: CUDA device detected: {zluda_device_name or 'None'}")
|
||||
print("***--------------------------------------------------------***\n")
|
||||
# ------------------- End Zluda detection -------------------
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user