mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 14:20:27 +08:00
rmsnorm patch second try
This commit is contained in:
parent
8966009181
commit
0488fe3748
@ -295,31 +295,86 @@ if is_zluda:
|
|||||||
# ------------------- End Audio Patch -------------------
|
# ------------------- End Audio Patch -------------------
|
||||||
# ------------------- RMSNorm Compatibility Patch -------------------
|
# ------------------- RMSNorm Compatibility Patch -------------------
|
||||||
# Fix for PyTorch < 2.4.0 which doesn't have nn.RMSNorm
|
# Fix for PyTorch < 2.4.0 which doesn't have nn.RMSNorm
|
||||||
|
# Based on ComfyUI's actual RMSNorm implementation
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import numbers
|
||||||
|
|
||||||
if not hasattr(nn, 'RMSNorm'):
|
if not hasattr(nn, 'RMSNorm'):
|
||||||
print(" :: PyTorch RMSNorm not found, adding compatibility layer.")
|
print(" :: PyTorch RMSNorm not found, adding ComfyUI-compatible layer.")
|
||||||
|
|
||||||
|
# Check if torch.nn.functional.rms_norm exists
|
||||||
|
rms_norm_torch = None
|
||||||
|
try:
|
||||||
|
rms_norm_torch = torch.nn.functional.rms_norm
|
||||||
|
except AttributeError:
|
||||||
|
rms_norm_torch = None
|
||||||
|
|
||||||
|
def rms_norm_fallback(x, weight=None, eps=1e-6):
|
||||||
|
"""Fallback RMSNorm implementation when native function unavailable"""
|
||||||
|
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||||
|
# Try to import comfy.model_management for proper casting
|
||||||
|
try:
|
||||||
|
import comfy.model_management
|
||||||
|
cast_fn = comfy.model_management.cast_to
|
||||||
|
except ImportError:
|
||||||
|
# Fallback casting function if comfy not available
|
||||||
|
cast_fn = lambda w, dtype, device: w.to(dtype=dtype, device=device) if w is not None else None
|
||||||
|
|
||||||
|
if weight is None:
|
||||||
|
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||||
|
else:
|
||||||
|
return rms_norm_torch(x, weight.shape, weight=cast_fn(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
|
else:
|
||||||
|
# Manual implementation
|
||||||
|
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||||
|
if weight is None:
|
||||||
|
return r
|
||||||
|
else:
|
||||||
|
# Try to use comfy's cast function, fallback to simple casting
|
||||||
|
try:
|
||||||
|
import comfy.model_management
|
||||||
|
weight_casted = comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||||
|
except ImportError:
|
||||||
|
weight_casted = weight.to(dtype=x.dtype, device=x.device) if weight is not None else None
|
||||||
|
return r * weight_casted
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dim, eps=1e-6, elementwise_affine=True, device=None, dtype=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
normalized_shape,
|
||||||
|
eps=1e-6,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# Handle both int and tuple normalized_shape (like ComfyUI does)
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
normalized_shape = (normalized_shape,)
|
||||||
|
self.normalized_shape = tuple(normalized_shape)
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.elementwise_affine = elementwise_affine
|
self.elementwise_affine = elementwise_affine
|
||||||
|
|
||||||
if self.elementwise_affine:
|
if self.elementwise_affine:
|
||||||
self.weight = nn.Parameter(torch.ones(dim, device=device, dtype=dtype))
|
# Use empty() like ComfyUI, not ones()
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||||
|
)
|
||||||
|
# Initialize like LayerNorm
|
||||||
|
nn.init.ones_(self.weight)
|
||||||
else:
|
else:
|
||||||
self.register_parameter('weight', None)
|
self.register_parameter("weight", None)
|
||||||
|
|
||||||
|
self.bias = None # RMSNorm doesn't use bias
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
return rms_norm_fallback(x, self.weight, self.eps)
|
||||||
x = x * torch.rsqrt(variance + self.eps)
|
|
||||||
if self.elementwise_affine:
|
|
||||||
x = x * self.weight
|
|
||||||
return x
|
|
||||||
|
|
||||||
# Monkey patch nn.RMSNorm
|
# Monkey patch nn.RMSNorm
|
||||||
nn.RMSNorm = RMSNorm
|
nn.RMSNorm = RMSNorm
|
||||||
print(" :: RMSNorm compatibility layer installed.")
|
print(" :: ComfyUI-compatible RMSNorm layer installed.")
|
||||||
else:
|
else:
|
||||||
print(" :: PyTorch RMSNorm found, no patch needed.")
|
print(" :: PyTorch RMSNorm found, no patch needed.")
|
||||||
# ------------------- End RMSNorm Patch -------------------
|
# ------------------- End RMSNorm Patch -------------------
|
||||||
@ -476,3 +531,4 @@ else:
|
|||||||
print("***--------------------------------------------------------***\n")
|
print("***--------------------------------------------------------***\n")
|
||||||
# ------------------- End Zluda detection -------------------
|
# ------------------- End Zluda detection -------------------
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user