diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index 9d82bee1a..bb2f6fa74 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -12,6 +12,10 @@ except: def rms_norm(x, weight=None, eps=1e-6): + + if eps is None: # Ensure eps is not None + eps = 1e-6 + if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): if weight is None: return rms_norm_torch(x, (x.shape[-1],), eps=eps) @@ -30,7 +34,7 @@ if RMSNorm is None: def __init__( self, normalized_shape, - eps=None, + eps=1e-6, # Changed default from None to 1e-6 elementwise_affine=True, device=None, dtype=None,