Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2025-05-08 14:52:54 +03:00 committed by GitHub
commit e973632f11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 183 additions and 13 deletions

View File

@ -1,7 +1,12 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py
import torch
from .autoencoder_dc import AutoencoderDC
import torchaudio
import logging
try:
import torchaudio
except:
logging.warning("torchaudio missing, ACE model will be broken")
import torchvision.transforms as transforms
from .music_vocoder import ADaMoSHiFiGANV1

View File

@ -2,7 +2,12 @@
import torch
import torch.nn as nn
from torch import Tensor
from torchaudio.transforms import MelScale
import logging
try:
from torchaudio.transforms import MelScale
except:
logging.warning("torchaudio missing, ACE model will be broken")
import comfy.model_management
class LinearSpectrogram(nn.Module):

View File

@ -222,6 +222,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {}
dit_config["image_model"] = "ltxv"
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
dit_config["attention_head_dim"] = shape[0] // 32
dit_config["cross_attention_dim"] = shape[1]
if metadata is not None and "config" in metadata:
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
return dit_config

View File

@ -282,6 +282,7 @@ class VAE:
self.downscale_index_formula = None
self.upscale_index_formula = None
self.extra_1d_channel = None
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@ -441,17 +442,18 @@ class VAE:
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 300) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 72000) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
self.latent_channels = 8
self.output_channels = 2
# self.upscale_ratio = 2048
# self.downscale_ratio = 2048
self.upscale_ratio = 4096
self.downscale_ratio = 4096
self.latent_dim = 2
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = True
self.extra_1d_channel = 16
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -510,7 +512,13 @@ class VAE:
return output
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
else:
og_shape = samples.shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
@ -530,9 +538,24 @@ class VAE:
samples /= 3.0
return samples
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio
else:
extra_channel_size = self.extra_1d_channel
out_channels = self.latent_channels * extra_channel_size
tile_x = tile_x // extra_channel_size
overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1:
return out
else:
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
@ -557,7 +580,7 @@ class VAE:
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
dims = samples_in.ndim - 2
if dims == 1:
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
@ -624,7 +647,7 @@ class VAE:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1:
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
else:
samples = self.encode_tiled_(pixel_samples)

View File

