custom rmsnorm

This commit is contained in:
Yousef R. Gamaleldin 2026-01-17 20:38:46 +02:00
parent 5f625fcc78
commit 371c319cf9

View File

@ -9,11 +9,46 @@ import torch
from itertools import chain
from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding
from comfy.ldm.modules.attention import optimized_attention
from comfy.rmsnorm import RMSNorm
from torch.nn.modules.utils import _triple
from torch import nn
import math
from comfy.ldm.flux.math import apply_rope1
import numbers
class CustomRMSNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None):
super(CustomRMSNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(*normalized_shape, device=device, dtype=dtype))
else:
self.register_parameter('weight', None)
def forward(self, input):
dims = tuple(range(-len(self.normalized_shape), 0))
variance = input.pow(2).mean(dim=dims, keepdim=True)
rms = torch.sqrt(variance + self.eps)
normalized = input / rms
if self.elementwise_affine:
if hasattr(torch, 'float8_e4m3fn'):
fp8_types = (torch.float8_e4m3fn, torch.float8_e5m2)
if self.weight.dtype in fp8_types:
weight = self.weight.to(input.dtype)
return normalized * weight
return normalized * self.weight
return normalized
class Cache:
def __init__(self, disable=False, prefix="", cache=None):
@ -1081,7 +1116,7 @@ class AdaSingle(nn.Module):
emb = cache(
f"emb_repeat_{idx}_{branch_tag}",
lambda: slice_inputs(
torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]),
torch.repeat_interleave(emb, hid_len, dim=0),
dim=0,
),
)
@ -1093,10 +1128,25 @@ class AdaSingle(nn.Module):
getattr(self, f"{layer}_gate", None),
)
if hasattr(torch, 'float8_e4m3fn'):
fp8_types = (torch.float8_e4m3fn, torch.float8_e5m2)
target_dtype = hid.dtype
if shiftB is not None and shiftB.dtype in fp8_types:
shiftB = shiftB.to(target_dtype)
if scaleB is not None and scaleB.dtype in fp8_types:
scaleB = scaleB.to(target_dtype)
if gateB is not None and gateB.dtype in fp8_types:
gateB = gateB.to(target_dtype)
if mode == "in":
return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB)
if mode == "out":
return hid.mul_(gateA + gateB)
if gateB is not None:
return hid.mul_(gateA + gateB)
else:
return hid.mul_(gateA)
raise NotImplementedError
@ -1221,8 +1271,8 @@ class NaDiT(nn.Module):
block_type = ["mmdit_sr"] * num_layers
window = num_layers * [(4,3,3)]
ada = AdaSingle
norm = RMSNorm
qk_norm = RMSNorm
norm = CustomRMSNorm
qk_norm = CustomRMSNorm
if isinstance(block_type, str):
block_type = [block_type] * num_layers
elif len(block_type) != num_layers:
@ -1308,7 +1358,7 @@ class NaDiT(nn.Module):
self.vid_out_norm = None
if vid_out_norm is not None:
self.vid_out_norm = RMSNorm(
self.vid_out_norm = CustomRMSNorm(
normalized_shape=vid_dim,
eps=norm_eps,
elementwise_affine=True,