Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2024-10-29 10:26:41 +03:00 committed by GitHub
commit 587b27ff26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 153 additions and 39 deletions

View File

@ -1,6 +1,6 @@
import logging import logging
import math import math
from typing import Dict, Optional from typing import Dict, Optional, List
import numpy as np import numpy as np
import torch import torch
@ -415,6 +415,7 @@ class DismantledBlock(nn.Module):
scale_mod_only: bool = False, scale_mod_only: bool = False,
swiglu: bool = False, swiglu: bool = False,
qk_norm: Optional[str] = None, qk_norm: Optional[str] = None,
x_block_self_attn: bool = False,
dtype=None, dtype=None,
device=None, device=None,
operations=None, operations=None,
@ -438,6 +439,24 @@ class DismantledBlock(nn.Module):
device=device, device=device,
operations=operations operations=operations
) )
if x_block_self_attn:
assert not pre_only
assert not scale_mod_only
self.x_block_self_attn = True
self.attn2 = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=False,
qk_norm=qk_norm,
rmsnorm=rmsnorm,
dtype=dtype,
device=device,
operations=operations
)
else:
self.x_block_self_attn = False
if not pre_only: if not pre_only:
if not rmsnorm: if not rmsnorm:
self.norm2 = operations.LayerNorm( self.norm2 = operations.LayerNorm(
@ -464,7 +483,11 @@ class DismantledBlock(nn.Module):
multiple_of=256, multiple_of=256,
) )
self.scale_mod_only = scale_mod_only self.scale_mod_only = scale_mod_only
if not scale_mod_only: if x_block_self_attn:
assert not pre_only
assert not scale_mod_only
n_mods = 9
elif not scale_mod_only:
n_mods = 6 if not pre_only else 2 n_mods = 6 if not pre_only else 2
else: else:
n_mods = 4 if not pre_only else 1 n_mods = 4 if not pre_only else 1
@ -525,14 +548,64 @@ class DismantledBlock(nn.Module):
) )
return x return x
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert self.x_block_self_attn
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
shift_msa2,
scale_msa2,
gate_msa2,
) = self.adaLN_modulation(c).chunk(9, dim=1)
x_norm = self.norm1(x)
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
return qkv, qkv2, (
x,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
gate_msa2,
)
def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2):
assert not self.pre_only
attn1 = self.attn.post_attention(attn)
attn2 = self.attn2.post_attention(attn2)
out1 = gate_msa.unsqueeze(1) * attn1
out2 = gate_msa2.unsqueeze(1) * attn2
x = x + out1
x = x + out2
x = x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
return x
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert not self.pre_only assert not self.pre_only
qkv, intermediates = self.pre_attention(x, c) if self.x_block_self_attn:
attn = optimized_attention( qkv, qkv2, intermediates = self.pre_attention_x(x, c)
qkv[0], qkv[1], qkv[2], attn, _ = optimized_attention(
heads=self.attn.num_heads, qkv[0], qkv[1], qkv[2],
) num_heads=self.attn.num_heads,
return self.post_attention(attn, *intermediates) )
attn2, _ = optimized_attention(
qkv2[0], qkv2[1], qkv2[2],
num_heads=self.attn2.num_heads,
)
return self.post_attention_x(attn, attn2, *intermediates)
else:
qkv, intermediates = self.pre_attention(x, c)
attn = optimized_attention(
qkv[0], qkv[1], qkv[2],
heads=self.attn.num_heads,
)
return self.post_attention(attn, *intermediates)
def block_mixing(*args, use_checkpoint=True, **kwargs): def block_mixing(*args, use_checkpoint=True, **kwargs):
@ -547,7 +620,10 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
def _block_mixing(context, x, context_block, x_block, c): def _block_mixing(context, x, context_block, x_block, c):
context_qkv, context_intermediates = context_block.pre_attention(context, c) context_qkv, context_intermediates = context_block.pre_attention(context, c)
x_qkv, x_intermediates = x_block.pre_attention(x, c) if x_block.x_block_self_attn:
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
else:
x_qkv, x_intermediates = x_block.pre_attention(x, c)
o = [] o = []
for t in range(3): for t in range(3):
@ -568,7 +644,14 @@ def _block_mixing(context, x, context_block, x_block, c):
else: else:
context = None context = None
x = x_block.post_attention(x_attn, *x_intermediates) if x_block.x_block_self_attn:
attn2 = optimized_attention(
x_qkv2[0], x_qkv2[1], x_qkv2[2],
heads=x_block.attn2.num_heads,
)
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
else:
x = x_block.post_attention(x_attn, *x_intermediates)
return context, x return context, x
@ -583,8 +666,13 @@ class JointBlock(nn.Module):
super().__init__() super().__init__()
pre_only = kwargs.pop("pre_only") pre_only = kwargs.pop("pre_only")
qk_norm = kwargs.pop("qk_norm", None) qk_norm = kwargs.pop("qk_norm", None)
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs) self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs) self.x_block = DismantledBlock(*args,
pre_only=False,
qk_norm=qk_norm,
x_block_self_attn=x_block_self_attn,
**kwargs)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return block_mixing( return block_mixing(
@ -699,9 +787,12 @@ class MMDiT(nn.Module):
qk_norm: Optional[str] = None, qk_norm: Optional[str] = None,
qkv_bias: bool = True, qkv_bias: bool = True,
context_processor_layers = None, context_processor_layers = None,
x_block_self_attn: bool = False,
x_block_self_attn_layers: Optional[List[int]] = [],
context_size = 4096, context_size = 4096,
num_blocks = None, num_blocks = None,
final_layer = True, final_layer = True,
skip_blocks = False,
dtype = None, #TODO dtype = None, #TODO
device = None, device = None,
operations = None, operations = None,
@ -716,6 +807,7 @@ class MMDiT(nn.Module):
self.pos_embed_scaling_factor = pos_embed_scaling_factor self.pos_embed_scaling_factor = pos_embed_scaling_factor
self.pos_embed_offset = pos_embed_offset self.pos_embed_offset = pos_embed_offset
self.pos_embed_max_size = pos_embed_max_size self.pos_embed_max_size = pos_embed_max_size
self.x_block_self_attn_layers = x_block_self_attn_layers
# hidden_size = default(hidden_size, 64 * depth) # hidden_size = default(hidden_size, 64 * depth)
# num_heads = default(num_heads, hidden_size // 64) # num_heads = default(num_heads, hidden_size // 64)
@ -773,26 +865,28 @@ class MMDiT(nn.Module):
self.pos_embed = None self.pos_embed = None
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.joint_blocks = nn.ModuleList( if not skip_blocks:
[ self.joint_blocks = nn.ModuleList(
JointBlock( [
self.hidden_size, JointBlock(
num_heads, self.hidden_size,
mlp_ratio=mlp_ratio, num_heads,
qkv_bias=qkv_bias, mlp_ratio=mlp_ratio,
attn_mode=attn_mode, qkv_bias=qkv_bias,
pre_only=(i == num_blocks - 1) and final_layer, attn_mode=attn_mode,
rmsnorm=rmsnorm, pre_only=(i == num_blocks - 1) and final_layer,
scale_mod_only=scale_mod_only, rmsnorm=rmsnorm,
swiglu=swiglu, scale_mod_only=scale_mod_only,
qk_norm=qk_norm, swiglu=swiglu,
dtype=dtype, qk_norm=qk_norm,
device=device, x_block_self_attn=(i in self.x_block_self_attn_layers) or x_block_self_attn,
operations=operations dtype=dtype,
) device=device,
for i in range(num_blocks) operations=operations,
] )
) for i in range(num_blocks)
]
)
if final_layer: if final_layer:
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
@ -855,7 +949,9 @@ class MMDiT(nn.Module):
c_mod: torch.Tensor, c_mod: torch.Tensor,
context: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None,
control = None, control = None,
transformer_options = {},
) -> torch.Tensor: ) -> torch.Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if self.register_length > 0: if self.register_length > 0:
context = torch.cat( context = torch.cat(
( (
@ -867,14 +963,25 @@ class MMDiT(nn.Module):
# context is B, L', D # context is B, L', D
# x is B, L, D # x is B, L, D
blocks_replace = patches_replace.get("dit", {})
blocks = len(self.joint_blocks) blocks = len(self.joint_blocks)
for i in range(blocks): for i in range(blocks):
context, x = self.joint_blocks[i]( if ("double_block", i) in blocks_replace:
context, def block_wrap(args):
x, out = {}
c=c_mod, out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
use_checkpoint=self.use_checkpoint, return out
)
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
context = out["txt"]
x = out["img"]
else:
context, x = self.joint_blocks[i](
context,
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
)
if control is not None: if control is not None:
control_o = control.get("output") control_o = control.get("output")
if i < len(control_o): if i < len(control_o):
@ -892,6 +999,7 @@ class MMDiT(nn.Module):
y: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None,
control = None, control = None,
transformer_options = {},
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Forward pass of DiT. Forward pass of DiT.
@ -913,7 +1021,7 @@ class MMDiT(nn.Module):
if context is not None: if context is not None:
context = self.context_embedder(context) context = self.context_embedder(context)
x = self.forward_core_with_concat(x, c, context, control) x = self.forward_core_with_concat(x, c, context, control, transformer_options)
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W) x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
return x[:,:,:hw[-2],:hw[-1]] return x[:,:,:hw[-2],:hw[-1]]
@ -927,7 +1035,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
context: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None,
control = None, control = None,
transformer_options = {},
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
return super().forward(x, timesteps, context=context, y=y, control=control) return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options)

View File

@ -70,6 +70,11 @@ def detect_unet_config(state_dict, key_prefix):
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix) context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
if context_processor in state_dict_keys: if context_processor in state_dict_keys:
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.') unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
unet_config["x_block_self_attn_layers"] = []
for key in state_dict_keys:
if key.startswith('{}joint_blocks.'.format(key_prefix)) and key.endswith('.x_block.attn2.qkv.weight'):
layer = key[len('{}joint_blocks.'.format(key_prefix)):-len('.x_block.attn2.qkv.weight')]
unet_config["x_block_self_attn_layers"].append(int(layer))
return unet_config return unet_config
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade