diff --git a/comfy/ldm/ace/vae/music_dcae_pipeline.py b/comfy/ldm/ace/vae/music_dcae_pipeline.py index 3188bc770..af81280eb 100644 --- a/comfy/ldm/ace/vae/music_dcae_pipeline.py +++ b/comfy/ldm/ace/vae/music_dcae_pipeline.py @@ -1,7 +1,12 @@ # Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py import torch from .autoencoder_dc import AutoencoderDC -import torchaudio +import logging +try: + import torchaudio +except: + logging.warning("torchaudio missing, ACE model will be broken") + import torchvision.transforms as transforms from .music_vocoder import ADaMoSHiFiGANV1 diff --git a/comfy/ldm/ace/vae/music_log_mel.py b/comfy/ldm/ace/vae/music_log_mel.py index d73d3f8e8..9c584eb7f 100755 --- a/comfy/ldm/ace/vae/music_log_mel.py +++ b/comfy/ldm/ace/vae/music_log_mel.py @@ -2,7 +2,12 @@ import torch import torch.nn as nn from torch import Tensor -from torchaudio.transforms import MelScale +import logging +try: + from torchaudio.transforms import MelScale +except: + logging.warning("torchaudio missing, ACE model will be broken") + import comfy.model_management class LinearSpectrogram(nn.Module): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ff4c29d7e..28c586389 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -222,6 +222,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv dit_config = {} dit_config["image_model"] = "ltxv" + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') + shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape + dit_config["attention_head_dim"] = shape[0] // 32 + dit_config["cross_attention_dim"] = shape[1] if metadata is not None and "config" in metadata: dit_config.update(json.loads(metadata["config"]).get("transformer", {})) return dit_config diff --git a/comfy/sd.py b/comfy/sd.py index 50af243ba..ee350d5b5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -282,6 +282,7 @@ class VAE: self.downscale_index_formula = None self.upscale_index_formula = None + self.extra_1d_channel = None if config is None: if "decoder.mid.block_1.mix_factor" in sd: @@ -441,17 +442,18 @@ class VAE: self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100) - self.memory_used_encode = lambda shape, dtype: (shape[2] * 300) * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 72000) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) self.latent_channels = 8 self.output_channels = 2 - # self.upscale_ratio = 2048 - # self.downscale_ratio = 2048 + self.upscale_ratio = 4096 + self.downscale_ratio = 4096 self.latent_dim = 2 self.process_output = lambda audio: audio self.process_input = lambda audio: audio self.working_dtypes = [torch.bfloat16, torch.float32] self.disable_offload = True + self.extra_1d_channel = 16 else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -510,7 +512,13 @@ class VAE: return output def decode_tiled_1d(self, samples, tile_x=128, overlap=32): - decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + if samples.ndim == 3: + decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + else: + og_shape = samples.shape + samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1)) + decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float() + return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): @@ -530,9 +538,24 @@ class VAE: samples /= 3.0 return samples - def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048): - encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() - return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device) + def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): + if self.latent_dim == 1: + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() + out_channels = self.latent_channels + upscale_amount = 1 / self.downscale_ratio + else: + extra_channel_size = self.extra_1d_channel + out_channels = self.latent_channels * extra_channel_size + tile_x = tile_x // extra_channel_size + overlap = overlap // extra_channel_size + upscale_amount = 1 / self.downscale_ratio + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float() + + out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) + if self.latent_dim == 1: + return out + else: + return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1) def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)): encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() @@ -557,7 +580,7 @@ class VAE: except model_management.OOM_EXCEPTION: logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") dims = samples_in.ndim - 2 - if dims == 1: + if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) elif dims == 2: pixel_samples = self.decode_tiled_(samples_in) @@ -624,7 +647,7 @@ class VAE: tile = 256 overlap = tile // 4 samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) - elif self.latent_dim == 1: + elif self.latent_dim == 1 or self.extra_1d_channel is not None: samples = self.encode_tiled_1d(pixel_samples) else: samples = self.encode_tiled_(pixel_samples) diff --git a/comfy/text_encoders/ace.py b/comfy/text_encoders/ace.py index b6fe451bd..d650bb10d 100644 --- a/comfy/text_encoders/ace.py +++ b/comfy/text_encoders/ace.py @@ -7,7 +7,7 @@ import torch import logging from tokenizers import Tokenizer -from .ace_text_cleaners import multilingual_cleaners +from .ace_text_cleaners import multilingual_cleaners, japanese_to_romaji SUPPORT_LANGUAGES = { "en": 259, "de": 260, "fr": 262, "es": 284, "it": 285, @@ -65,6 +65,14 @@ class VoiceBpeTokenizer: if "spa" in lang: lang = "es" + try: + line_out = japanese_to_romaji(line) + if line_out != line: + lang = "ja" + line = line_out + except: + pass + try: if structure_pattern.match(line): token_idx = self.encode(line, "en") diff --git a/comfy/text_encoders/ace_text_cleaners.py b/comfy/text_encoders/ace_text_cleaners.py index ad3612e5f..cd31d8d8c 100644 --- a/comfy/text_encoders/ace_text_cleaners.py +++ b/comfy/text_encoders/ace_text_cleaners.py @@ -4,6 +4,131 @@ import re +def japanese_to_romaji(japanese_text): + """ + Convert Japanese hiragana and katakana to romaji (Latin alphabet representation). + + Args: + japanese_text (str): Text containing hiragana and/or katakana characters + + Returns: + str: The romaji (Latin alphabet) equivalent + """ + # Dictionary mapping kana characters to their romaji equivalents + kana_map = { + # Katakana characters + 'ア': 'a', 'イ': 'i', 'ウ': 'u', 'エ': 'e', 'オ': 'o', + 'カ': 'ka', 'キ': 'ki', 'ク': 'ku', 'ケ': 'ke', 'コ': 'ko', + 'サ': 'sa', 'シ': 'shi', 'ス': 'su', 'セ': 'se', 'ソ': 'so', + 'タ': 'ta', 'チ': 'chi', 'ツ': 'tsu', 'テ': 'te', 'ト': 'to', + 'ナ': 'na', 'ニ': 'ni', 'ヌ': 'nu', 'ネ': 'ne', 'ノ': 'no', + 'ハ': 'ha', 'ヒ': 'hi', 'フ': 'fu', 'ヘ': 'he', 'ホ': 'ho', + 'マ': 'ma', 'ミ': 'mi', 'ム': 'mu', 'メ': 'me', 'モ': 'mo', + 'ヤ': 'ya', 'ユ': 'yu', 'ヨ': 'yo', + 'ラ': 'ra', 'リ': 'ri', 'ル': 'ru', 'レ': 're', 'ロ': 'ro', + 'ワ': 'wa', 'ヲ': 'wo', 'ン': 'n', + + # Katakana voiced consonants + 'ガ': 'ga', 'ギ': 'gi', 'グ': 'gu', 'ゲ': 'ge', 'ゴ': 'go', + 'ザ': 'za', 'ジ': 'ji', 'ズ': 'zu', 'ゼ': 'ze', 'ゾ': 'zo', + 'ダ': 'da', 'ヂ': 'ji', 'ヅ': 'zu', 'デ': 'de', 'ド': 'do', + 'バ': 'ba', 'ビ': 'bi', 'ブ': 'bu', 'ベ': 'be', 'ボ': 'bo', + 'パ': 'pa', 'ピ': 'pi', 'プ': 'pu', 'ペ': 'pe', 'ポ': 'po', + + # Katakana combinations + 'キャ': 'kya', 'キュ': 'kyu', 'キョ': 'kyo', + 'シャ': 'sha', 'シュ': 'shu', 'ショ': 'sho', + 'チャ': 'cha', 'チュ': 'chu', 'チョ': 'cho', + 'ニャ': 'nya', 'ニュ': 'nyu', 'ニョ': 'nyo', + 'ヒャ': 'hya', 'ヒュ': 'hyu', 'ヒョ': 'hyo', + 'ミャ': 'mya', 'ミュ': 'myu', 'ミョ': 'myo', + 'リャ': 'rya', 'リュ': 'ryu', 'リョ': 'ryo', + 'ギャ': 'gya', 'ギュ': 'gyu', 'ギョ': 'gyo', + 'ジャ': 'ja', 'ジュ': 'ju', 'ジョ': 'jo', + 'ビャ': 'bya', 'ビュ': 'byu', 'ビョ': 'byo', + 'ピャ': 'pya', 'ピュ': 'pyu', 'ピョ': 'pyo', + + # Katakana small characters and special cases + 'ッ': '', # Small tsu (doubles the following consonant) + 'ャ': 'ya', 'ュ': 'yu', 'ョ': 'yo', + + # Katakana extras + 'ヴ': 'vu', 'ファ': 'fa', 'フィ': 'fi', 'フェ': 'fe', 'フォ': 'fo', + 'ウィ': 'wi', 'ウェ': 'we', 'ウォ': 'wo', + + # Hiragana characters + 'あ': 'a', 'い': 'i', 'う': 'u', 'え': 'e', 'お': 'o', + 'か': 'ka', 'き': 'ki', 'く': 'ku', 'け': 'ke', 'こ': 'ko', + 'さ': 'sa', 'し': 'shi', 'す': 'su', 'せ': 'se', 'そ': 'so', + 'た': 'ta', 'ち': 'chi', 'つ': 'tsu', 'て': 'te', 'と': 'to', + 'な': 'na', 'に': 'ni', 'ぬ': 'nu', 'ね': 'ne', 'の': 'no', + 'は': 'ha', 'ひ': 'hi', 'ふ': 'fu', 'へ': 'he', 'ほ': 'ho', + 'ま': 'ma', 'み': 'mi', 'む': 'mu', 'め': 'me', 'も': 'mo', + 'や': 'ya', 'ゆ': 'yu', 'よ': 'yo', + 'ら': 'ra', 'り': 'ri', 'る': 'ru', 'れ': 're', 'ろ': 'ro', + 'わ': 'wa', 'を': 'wo', 'ん': 'n', + + # Hiragana voiced consonants + 'が': 'ga', 'ぎ': 'gi', 'ぐ': 'gu', 'げ': 'ge', 'ご': 'go', + 'ざ': 'za', 'じ': 'ji', 'ず': 'zu', 'ぜ': 'ze', 'ぞ': 'zo', + 'だ': 'da', 'ぢ': 'ji', 'づ': 'zu', 'で': 'de', 'ど': 'do', + 'ば': 'ba', 'び': 'bi', 'ぶ': 'bu', 'べ': 'be', 'ぼ': 'bo', + 'ぱ': 'pa', 'ぴ': 'pi', 'ぷ': 'pu', 'ぺ': 'pe', 'ぽ': 'po', + + # Hiragana combinations + 'きゃ': 'kya', 'きゅ': 'kyu', 'きょ': 'kyo', + 'しゃ': 'sha', 'しゅ': 'shu', 'しょ': 'sho', + 'ちゃ': 'cha', 'ちゅ': 'chu', 'ちょ': 'cho', + 'にゃ': 'nya', 'にゅ': 'nyu', 'にょ': 'nyo', + 'ひゃ': 'hya', 'ひゅ': 'hyu', 'ひょ': 'hyo', + 'みゃ': 'mya', 'みゅ': 'myu', 'みょ': 'myo', + 'りゃ': 'rya', 'りゅ': 'ryu', 'りょ': 'ryo', + 'ぎゃ': 'gya', 'ぎゅ': 'gyu', 'ぎょ': 'gyo', + 'じゃ': 'ja', 'じゅ': 'ju', 'じょ': 'jo', + 'びゃ': 'bya', 'びゅ': 'byu', 'びょ': 'byo', + 'ぴゃ': 'pya', 'ぴゅ': 'pyu', 'ぴょ': 'pyo', + + # Hiragana small characters and special cases + 'っ': '', # Small tsu (doubles the following consonant) + 'ゃ': 'ya', 'ゅ': 'yu', 'ょ': 'yo', + + # Common punctuation and spaces + ' ': ' ', # Japanese space + '、': ', ', '。': '. ', + } + + result = [] + i = 0 + + while i < len(japanese_text): + # Check for small tsu (doubling the following consonant) + if i < len(japanese_text) - 1 and (japanese_text[i] == 'っ' or japanese_text[i] == 'ッ'): + if i < len(japanese_text) - 1 and japanese_text[i+1] in kana_map: + next_romaji = kana_map[japanese_text[i+1]] + if next_romaji and next_romaji[0] not in 'aiueon': + result.append(next_romaji[0]) # Double the consonant + i += 1 + continue + + # Check for combinations with small ya, yu, yo + if i < len(japanese_text) - 1 and japanese_text[i+1] in ('ゃ', 'ゅ', 'ょ', 'ャ', 'ュ', 'ョ'): + combo = japanese_text[i:i+2] + if combo in kana_map: + result.append(kana_map[combo]) + i += 2 + continue + + # Regular character + if japanese_text[i] in kana_map: + result.append(kana_map[japanese_text[i]]) + else: + # If it's not in our map, keep it as is (might be kanji, romaji, etc.) + result.append(japanese_text[i]) + + i += 1 + + return ''.join(result) + def number_to_text(num, ordinal=False): """ Convert a number (int or float) to its text representation.