mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-08 04:22:36 +08:00
Add VAE tiled decode node for audio. (#12299)
This commit is contained in:
parent
a246cc02b2
commit
35183543e0
@ -976,7 +976,7 @@ class VAE:
|
|||||||
if overlap is not None:
|
if overlap is not None:
|
||||||
args["overlap"] = overlap
|
args["overlap"] = overlap
|
||||||
|
|
||||||
if dims == 1:
|
if dims == 1 or self.extra_1d_channel is not None:
|
||||||
args.pop("tile_y")
|
args.pop("tile_y")
|
||||||
output = self.decode_tiled_1d(samples, **args)
|
output = self.decode_tiled_1d(samples, **args)
|
||||||
elif dims == 2:
|
elif dims == 2:
|
||||||
|
|||||||
@ -94,6 +94,19 @@ class VAEEncodeAudio(IO.ComfyNode):
|
|||||||
encode = execute # TODO: remove
|
encode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
def vae_decode_audio(vae, samples, tile=None, overlap=None):
|
||||||
|
if tile is not None:
|
||||||
|
audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1)
|
||||||
|
else:
|
||||||
|
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||||
|
|
||||||
|
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
|
||||||
|
std[std < 1.0] = 1.0
|
||||||
|
audio /= std
|
||||||
|
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
||||||
|
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}
|
||||||
|
|
||||||
|
|
||||||
class VAEDecodeAudio(IO.ComfyNode):
|
class VAEDecodeAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -111,16 +124,33 @@ class VAEDecodeAudio(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, vae, samples) -> IO.NodeOutput:
|
def execute(cls, vae, samples) -> IO.NodeOutput:
|
||||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
return IO.NodeOutput(vae_decode_audio(vae, samples))
|
||||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
|
||||||
std[std < 1.0] = 1.0
|
|
||||||
audio /= std
|
|
||||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
|
||||||
return IO.NodeOutput({"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]})
|
|
||||||
|
|
||||||
decode = execute # TODO: remove
|
decode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class VAEDecodeAudioTiled(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="VAEDecodeAudioTiled",
|
||||||
|
search_aliases=["latent to audio"],
|
||||||
|
display_name="VAE Decode Audio (Tiled)",
|
||||||
|
category="latent/audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Latent.Input("samples"),
|
||||||
|
IO.Vae.Input("vae"),
|
||||||
|
IO.Int.Input("tile_size", default=512, min=32, max=8192, step=8),
|
||||||
|
IO.Int.Input("overlap", default=64, min=0, max=1024, step=8),
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, vae, samples, tile_size, overlap) -> IO.NodeOutput:
|
||||||
|
return IO.NodeOutput(vae_decode_audio(vae, samples, tile_size, overlap))
|
||||||
|
|
||||||
|
|
||||||
class SaveAudio(IO.ComfyNode):
|
class SaveAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -675,6 +705,7 @@ class AudioExtension(ComfyExtension):
|
|||||||
EmptyLatentAudio,
|
EmptyLatentAudio,
|
||||||
VAEEncodeAudio,
|
VAEEncodeAudio,
|
||||||
VAEDecodeAudio,
|
VAEDecodeAudio,
|
||||||
|
VAEDecodeAudioTiled,
|
||||||
SaveAudio,
|
SaveAudio,
|
||||||
SaveAudioMP3,
|
SaveAudioMP3,
|
||||||
SaveAudioOpus,
|
SaveAudioOpus,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user