From 896600918166dcee320815c0e96ea0ad92399e38 Mon Sep 17 00:00:00 2001 From: patientx Date: Fri, 5 Sep 2025 22:43:39 +0300 Subject: [PATCH] added rmsnorm patch for torch's older than 2.4 --- comfy/customzluda/zluda-default.py | 116 +++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/comfy/customzluda/zluda-default.py b/comfy/customzluda/zluda-default.py index e043a20ee..f431914eb 100644 --- a/comfy/customzluda/zluda-default.py +++ b/comfy/customzluda/zluda-default.py @@ -293,6 +293,121 @@ if is_zluda: torch.istft = z_istft torch.jit.script = z_jit # ------------------- End Audio Patch ------------------- +# ------------------- RMSNorm Compatibility Patch ------------------- +# Fix for PyTorch < 2.4.0 which doesn't have nn.RMSNorm +import torch.nn as nn + +if not hasattr(nn, 'RMSNorm'): + print(" :: PyTorch RMSNorm not found, adding compatibility layer.") + + class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-6, elementwise_affine=True, device=None, dtype=None): + super().__init__() + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, device=device, dtype=dtype)) + else: + self.register_parameter('weight', None) + + def forward(self, x): + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + if self.elementwise_affine: + x = x * self.weight + return x + + # Monkey patch nn.RMSNorm + nn.RMSNorm = RMSNorm + print(" :: RMSNorm compatibility layer installed.") +else: + print(" :: PyTorch RMSNorm found, no patch needed.") +# ------------------- End RMSNorm Patch ------------------- +# ------------------- RMSNorm Compatibility Patch ------------------- +# Fix for PyTorch < 2.4.0 which doesn't have nn.RMSNorm +# Based on ComfyUI's actual RMSNorm implementation +import torch.nn as nn +import numbers + +if not hasattr(nn, 'RMSNorm'): + print(" :: PyTorch RMSNorm not found, adding ComfyUI-compatible layer.") + + # Check if torch.nn.functional.rms_norm exists + rms_norm_torch = None + try: + rms_norm_torch = torch.nn.functional.rms_norm + except AttributeError: + rms_norm_torch = None + + def rms_norm_fallback(x, weight=None, eps=1e-6): + """Fallback RMSNorm implementation when native function unavailable""" + if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): + # Try to import comfy.model_management for proper casting + try: + import comfy.model_management + cast_fn = comfy.model_management.cast_to + except ImportError: + # Fallback casting function if comfy not available + cast_fn = lambda w, dtype, device: w.to(dtype=dtype, device=device) if w is not None else None + + if weight is None: + return rms_norm_torch(x, (x.shape[-1],), eps=eps) + else: + return rms_norm_torch(x, weight.shape, weight=cast_fn(weight, dtype=x.dtype, device=x.device), eps=eps) + else: + # Manual implementation + r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) + if weight is None: + return r + else: + # Try to use comfy's cast function, fallback to simple casting + try: + import comfy.model_management + weight_casted = comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device) + except ImportError: + weight_casted = weight.to(dtype=x.dtype, device=x.device) if weight is not None else None + return r * weight_casted + + class RMSNorm(nn.Module): + def __init__( + self, + normalized_shape, + eps=1e-6, + elementwise_affine=True, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + # Handle both int and tuple normalized_shape (like ComfyUI does) + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + + if self.elementwise_affine: + # Use empty() like ComfyUI, not ones() + self.weight = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + # Initialize like LayerNorm + nn.init.ones_(self.weight) + else: + self.register_parameter("weight", None) + + self.bias = None # RMSNorm doesn't use bias + + def forward(self, x): + return rms_norm_fallback(x, self.weight, self.eps) + + # Monkey patch nn.RMSNorm + nn.RMSNorm = RMSNorm + print(" :: ComfyUI-compatible RMSNorm layer installed.") +else: + print(" :: PyTorch RMSNorm found, no patch needed.") +# ------------------- End RMSNorm Patch ------------------- # ------------------- Top-K Fallback Patch ------------------- if is_zluda: @@ -360,3 +475,4 @@ else: print(f" :: CUDA device detected: {zluda_device_name or 'None'}") print("***--------------------------------------------------------***\n") # ------------------- End Zluda detection ------------------- +