mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-18 02:23:06 +08:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
commit
ce1df28bef
@ -11,7 +11,13 @@ class AudioEncoderModel():
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast)
|
||||
model_config = dict(config)
|
||||
model_config.update({
|
||||
"dtype": self.dtype,
|
||||
"device": offload_device,
|
||||
"operations": comfy.ops.manual_cast
|
||||
})
|
||||
self.model = Wav2Vec2Model(**model_config)
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
@ -33,8 +39,32 @@ class AudioEncoderModel():
|
||||
|
||||
|
||||
def load_audio_encoder_from_sd(sd, prefix=""):
|
||||
audio_encoder = AudioEncoderModel(None)
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
|
||||
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
|
||||
if embed_dim == 1024:# large
|
||||
config = {
|
||||
"embed_dim": 1024,
|
||||
"num_heads": 16,
|
||||
"num_layers": 24,
|
||||
"conv_norm": True,
|
||||
"conv_bias": True,
|
||||
"do_normalize": True,
|
||||
"do_stable_layer_norm": True
|
||||
}
|
||||
elif embed_dim == 768: # base
|
||||
config = {
|
||||
"embed_dim": 768,
|
||||
"num_heads": 12,
|
||||
"num_layers": 12,
|
||||
"conv_norm": False,
|
||||
"conv_bias": False,
|
||||
"do_normalize": False, # chinese-wav2vec2-base has this False
|
||||
"do_stable_layer_norm": False
|
||||
}
|
||||
else:
|
||||
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
|
||||
|
||||
audio_encoder = AudioEncoderModel(config)
|
||||
m, u = audio_encoder.load_sd(sd)
|
||||
if len(m) > 0:
|
||||
logging.warning("missing audio encoder: {}".format(m))
|
||||
|
||||
@ -13,19 +13,49 @@ class LayerNormConv(nn.Module):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
|
||||
|
||||
class LayerGroupNormConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||
self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(self.layer_norm(x))
|
||||
|
||||
class ConvNoNorm(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(x)
|
||||
|
||||
|
||||
class ConvFeatureEncoder(nn.Module):
|
||||
def __init__(self, conv_dim, dtype=None, device=None, operations=None):
|
||||
def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv_layers = nn.ModuleList([
|
||||
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
])
|
||||
if conv_norm:
|
||||
self.conv_layers = nn.ModuleList([
|
||||
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
])
|
||||
else:
|
||||
self.conv_layers = nn.ModuleList([
|
||||
LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
x = x.unsqueeze(1)
|
||||
@ -76,6 +106,7 @@ class TransformerEncoder(nn.Module):
|
||||
num_heads=12,
|
||||
num_layers=12,
|
||||
mlp_ratio=4.0,
|
||||
do_stable_layer_norm=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
@ -86,20 +117,25 @@ class TransformerEncoder(nn.Module):
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
do_stable_layer_norm=do_stable_layer_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
x = x + self.pos_conv_embed(x)
|
||||
all_x = ()
|
||||
if not self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
for layer in self.layers:
|
||||
all_x += (x,)
|
||||
x = layer(x, mask)
|
||||
x = self.layer_norm(x)
|
||||
if self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
all_x += (x,)
|
||||
return x, all_x
|
||||
|
||||
@ -145,6 +181,7 @@ class TransformerEncoderLayer(nn.Module):
|
||||
embed_dim=768,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
do_stable_layer_norm=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
@ -154,15 +191,19 @@ class TransformerEncoderLayer(nn.Module):
|
||||
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
residual = x
|
||||
x = self.layer_norm(x)
|
||||
if self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
x = self.attention(x, mask=mask)
|
||||
x = residual + x
|
||||
|
||||
x = x + self.feed_forward(self.final_layer_norm(x))
|
||||
return x
|
||||
if not self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
return self.final_layer_norm(x + self.feed_forward(x))
|
||||
else:
|
||||
return x + self.feed_forward(self.final_layer_norm(x))
|
||||
|
||||
|
||||
class Wav2Vec2Model(nn.Module):
|
||||
@ -174,34 +215,38 @@ class Wav2Vec2Model(nn.Module):
|
||||
final_dim=256,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
conv_norm=True,
|
||||
conv_bias=True,
|
||||
do_normalize=True,
|
||||
do_stable_layer_norm=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
conv_dim = 512
|
||||
self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations)
|
||||
self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations)
|
||||
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
self.encoder = TransformerEncoder(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers,
|
||||
do_stable_layer_norm=do_stable_layer_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
|
||||
def forward(self, x, mask_time_indices=None, return_dict=False):
|
||||
|
||||
x = torch.mean(x, dim=1)
|
||||
|
||||
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
|
||||
if self.do_normalize:
|
||||
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
|
||||
|
||||
features = self.feature_extractor(x)
|
||||
features = self.feature_projection(features)
|
||||
|
||||
batch_size, seq_len, _ = features.shape
|
||||
|
||||
x, all_x = self.encoder(features)
|
||||
|
||||
return x, all_x
|
||||
|
||||
@ -185,7 +185,6 @@ class Encoder(nn.Module):
|
||||
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
|
||||
|
||||
def forward(self, x):
|
||||
x = x.unsqueeze(2)
|
||||
x = self.conv_in(x)
|
||||
|
||||
for stage in self.down:
|
||||
|
||||
21
comfy/sd.py
21
comfy/sd.py
@ -412,9 +412,12 @@ class VAE:
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
self.downscale_ratio = 16
|
||||
self.upscale_ratio = 16
|
||||
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
self.latent_channels = 64
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||
self.downscale_index_formula = (4, 16, 16)
|
||||
self.latent_dim = 3
|
||||
self.not_video = True
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
@ -684,8 +687,11 @@ class VAE:
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if not self.not_video and self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||
if not self.not_video:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
else:
|
||||
pixel_samples = pixel_samples.unsqueeze(2)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
@ -719,7 +725,10 @@ class VAE:
|
||||
dims = self.latent_dim
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if dims == 3:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
if not self.not_video:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
else:
|
||||
pixel_samples = pixel_samples.unsqueeze(2)
|
||||
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user