mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 00:30:21 +08:00
custom rmsnorm
This commit is contained in:
parent
5f625fcc78
commit
371c319cf9
@ -9,11 +9,46 @@ import torch
|
|||||||
from itertools import chain
|
from itertools import chain
|
||||||
from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding
|
from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from comfy.rmsnorm import RMSNorm
|
|
||||||
from torch.nn.modules.utils import _triple
|
from torch.nn.modules.utils import _triple
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import math
|
import math
|
||||||
from comfy.ldm.flux.math import apply_rope1
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
class CustomRMSNorm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None):
|
||||||
|
super(CustomRMSNorm, self).__init__()
|
||||||
|
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
normalized_shape = (normalized_shape,)
|
||||||
|
self.normalized_shape = torch.Size(normalized_shape)
|
||||||
|
self.eps = eps
|
||||||
|
self.elementwise_affine = elementwise_affine
|
||||||
|
|
||||||
|
if self.elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(torch.ones(*normalized_shape, device=device, dtype=dtype))
|
||||||
|
else:
|
||||||
|
self.register_parameter('weight', None)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
dims = tuple(range(-len(self.normalized_shape), 0))
|
||||||
|
|
||||||
|
variance = input.pow(2).mean(dim=dims, keepdim=True)
|
||||||
|
rms = torch.sqrt(variance + self.eps)
|
||||||
|
|
||||||
|
normalized = input / rms
|
||||||
|
|
||||||
|
if self.elementwise_affine:
|
||||||
|
if hasattr(torch, 'float8_e4m3fn'):
|
||||||
|
fp8_types = (torch.float8_e4m3fn, torch.float8_e5m2)
|
||||||
|
if self.weight.dtype in fp8_types:
|
||||||
|
weight = self.weight.to(input.dtype)
|
||||||
|
return normalized * weight
|
||||||
|
|
||||||
|
return normalized * self.weight
|
||||||
|
return normalized
|
||||||
|
|
||||||
class Cache:
|
class Cache:
|
||||||
def __init__(self, disable=False, prefix="", cache=None):
|
def __init__(self, disable=False, prefix="", cache=None):
|
||||||
@ -1081,7 +1116,7 @@ class AdaSingle(nn.Module):
|
|||||||
emb = cache(
|
emb = cache(
|
||||||
f"emb_repeat_{idx}_{branch_tag}",
|
f"emb_repeat_{idx}_{branch_tag}",
|
||||||
lambda: slice_inputs(
|
lambda: slice_inputs(
|
||||||
torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]),
|
torch.repeat_interleave(emb, hid_len, dim=0),
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -1093,10 +1128,25 @@ class AdaSingle(nn.Module):
|
|||||||
getattr(self, f"{layer}_gate", None),
|
getattr(self, f"{layer}_gate", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if hasattr(torch, 'float8_e4m3fn'):
|
||||||
|
fp8_types = (torch.float8_e4m3fn, torch.float8_e5m2)
|
||||||
|
target_dtype = hid.dtype
|
||||||
|
|
||||||
|
if shiftB is not None and shiftB.dtype in fp8_types:
|
||||||
|
shiftB = shiftB.to(target_dtype)
|
||||||
|
if scaleB is not None and scaleB.dtype in fp8_types:
|
||||||
|
scaleB = scaleB.to(target_dtype)
|
||||||
|
if gateB is not None and gateB.dtype in fp8_types:
|
||||||
|
gateB = gateB.to(target_dtype)
|
||||||
|
|
||||||
if mode == "in":
|
if mode == "in":
|
||||||
return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB)
|
return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB)
|
||||||
if mode == "out":
|
if mode == "out":
|
||||||
|
if gateB is not None:
|
||||||
return hid.mul_(gateA + gateB)
|
return hid.mul_(gateA + gateB)
|
||||||
|
else:
|
||||||
|
return hid.mul_(gateA)
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@ -1221,8 +1271,8 @@ class NaDiT(nn.Module):
|
|||||||
block_type = ["mmdit_sr"] * num_layers
|
block_type = ["mmdit_sr"] * num_layers
|
||||||
window = num_layers * [(4,3,3)]
|
window = num_layers * [(4,3,3)]
|
||||||
ada = AdaSingle
|
ada = AdaSingle
|
||||||
norm = RMSNorm
|
norm = CustomRMSNorm
|
||||||
qk_norm = RMSNorm
|
qk_norm = CustomRMSNorm
|
||||||
if isinstance(block_type, str):
|
if isinstance(block_type, str):
|
||||||
block_type = [block_type] * num_layers
|
block_type = [block_type] * num_layers
|
||||||
elif len(block_type) != num_layers:
|
elif len(block_type) != num_layers:
|
||||||
@ -1308,7 +1358,7 @@ class NaDiT(nn.Module):
|
|||||||
|
|
||||||
self.vid_out_norm = None
|
self.vid_out_norm = None
|
||||||
if vid_out_norm is not None:
|
if vid_out_norm is not None:
|
||||||
self.vid_out_norm = RMSNorm(
|
self.vid_out_norm = CustomRMSNorm(
|
||||||
normalized_shape=vid_dim,
|
normalized_shape=vid_dim,
|
||||||
eps=norm_eps,
|
eps=norm_eps,
|
||||||
elementwise_affine=True,
|
elementwise_affine=True,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user