Update rmsnorm.py (Ensuring eps is not None)

This commit is contained in:
patientx 2025-05-05 01:46:04 +03:00 committed by GitHub
parent 634c398fd5
commit 080b5d0df4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,6 +12,10 @@ except:
def rms_norm(x, weight=None, eps=1e-6):
if eps is None: # Ensure eps is not None
eps = 1e-6
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
if weight is None:
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
@ -30,7 +34,7 @@ if RMSNorm is None:
def __init__(
self,
normalized_shape,
eps=None,
eps=1e-6, # Changed default from None to 1e-6
elementwise_affine=True,
device=None,
dtype=None,