From 371c319cf917fa21cbdefffc9b4ae3b7d62f7663 Mon Sep 17 00:00:00 2001 From: "Yousef R. Gamaleldin" Date: Sat, 17 Jan 2026 20:38:46 +0200 Subject: [PATCH] custom rmsnorm --- comfy/ldm/seedvr/model.py | 62 +++++++++++++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index eb2237eee..f42dcb1e2 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -9,11 +9,46 @@ import torch from itertools import chain from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding from comfy.ldm.modules.attention import optimized_attention -from comfy.rmsnorm import RMSNorm from torch.nn.modules.utils import _triple from torch import nn import math from comfy.ldm.flux.math import apply_rope1 +import numbers + +class CustomRMSNorm(nn.Module): + + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None): + super(CustomRMSNorm, self).__init__() + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + + if self.elementwise_affine: + self.weight = nn.Parameter(torch.ones(*normalized_shape, device=device, dtype=dtype)) + else: + self.register_parameter('weight', None) + + def forward(self, input): + + dims = tuple(range(-len(self.normalized_shape), 0)) + + variance = input.pow(2).mean(dim=dims, keepdim=True) + rms = torch.sqrt(variance + self.eps) + + normalized = input / rms + + if self.elementwise_affine: + if hasattr(torch, 'float8_e4m3fn'): + fp8_types = (torch.float8_e4m3fn, torch.float8_e5m2) + if self.weight.dtype in fp8_types: + weight = self.weight.to(input.dtype) + return normalized * weight + + return normalized * self.weight + return normalized class Cache: def __init__(self, disable=False, prefix="", cache=None): @@ -1081,7 +1116,7 @@ class AdaSingle(nn.Module): emb = cache( f"emb_repeat_{idx}_{branch_tag}", lambda: slice_inputs( - torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]), + torch.repeat_interleave(emb, hid_len, dim=0), dim=0, ), ) @@ -1093,10 +1128,25 @@ class AdaSingle(nn.Module): getattr(self, f"{layer}_gate", None), ) + if hasattr(torch, 'float8_e4m3fn'): + fp8_types = (torch.float8_e4m3fn, torch.float8_e5m2) + target_dtype = hid.dtype + + if shiftB is not None and shiftB.dtype in fp8_types: + shiftB = shiftB.to(target_dtype) + if scaleB is not None and scaleB.dtype in fp8_types: + scaleB = scaleB.to(target_dtype) + if gateB is not None and gateB.dtype in fp8_types: + gateB = gateB.to(target_dtype) + if mode == "in": return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) if mode == "out": - return hid.mul_(gateA + gateB) + if gateB is not None: + return hid.mul_(gateA + gateB) + else: + return hid.mul_(gateA) + raise NotImplementedError @@ -1221,8 +1271,8 @@ class NaDiT(nn.Module): block_type = ["mmdit_sr"] * num_layers window = num_layers * [(4,3,3)] ada = AdaSingle - norm = RMSNorm - qk_norm = RMSNorm + norm = CustomRMSNorm + qk_norm = CustomRMSNorm if isinstance(block_type, str): block_type = [block_type] * num_layers elif len(block_type) != num_layers: @@ -1308,7 +1358,7 @@ class NaDiT(nn.Module): self.vid_out_norm = None if vid_out_norm is not None: - self.vid_out_norm = RMSNorm( + self.vid_out_norm = CustomRMSNorm( normalized_shape=vid_dim, eps=norm_eps, elementwise_affine=True,