diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index 84fd6d839..a12f892d2 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -23,7 +23,6 @@ from einops import rearrange, repeat from einops.layers.torch import Rearrange from torch import nn -from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm from comfy.ldm.modules.attention import optimized_attention @@ -37,11 +36,11 @@ def apply_rotary_pos_emb( return t_out -def get_normalization(name: str, channels: int, weight_args={}): +def get_normalization(name: str, channels: int, weight_args={}, operations=None): if name == "I": return nn.Identity() elif name == "R": - return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args) + return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args) else: raise ValueError(f"Normalization {name} not found") @@ -120,15 +119,15 @@ class Attention(nn.Module): self.to_q = nn.Sequential( operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args), - get_normalization(qkv_norm[0], norm_dim), + get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations), ) self.to_k = nn.Sequential( operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args), - get_normalization(qkv_norm[1], norm_dim), + get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations), ) self.to_v = nn.Sequential( operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args), - get_normalization(qkv_norm[2], norm_dim), + get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations), ) self.to_out = nn.Sequential( diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py index 06d0baef3..4836e0b69 100644 --- a/comfy/ldm/cosmos/model.py +++ b/comfy/ldm/cosmos/model.py @@ -27,8 +27,6 @@ from torchvision import transforms from enum import Enum import logging -from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm - from .blocks import ( FinalLayer, GeneralDITTransformerBlock, @@ -195,7 +193,7 @@ class GeneralDIT(nn.Module): if self.affline_emb_norm: logging.debug("Building affine embedding normalization layer") - self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6) + self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype) else: self.affline_norm = nn.Identity() diff --git a/comfy/ldm/hydit/models.py b/comfy/ldm/hydit/models.py index 359f6a965..5ba2b76e0 100644 --- a/comfy/ldm/hydit/models.py +++ b/comfy/ldm/hydit/models.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn import comfy.ops -from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm +from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed from comfy.ldm.modules.diffusionmodules.util import timestep_embedding from torch.utils import checkpoint @@ -51,7 +51,7 @@ class HunYuanDiTBlock(nn.Module): if norm_type == "layer": norm_layer = operations.LayerNorm elif norm_type == "rms": - norm_layer = RMSNorm + norm_layer = operations.RMSNorm else: raise ValueError(f"Unknown norm_type: {norm_type}") diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index ccd5d2c0e..f8dc4d7db 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -8,7 +8,7 @@ import torch.nn as nn import torch.nn.functional as F import comfy.ldm.common_dit -from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm +from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND @@ -64,8 +64,8 @@ class JointAttention(nn.Module): ) if qk_norm: - self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings) - self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings) + self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) else: self.q_norm = self.k_norm = nn.Identity() @@ -242,11 +242,11 @@ class JointTransformerBlock(nn.Module): operation_settings=operation_settings, ) self.layer_id = layer_id - self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) - self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) + self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) - self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) + self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.modulation = modulation if modulation: @@ -431,7 +431,7 @@ class NextDiT(nn.Module): self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) self.cap_embedder = nn.Sequential( - RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings), + operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").Linear( cap_feat_dim, dim, @@ -457,7 +457,7 @@ class NextDiT(nn.Module): for layer_id in range(n_layers) ] ) - self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) + self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) assert (dim // n_heads) == sum(axes_dims)