Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2025-05-04 16:14:44 +03:00 committed by GitHub
commit 634c398fd5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 17 additions and 20 deletions

View File

@ -23,7 +23,6 @@ from einops import rearrange, repeat
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
from torch import nn from torch import nn
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
@ -37,11 +36,11 @@ def apply_rotary_pos_emb(
return t_out return t_out
def get_normalization(name: str, channels: int, weight_args={}): def get_normalization(name: str, channels: int, weight_args={}, operations=None):
if name == "I": if name == "I":
return nn.Identity() return nn.Identity()
elif name == "R": elif name == "R":
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args) return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
else: else:
raise ValueError(f"Normalization {name} not found") raise ValueError(f"Normalization {name} not found")
@ -120,15 +119,15 @@ class Attention(nn.Module):
self.to_q = nn.Sequential( self.to_q = nn.Sequential(
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args), operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[0], norm_dim), get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations),
) )
self.to_k = nn.Sequential( self.to_k = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args), operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[1], norm_dim), get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations),
) )
self.to_v = nn.Sequential( self.to_v = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args), operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[2], norm_dim), get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations),
) )
self.to_out = nn.Sequential( self.to_out = nn.Sequential(

View File

@ -27,8 +27,6 @@ from torchvision import transforms
from enum import Enum from enum import Enum
import logging import logging
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
from .blocks import ( from .blocks import (
FinalLayer, FinalLayer,
GeneralDITTransformerBlock, GeneralDITTransformerBlock,
@ -195,7 +193,7 @@ class GeneralDIT(nn.Module):
if self.affline_emb_norm: if self.affline_emb_norm:
logging.debug("Building affine embedding normalization layer") logging.debug("Building affine embedding normalization layer")
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6) self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
else: else:
self.affline_norm = nn.Identity() self.affline_norm = nn.Identity()

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import comfy.ops import comfy.ops
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from torch.utils import checkpoint from torch.utils import checkpoint
@ -51,7 +51,7 @@ class HunYuanDiTBlock(nn.Module):
if norm_type == "layer": if norm_type == "layer":
norm_layer = operations.LayerNorm norm_layer = operations.LayerNorm
elif norm_type == "rms": elif norm_type == "rms":
norm_layer = RMSNorm norm_layer = operations.RMSNorm
else: else:
raise ValueError(f"Unknown norm_type: {norm_type}") raise ValueError(f"Unknown norm_type: {norm_type}")

View File

@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import comfy.ldm.common_dit import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.layers import EmbedND
@ -64,8 +64,8 @@ class JointAttention(nn.Module):
) )
if qk_norm: if qk_norm:
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings) self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings) self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
else: else:
self.q_norm = self.k_norm = nn.Identity() self.q_norm = self.k_norm = nn.Identity()
@ -242,11 +242,11 @@ class JointTransformerBlock(nn.Module):
operation_settings=operation_settings, operation_settings=operation_settings,
) )
self.layer_id = layer_id self.layer_id = layer_id
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.modulation = modulation self.modulation = modulation
if modulation: if modulation:
@ -431,7 +431,7 @@ class NextDiT(nn.Module):
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
self.cap_embedder = nn.Sequential( self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings), operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear( operation_settings.get("operations").Linear(
cap_feat_dim, cap_feat_dim,
dim, dim,
@ -457,7 +457,7 @@ class NextDiT(nn.Module):
for layer_id in range(n_layers) for layer_id in range(n_layers)
] ]
) )
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
assert (dim // n_heads) == sum(axes_dims) assert (dim // n_heads) == sum(axes_dims)