nodes_image_stream: implement VAE decoder node

This commit is contained in:
Rattus 2026-04-14 11:15:47 +10:00
parent 0c70446c9b
commit 1c2d37944c
3 changed files with 101 additions and 0 deletions

View File

@ -1388,6 +1388,15 @@ class VideoVAE(nn.Module):
def decode_output_shape(self, input_shape): def decode_output_shape(self, input_shape):
return self.decoder.decode_output_shape(input_shape) return self.decoder.decode_output_shape(input_shape)
def decode_start(self, x):
clear_temporal_cache_state(self.decoder)
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder.forward_start(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
def decode_chunk(self, output_t: int):
return self.decoder.forward_resume(output_t)
def decode(self, x, output_buffer=None): def decode(self, x, output_buffer=None):
if self.timestep_conditioning: #TODO: seed if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x

View File

@ -995,6 +995,17 @@ class VAE:
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples return pixel_samples
def decode_output_shape(self, samples_shape):
self.throw_exception_if_invalid()
if hasattr(self.first_stage_model, "decode_output_shape"):
return self.first_stage_model.decode_output_shape(samples_shape)
raise RuntimeError("This VAE does not expose decode output shape information.")
def decode_stream_start(self, samples_in):
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
self.first_stage_model.decode_start(samples_in.to(device=self.device, dtype=self.vae_dtype))
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid() self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile

View File

@ -123,6 +123,60 @@ class PreviewingImageStream(Input.ImageStream):
return chunk return chunk
class VAEDecodedImageStream(Input.ImageStream):
def __init__(self, vae, latent: Input.Latent):
super().__init__()
self._vae = vae
self._latent = latent
vae.throw_exception_if_invalid()
if not getattr(vae.first_stage_model, "comfy_has_chunked_io", False):
raise RuntimeError("This VAE does not expose chunked decode support, so VAE Decode Stream cannot be used.")
if latent.ndim != 5:
raise RuntimeError("VAE Decode Stream expects a video latent shaped [batch, channels, frames, height, width].")
if latent.shape[0] != 1:
raise RuntimeError("VAE Decode Stream currently requires latent batch size 1.")
output_shape = vae.decode_output_shape(latent.shape)
self._channels = int(output_shape[1])
self._width = int(output_shape[4])
self._height = int(output_shape[3])
self._total_frames = int(output_shape[0] * output_shape[2])
self._frames_emitted = 0
def _update_progress(self, value: int) -> None:
current = get_executing_context()
if current is None:
return
get_progress_state().update_progress(
current.node_id,
value=float(value),
max_value=float(max(self._total_frames, 1)),
)
def get_dimensions(self) -> tuple[int, int]:
return self._width, self._height
def do_reset(self) -> None:
self._frames_emitted = 0
self._update_progress(0)
self._vae.decode_stream_start(self._latent)
def do_pull(self, max_frames: int) -> Input.Image:
chunk = self._vae.first_stage_model.decode_chunk(max_frames)
if chunk is None:
return torch.empty(
(0, self._height, self._width, self._channels),
device=self._vae.output_device,
dtype=self._vae.vae_output_dtype(),
)
chunk = chunk.to(device=self._vae.output_device, dtype=self._vae.vae_output_dtype())
chunk = self._vae.process_output(chunk).movedim(1, -1)
chunk = chunk.reshape((-1,) + tuple(chunk.shape[-3:]))
self._frames_emitted += int(chunk.shape[0])
self._update_progress(self._frames_emitted)
return chunk
class ImageBatchToStream(io.ComfyNode): class ImageBatchToStream(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -178,6 +232,32 @@ class ImageStreamToBatch(io.ComfyNode):
return io.NodeOutput(torch.cat(chunks, dim=0)) return io.NodeOutput(torch.cat(chunks, dim=0))
class VAEDecodeStream(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VAEDecodeStream",
display_name="VAE Decode Stream",
category="image/stream",
search_aliases=["vae stream decode", "latent to stream", "video latent stream"],
description="Decodes a latent into an IMAGE_STREAM.",
inputs=[
io.Latent.Input("samples", tooltip="The LTX latent to decode."),
io.Vae.Input("vae", tooltip="The LTX VAE used for chunked streaming decode."),
],
outputs=[
io.ImageStream.Output(display_name="stream"),
],
)
@classmethod
def execute(cls, samples: Input.Latent, vae) -> io.NodeOutput:
latent = samples["samples"]
if latent.is_nested:
latent = latent.unbind()[0]
return io.NodeOutput(VAEDecodedImageStream(vae, latent))
class PreviewImageStream(io.ComfyNode): class PreviewImageStream(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -230,6 +310,7 @@ class ImageStreamExtension(ComfyExtension):
return [ return [
ImageBatchToStream, ImageBatchToStream,
ImageStreamToBatch, ImageStreamToBatch,
VAEDecodeStream,
PreviewImageStream, PreviewImageStream,
StreamSink, StreamSink,
] ]