mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 23:30:16 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
634c398fd5
@ -23,7 +23,6 @@ from einops import rearrange, repeat
|
|||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
|
||||||
@ -37,11 +36,11 @@ def apply_rotary_pos_emb(
|
|||||||
return t_out
|
return t_out
|
||||||
|
|
||||||
|
|
||||||
def get_normalization(name: str, channels: int, weight_args={}):
|
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
||||||
if name == "I":
|
if name == "I":
|
||||||
return nn.Identity()
|
return nn.Identity()
|
||||||
elif name == "R":
|
elif name == "R":
|
||||||
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Normalization {name} not found")
|
raise ValueError(f"Normalization {name} not found")
|
||||||
|
|
||||||
@ -120,15 +119,15 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
self.to_q = nn.Sequential(
|
self.to_q = nn.Sequential(
|
||||||
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
|
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
get_normalization(qkv_norm[0], norm_dim),
|
get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations),
|
||||||
)
|
)
|
||||||
self.to_k = nn.Sequential(
|
self.to_k = nn.Sequential(
|
||||||
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
get_normalization(qkv_norm[1], norm_dim),
|
get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations),
|
||||||
)
|
)
|
||||||
self.to_v = nn.Sequential(
|
self.to_v = nn.Sequential(
|
||||||
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
get_normalization(qkv_norm[2], norm_dim),
|
get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
|
|||||||
@ -27,8 +27,6 @@ from torchvision import transforms
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
|
||||||
|
|
||||||
from .blocks import (
|
from .blocks import (
|
||||||
FinalLayer,
|
FinalLayer,
|
||||||
GeneralDITTransformerBlock,
|
GeneralDITTransformerBlock,
|
||||||
@ -195,7 +193,7 @@ class GeneralDIT(nn.Module):
|
|||||||
|
|
||||||
if self.affline_emb_norm:
|
if self.affline_emb_norm:
|
||||||
logging.debug("Building affine embedding normalization layer")
|
logging.debug("Building affine embedding normalization layer")
|
||||||
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
|
self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
self.affline_norm = nn.Identity()
|
self.affline_norm = nn.Identity()
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed
|
||||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||||
from torch.utils import checkpoint
|
from torch.utils import checkpoint
|
||||||
|
|
||||||
@ -51,7 +51,7 @@ class HunYuanDiTBlock(nn.Module):
|
|||||||
if norm_type == "layer":
|
if norm_type == "layer":
|
||||||
norm_layer = operations.LayerNorm
|
norm_layer = operations.LayerNorm
|
||||||
elif norm_type == "rms":
|
elif norm_type == "rms":
|
||||||
norm_layer = RMSNorm
|
norm_layer = operations.RMSNorm
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown norm_type: {norm_type}")
|
raise ValueError(f"Unknown norm_type: {norm_type}")
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
|
||||||
@ -64,8 +64,8 @@ class JointAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if qk_norm:
|
if qk_norm:
|
||||||
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
else:
|
else:
|
||||||
self.q_norm = self.k_norm = nn.Identity()
|
self.q_norm = self.k_norm = nn.Identity()
|
||||||
|
|
||||||
@ -242,11 +242,11 @@ class JointTransformerBlock(nn.Module):
|
|||||||
operation_settings=operation_settings,
|
operation_settings=operation_settings,
|
||||||
)
|
)
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
self.modulation = modulation
|
self.modulation = modulation
|
||||||
if modulation:
|
if modulation:
|
||||||
@ -431,7 +431,7 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
||||||
self.cap_embedder = nn.Sequential(
|
self.cap_embedder = nn.Sequential(
|
||||||
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings),
|
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||||
operation_settings.get("operations").Linear(
|
operation_settings.get("operations").Linear(
|
||||||
cap_feat_dim,
|
cap_feat_dim,
|
||||||
dim,
|
dim,
|
||||||
@ -457,7 +457,7 @@ class NextDiT(nn.Module):
|
|||||||
for layer_id in range(n_layers)
|
for layer_id in range(n_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
||||||
|
|
||||||
assert (dim // n_heads) == sum(axes_dims)
|
assert (dim // n_heads) == sum(axes_dims)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user