From 0488fe3748dde0c1a1e5710b3eaf33e84b0698c4 Mon Sep 17 00:00:00 2001 From: patientx Date: Fri, 5 Sep 2025 23:10:27 +0300 Subject: [PATCH] rmsnorm patch second try --- comfy/customzluda/zluda-default.py | 76 ++++++++++++++++++++++++++---- 1 file changed, 66 insertions(+), 10 deletions(-) diff --git a/comfy/customzluda/zluda-default.py b/comfy/customzluda/zluda-default.py index f431914eb..6153fdd11 100644 --- a/comfy/customzluda/zluda-default.py +++ b/comfy/customzluda/zluda-default.py @@ -295,31 +295,86 @@ if is_zluda: # ------------------- End Audio Patch ------------------- # ------------------- RMSNorm Compatibility Patch ------------------- # Fix for PyTorch < 2.4.0 which doesn't have nn.RMSNorm +# Based on ComfyUI's actual RMSNorm implementation import torch.nn as nn +import numbers if not hasattr(nn, 'RMSNorm'): - print(" :: PyTorch RMSNorm not found, adding compatibility layer.") + print(" :: PyTorch RMSNorm not found, adding ComfyUI-compatible layer.") + + # Check if torch.nn.functional.rms_norm exists + rms_norm_torch = None + try: + rms_norm_torch = torch.nn.functional.rms_norm + except AttributeError: + rms_norm_torch = None + + def rms_norm_fallback(x, weight=None, eps=1e-6): + """Fallback RMSNorm implementation when native function unavailable""" + if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): + # Try to import comfy.model_management for proper casting + try: + import comfy.model_management + cast_fn = comfy.model_management.cast_to + except ImportError: + # Fallback casting function if comfy not available + cast_fn = lambda w, dtype, device: w.to(dtype=dtype, device=device) if w is not None else None + + if weight is None: + return rms_norm_torch(x, (x.shape[-1],), eps=eps) + else: + return rms_norm_torch(x, weight.shape, weight=cast_fn(weight, dtype=x.dtype, device=x.device), eps=eps) + else: + # Manual implementation + r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) + if weight is None: + return r + else: + # Try to use comfy's cast function, fallback to simple casting + try: + import comfy.model_management + weight_casted = comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device) + except ImportError: + weight_casted = weight.to(dtype=x.dtype, device=x.device) if weight is not None else None + return r * weight_casted class RMSNorm(nn.Module): - def __init__(self, dim, eps=1e-6, elementwise_affine=True, device=None, dtype=None): + def __init__( + self, + normalized_shape, + eps=1e-6, + elementwise_affine=True, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() + + # Handle both int and tuple normalized_shape (like ComfyUI does) + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) self.eps = eps self.elementwise_affine = elementwise_affine + if self.elementwise_affine: - self.weight = nn.Parameter(torch.ones(dim, device=device, dtype=dtype)) + # Use empty() like ComfyUI, not ones() + self.weight = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + # Initialize like LayerNorm + nn.init.ones_(self.weight) else: - self.register_parameter('weight', None) + self.register_parameter("weight", None) + + self.bias = None # RMSNorm doesn't use bias def forward(self, x): - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + self.eps) - if self.elementwise_affine: - x = x * self.weight - return x + return rms_norm_fallback(x, self.weight, self.eps) # Monkey patch nn.RMSNorm nn.RMSNorm = RMSNorm - print(" :: RMSNorm compatibility layer installed.") + print(" :: ComfyUI-compatible RMSNorm layer installed.") else: print(" :: PyTorch RMSNorm found, no patch needed.") # ------------------- End RMSNorm Patch ------------------- @@ -476,3 +531,4 @@ else: print("***--------------------------------------------------------***\n") # ------------------- End Zluda detection ------------------- +