@ -7,7 +7,7 @@ import torch
import logging
from tokenizers import Tokenizer
from .ace_text_cleaners import multilingual_cleaners
from .ace_text_cleaners import multilingual_cleaners, japanese_to_romaji
SUPPORT_LANGUAGES = {
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
@ -65,6 +65,14 @@ class VoiceBpeTokenizer:
if "spa" in lang:
lang = "es"
try:
line_out = japanese_to_romaji(line)
if line_out != line:
lang = "ja"
line = line_out
except:
pass
try:
if structure_pattern.match(line):
token_idx = self.encode(line, "en")

View File

@ -4,6 +4,131 @@
import re
def japanese_to_romaji(japanese_text):
"""
Convert Japanese hiragana and katakana to romaji (Latin alphabet representation).
Args:
japanese_text (str): Text containing hiragana and/or katakana characters
Returns:
str: The romaji (Latin alphabet) equivalent
"""
# Dictionary mapping kana characters to their romaji equivalents
kana_map = {
# Katakana characters
'': 'a', '': 'i', '': 'u', '': 'e', '': 'o',
'': 'ka', '': 'ki', '': 'ku', '': 'ke', '': 'ko',
'': 'sa', '': 'shi', '': 'su', '': 'se', '': 'so',
'': 'ta', '': 'chi', '': 'tsu', '': 'te', '': 'to',
'': 'na', '': 'ni', '': 'nu', '': 'ne', '': 'no',
'': 'ha', '': 'hi', '': 'fu', '': 'he', '': 'ho',
'': 'ma', '': 'mi', '': 'mu', '': 'me', '': 'mo',
'': 'ya', '': 'yu', '': 'yo',
'': 'ra', '': 'ri', '': 'ru', '': 're', '': 'ro',
'': 'wa', '': 'wo', '': 'n',
# Katakana voiced consonants
'': 'ga', '': 'gi', '': 'gu', '': 'ge', '': 'go',
'': 'za', '': 'ji', '': 'zu', '': 'ze', '': 'zo',
'': 'da', '': 'ji', '': 'zu', '': 'de', '': 'do',
'': 'ba', '': 'bi', '': 'bu', '': 'be', '': 'bo',
'': 'pa', '': 'pi', '': 'pu', '': 'pe', '': 'po',
# Katakana combinations
'キャ': 'kya', 'キュ': 'kyu', 'キョ': 'kyo',
'シャ': 'sha', 'シュ': 'shu', 'ショ': 'sho',
'チャ': 'cha', 'チュ': 'chu', 'チョ': 'cho',
'ニャ': 'nya', 'ニュ': 'nyu', 'ニョ': 'nyo',
'ヒャ': 'hya', 'ヒュ': 'hyu', 'ヒョ': 'hyo',
'ミャ': 'mya', 'ミュ': 'myu', 'ミョ': 'myo',
'リャ': 'rya', 'リュ': 'ryu', 'リョ': 'ryo',
'ギャ': 'gya', 'ギュ': 'gyu', 'ギョ': 'gyo',
'ジャ': 'ja', 'ジュ': 'ju', 'ジョ': 'jo',
'ビャ': 'bya', 'ビュ': 'byu', 'ビョ': 'byo',
'ピャ': 'pya', 'ピュ': 'pyu', 'ピョ': 'pyo',
# Katakana small characters and special cases
'': '', # Small tsu (doubles the following consonant)
'': 'ya', '': 'yu', '': 'yo',
# Katakana extras
'': 'vu', 'ファ': 'fa', 'フィ': 'fi', 'フェ': 'fe', 'フォ': 'fo',
'ウィ': 'wi', 'ウェ': 'we', 'ウォ': 'wo',
# Hiragana characters
'': 'a', '': 'i', '': 'u', '': 'e', '': 'o',
'': 'ka', '': 'ki', '': 'ku', '': 'ke', '': 'ko',
'': 'sa', '': 'shi', '': 'su', '': 'se', '': 'so',
'': 'ta', '': 'chi', '': 'tsu', '': 'te', '': 'to',
'': 'na', '': 'ni', '': 'nu', '': 'ne', '': 'no',
'': 'ha', '': 'hi', '': 'fu', '': 'he', '': 'ho',
'': 'ma', '': 'mi', '': 'mu', '': 'me', '': 'mo',
'': 'ya', '': 'yu', '': 'yo',
'': 'ra', '': 'ri', '': 'ru', '': 're', '': 'ro',
'': 'wa', '': 'wo', '': 'n',
# Hiragana voiced consonants
'': 'ga', '': 'gi', '': 'gu', '': 'ge', '': 'go',
'': 'za', '': 'ji', '': 'zu', '': 'ze', '': 'zo',
'': 'da', '': 'ji', '': 'zu', '': 'de', '': 'do',
'': 'ba', '': 'bi', '': 'bu', '': 'be', '': 'bo',
'': 'pa', '': 'pi', '': 'pu', '': 'pe', '': 'po',
# Hiragana combinations
'きゃ': 'kya', 'きゅ': 'kyu', 'きょ': 'kyo',
'しゃ': 'sha', 'しゅ': 'shu', 'しょ': 'sho',
'ちゃ': 'cha', 'ちゅ': 'chu', 'ちょ': 'cho',
'にゃ': 'nya', 'にゅ': 'nyu', 'にょ': 'nyo',
'ひゃ': 'hya', 'ひゅ': 'hyu', 'ひょ': 'hyo',
'みゃ': 'mya', 'みゅ': 'myu', 'みょ': 'myo',
'りゃ': 'rya', 'りゅ': 'ryu', 'りょ': 'ryo',
'ぎゃ': 'gya', 'ぎゅ': 'gyu', 'ぎょ': 'gyo',
'じゃ': 'ja', 'じゅ': 'ju', 'じょ': 'jo',
'びゃ': 'bya', 'びゅ': 'byu', 'びょ': 'byo',
'ぴゃ': 'pya', 'ぴゅ': 'pyu', 'ぴょ': 'pyo',
# Hiragana small characters and special cases
'': '', # Small tsu (doubles the following consonant)
'': 'ya', '': 'yu', '': 'yo',
# Common punctuation and spaces
' ': ' ', # Japanese space
'': ', ', '': '. ',
}
result = []
i = 0
while i < len(japanese_text):
# Check for small tsu (doubling the following consonant)
if i < len(japanese_text) - 1 and (japanese_text[i] == '' or japanese_text[i] == ''):
if i < len(japanese_text) - 1 and japanese_text[i+1] in kana_map:
next_romaji = kana_map[japanese_text[i+1]]
if next_romaji and next_romaji[0] not in 'aiueon':
result.append(next_romaji[0]) # Double the consonant
i += 1
continue
# Check for combinations with small ya, yu, yo
if i < len(japanese_text) - 1 and japanese_text[i+1] in ('', '', '', '', '', ''):
combo = japanese_text[i:i+2]
if combo in kana_map:
result.append(kana_map[combo])
i += 2
continue
# Regular character
if japanese_text[i] in kana_map:
result.append(kana_map[japanese_text[i]])
else:
# If it's not in our map, keep it as is (might be kanji, romaji, etc.)
result.append(japanese_text[i])
i += 1
return ''.join(result)
def number_to_text(num, ordinal=False):
"""
Convert a number (int or float) to its text representation.