mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-21 11:50:16 +08:00
fix
This commit is contained in:
parent
0488fe3748
commit
3ca065a755
@ -293,91 +293,7 @@ if is_zluda:
|
|||||||
torch.istft = z_istft
|
torch.istft = z_istft
|
||||||
torch.jit.script = z_jit
|
torch.jit.script = z_jit
|
||||||
# ------------------- End Audio Patch -------------------
|
# ------------------- End Audio Patch -------------------
|
||||||
# ------------------- RMSNorm Compatibility Patch -------------------
|
|
||||||
# Fix for PyTorch < 2.4.0 which doesn't have nn.RMSNorm
|
|
||||||
# Based on ComfyUI's actual RMSNorm implementation
|
|
||||||
import torch.nn as nn
|
|
||||||
import numbers
|
|
||||||
|
|
||||||
if not hasattr(nn, 'RMSNorm'):
|
|
||||||
print(" :: PyTorch RMSNorm not found, adding ComfyUI-compatible layer.")
|
|
||||||
|
|
||||||
# Check if torch.nn.functional.rms_norm exists
|
|
||||||
rms_norm_torch = None
|
|
||||||
try:
|
|
||||||
rms_norm_torch = torch.nn.functional.rms_norm
|
|
||||||
except AttributeError:
|
|
||||||
rms_norm_torch = None
|
|
||||||
|
|
||||||
def rms_norm_fallback(x, weight=None, eps=1e-6):
|
|
||||||
"""Fallback RMSNorm implementation when native function unavailable"""
|
|
||||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
|
||||||
# Try to import comfy.model_management for proper casting
|
|
||||||
try:
|
|
||||||
import comfy.model_management
|
|
||||||
cast_fn = comfy.model_management.cast_to
|
|
||||||
except ImportError:
|
|
||||||
# Fallback casting function if comfy not available
|
|
||||||
cast_fn = lambda w, dtype, device: w.to(dtype=dtype, device=device) if w is not None else None
|
|
||||||
|
|
||||||
if weight is None:
|
|
||||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
|
||||||
else:
|
|
||||||
return rms_norm_torch(x, weight.shape, weight=cast_fn(weight, dtype=x.dtype, device=x.device), eps=eps)
|
|
||||||
else:
|
|
||||||
# Manual implementation
|
|
||||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
|
||||||
if weight is None:
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
# Try to use comfy's cast function, fallback to simple casting
|
|
||||||
try:
|
|
||||||
import comfy.model_management
|
|
||||||
weight_casted = comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
|
||||||
except ImportError:
|
|
||||||
weight_casted = weight.to(dtype=x.dtype, device=x.device) if weight is not None else None
|
|
||||||
return r * weight_casted
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
normalized_shape,
|
|
||||||
eps=1e-6,
|
|
||||||
elementwise_affine=True,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
):
|
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# Handle both int and tuple normalized_shape (like ComfyUI does)
|
|
||||||
if isinstance(normalized_shape, numbers.Integral):
|
|
||||||
normalized_shape = (normalized_shape,)
|
|
||||||
self.normalized_shape = tuple(normalized_shape)
|
|
||||||
self.eps = eps
|
|
||||||
self.elementwise_affine = elementwise_affine
|
|
||||||
|
|
||||||
if self.elementwise_affine:
|
|
||||||
# Use empty() like ComfyUI, not ones()
|
|
||||||
self.weight = nn.Parameter(
|
|
||||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
|
||||||
)
|
|
||||||
# Initialize like LayerNorm
|
|
||||||
nn.init.ones_(self.weight)
|
|
||||||
else:
|
|
||||||
self.register_parameter("weight", None)
|
|
||||||
|
|
||||||
self.bias = None # RMSNorm doesn't use bias
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return rms_norm_fallback(x, self.weight, self.eps)
|
|
||||||
|
|
||||||
# Monkey patch nn.RMSNorm
|
|
||||||
nn.RMSNorm = RMSNorm
|
|
||||||
print(" :: ComfyUI-compatible RMSNorm layer installed.")
|
|
||||||
else:
|
|
||||||
print(" :: PyTorch RMSNorm found, no patch needed.")
|
|
||||||
# ------------------- End RMSNorm Patch -------------------
|
|
||||||
# ------------------- RMSNorm Compatibility Patch -------------------
|
# ------------------- RMSNorm Compatibility Patch -------------------
|
||||||
# Fix for PyTorch < 2.4.0 which doesn't have nn.RMSNorm
|
# Fix for PyTorch < 2.4.0 which doesn't have nn.RMSNorm
|
||||||
# Based on ComfyUI's actual RMSNorm implementation
|
# Based on ComfyUI's actual RMSNorm implementation
|
||||||
@ -532,3 +448,4 @@ else:
|
|||||||
# ------------------- End Zluda detection -------------------
|
# ------------------- End Zluda detection -------------------
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user