diff --git a/comfy/customzluda/zluda-default.py b/comfy/customzluda/zluda-default.py index 6153fdd11..c64af5ed0 100644 --- a/comfy/customzluda/zluda-default.py +++ b/comfy/customzluda/zluda-default.py @@ -293,91 +293,7 @@ if is_zluda: torch.istft = z_istft torch.jit.script = z_jit # ------------------- End Audio Patch ------------------- -# ------------------- RMSNorm Compatibility Patch ------------------- -# Fix for PyTorch < 2.4.0 which doesn't have nn.RMSNorm -# Based on ComfyUI's actual RMSNorm implementation -import torch.nn as nn -import numbers -if not hasattr(nn, 'RMSNorm'): - print(" :: PyTorch RMSNorm not found, adding ComfyUI-compatible layer.") - - # Check if torch.nn.functional.rms_norm exists - rms_norm_torch = None - try: - rms_norm_torch = torch.nn.functional.rms_norm - except AttributeError: - rms_norm_torch = None - - def rms_norm_fallback(x, weight=None, eps=1e-6): - """Fallback RMSNorm implementation when native function unavailable""" - if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): - # Try to import comfy.model_management for proper casting - try: - import comfy.model_management - cast_fn = comfy.model_management.cast_to - except ImportError: - # Fallback casting function if comfy not available - cast_fn = lambda w, dtype, device: w.to(dtype=dtype, device=device) if w is not None else None - - if weight is None: - return rms_norm_torch(x, (x.shape[-1],), eps=eps) - else: - return rms_norm_torch(x, weight.shape, weight=cast_fn(weight, dtype=x.dtype, device=x.device), eps=eps) - else: - # Manual implementation - r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) - if weight is None: - return r - else: - # Try to use comfy's cast function, fallback to simple casting - try: - import comfy.model_management - weight_casted = comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device) - except ImportError: - weight_casted = weight.to(dtype=x.dtype, device=x.device) if weight is not None else None - return r * weight_casted - - class RMSNorm(nn.Module): - def __init__( - self, - normalized_shape, - eps=1e-6, - elementwise_affine=True, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - - # Handle both int and tuple normalized_shape (like ComfyUI does) - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape,) - self.normalized_shape = tuple(normalized_shape) - self.eps = eps - self.elementwise_affine = elementwise_affine - - if self.elementwise_affine: - # Use empty() like ComfyUI, not ones() - self.weight = nn.Parameter( - torch.empty(self.normalized_shape, **factory_kwargs) - ) - # Initialize like LayerNorm - nn.init.ones_(self.weight) - else: - self.register_parameter("weight", None) - - self.bias = None # RMSNorm doesn't use bias - - def forward(self, x): - return rms_norm_fallback(x, self.weight, self.eps) - - # Monkey patch nn.RMSNorm - nn.RMSNorm = RMSNorm - print(" :: ComfyUI-compatible RMSNorm layer installed.") -else: - print(" :: PyTorch RMSNorm found, no patch needed.") -# ------------------- End RMSNorm Patch ------------------- # ------------------- RMSNorm Compatibility Patch ------------------- # Fix for PyTorch < 2.4.0 which doesn't have nn.RMSNorm # Based on ComfyUI's actual RMSNorm implementation @@ -532,3 +448,4 @@ else: # ------------------- End Zluda detection ------------------- +