mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-30 08:10:21 +08:00
custom rmsnorm
This commit is contained in:
parent
5f625fcc78
commit
371c319cf9
@ -9,11 +9,46 @@ import torch
|
||||
from itertools import chain
|
||||
from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.rmsnorm import RMSNorm
|
||||
from torch.nn.modules.utils import _triple
|
||||
from torch import nn
|
||||
import math
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
import numbers
|
||||
|
||||
class CustomRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None):
|
||||
super(CustomRMSNorm, self).__init__()
|
||||
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
normalized_shape = (normalized_shape,)
|
||||
self.normalized_shape = torch.Size(normalized_shape)
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(*normalized_shape, device=device, dtype=dtype))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
dims = tuple(range(-len(self.normalized_shape), 0))
|
||||
|
||||
variance = input.pow(2).mean(dim=dims, keepdim=True)
|
||||
rms = torch.sqrt(variance + self.eps)
|
||||
|
||||
normalized = input / rms
|
||||
|
||||
if self.elementwise_affine:
|
||||
if hasattr(torch, 'float8_e4m3fn'):
|
||||
fp8_types = (torch.float8_e4m3fn, torch.float8_e5m2)
|
||||
if self.weight.dtype in fp8_types:
|
||||
weight = self.weight.to(input.dtype)
|
||||
return normalized * weight
|
||||
|
||||
return normalized * self.weight
|
||||
return normalized
|
||||
|
||||
class Cache:
|
||||
def __init__(self, disable=False, prefix="", cache=None):
|
||||
@ -1081,7 +1116,7 @@ class AdaSingle(nn.Module):
|
||||
emb = cache(
|
||||
f"emb_repeat_{idx}_{branch_tag}",
|
||||
lambda: slice_inputs(
|
||||
torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]),
|
||||
torch.repeat_interleave(emb, hid_len, dim=0),
|
||||
dim=0,
|
||||
),
|
||||
)
|
||||
@ -1093,10 +1128,25 @@ class AdaSingle(nn.Module):
|
||||
getattr(self, f"{layer}_gate", None),
|
||||
)
|
||||
|
||||
if hasattr(torch, 'float8_e4m3fn'):
|
||||
fp8_types = (torch.float8_e4m3fn, torch.float8_e5m2)
|
||||
target_dtype = hid.dtype
|
||||
|
||||
if shiftB is not None and shiftB.dtype in fp8_types:
|
||||
shiftB = shiftB.to(target_dtype)
|
||||
if scaleB is not None and scaleB.dtype in fp8_types:
|
||||
scaleB = scaleB.to(target_dtype)
|
||||
if gateB is not None and gateB.dtype in fp8_types:
|
||||
gateB = gateB.to(target_dtype)
|
||||
|
||||
if mode == "in":
|
||||
return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB)
|
||||
if mode == "out":
|
||||
return hid.mul_(gateA + gateB)
|
||||
if gateB is not None:
|
||||
return hid.mul_(gateA + gateB)
|
||||
else:
|
||||
return hid.mul_(gateA)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -1221,8 +1271,8 @@ class NaDiT(nn.Module):
|
||||
block_type = ["mmdit_sr"] * num_layers
|
||||
window = num_layers * [(4,3,3)]
|
||||
ada = AdaSingle
|
||||
norm = RMSNorm
|
||||
qk_norm = RMSNorm
|
||||
norm = CustomRMSNorm
|
||||
qk_norm = CustomRMSNorm
|
||||
if isinstance(block_type, str):
|
||||
block_type = [block_type] * num_layers
|
||||
elif len(block_type) != num_layers:
|
||||
@ -1308,7 +1358,7 @@ class NaDiT(nn.Module):
|
||||
|
||||
self.vid_out_norm = None
|
||||
if vid_out_norm is not None:
|
||||
self.vid_out_norm = RMSNorm(
|
||||
self.vid_out_norm = CustomRMSNorm(
|
||||
normalized_shape=vid_dim,
|
||||
eps=norm_eps,
|
||||
elementwise_affine=True,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user