added rmsnorm patch for torch's older than 2.4

This commit is contained in:
patientx 2025-09-05 22:43:39 +03:00 committed by GitHub
parent f9d7fcb696
commit 8966009181
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -293,6 +293,121 @@ if is_zluda:
torch.istft = z_istft
torch.jit.script = z_jit
# ------------------- End Audio Patch -------------------
# ------------------- RMSNorm Compatibility Patch -------------------
# Fix for PyTorch < 2.4.0 which doesn't have nn.RMSNorm
import torch.nn as nn
if not hasattr(nn, 'RMSNorm'):
print(" :: PyTorch RMSNorm not found, adding compatibility layer.")
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6, elementwise_affine=True, device=None, dtype=None):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, device=device, dtype=dtype))
else:
self.register_parameter('weight', None)
def forward(self, x):
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
if self.elementwise_affine:
x = x * self.weight
return x
# Monkey patch nn.RMSNorm
nn.RMSNorm = RMSNorm
print(" :: RMSNorm compatibility layer installed.")
else:
print(" :: PyTorch RMSNorm found, no patch needed.")
# ------------------- End RMSNorm Patch -------------------
# ------------------- RMSNorm Compatibility Patch -------------------
# Fix for PyTorch < 2.4.0 which doesn't have nn.RMSNorm
# Based on ComfyUI's actual RMSNorm implementation
import torch.nn as nn
import numbers
if not hasattr(nn, 'RMSNorm'):
print(" :: PyTorch RMSNorm not found, adding ComfyUI-compatible layer.")
# Check if torch.nn.functional.rms_norm exists
rms_norm_torch = None
try:
rms_norm_torch = torch.nn.functional.rms_norm
except AttributeError:
rms_norm_torch = None
def rms_norm_fallback(x, weight=None, eps=1e-6):
"""Fallback RMSNorm implementation when native function unavailable"""
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
# Try to import comfy.model_management for proper casting
try:
import comfy.model_management
cast_fn = comfy.model_management.cast_to
except ImportError:
# Fallback casting function if comfy not available
cast_fn = lambda w, dtype, device: w.to(dtype=dtype, device=device) if w is not None else None
if weight is None:
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
else:
return rms_norm_torch(x, weight.shape, weight=cast_fn(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
# Manual implementation
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
if weight is None:
return r
else:
# Try to use comfy's cast function, fallback to simple casting
try:
import comfy.model_management
weight_casted = comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
except ImportError:
weight_casted = weight.to(dtype=x.dtype, device=x.device) if weight is not None else None
return r * weight_casted
class RMSNorm(nn.Module):
def __init__(
self,
normalized_shape,
eps=1e-6,
elementwise_affine=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
# Handle both int and tuple normalized_shape (like ComfyUI does)
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
# Use empty() like ComfyUI, not ones()
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
# Initialize like LayerNorm
nn.init.ones_(self.weight)
else:
self.register_parameter("weight", None)
self.bias = None # RMSNorm doesn't use bias
def forward(self, x):
return rms_norm_fallback(x, self.weight, self.eps)
# Monkey patch nn.RMSNorm
nn.RMSNorm = RMSNorm
print(" :: ComfyUI-compatible RMSNorm layer installed.")
else:
print(" :: PyTorch RMSNorm found, no patch needed.")
# ------------------- End RMSNorm Patch -------------------
# ------------------- Top-K Fallback Patch -------------------
if is_zluda:
@ -360,3 +475,4 @@ else:
print(f" :: CUDA device detected: {zluda_device_name or 'None'}")
print("***--------------------------------------------------------***\n")
# ------------------- End Zluda detection -------------------