mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 10:52:31 +08:00
nodes_image_stream: implement VAE decoder node
This commit is contained in:
parent
0c70446c9b
commit
1c2d37944c
@ -1388,6 +1388,15 @@ class VideoVAE(nn.Module):
|
||||
def decode_output_shape(self, input_shape):
|
||||
return self.decoder.decode_output_shape(input_shape)
|
||||
|
||||
def decode_start(self, x):
|
||||
clear_temporal_cache_state(self.decoder)
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||
return self.decoder.forward_start(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
||||
|
||||
def decode_chunk(self, output_t: int):
|
||||
return self.decoder.forward_resume(output_t)
|
||||
|
||||
def decode(self, x, output_buffer=None):
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||
|
||||
11
comfy/sd.py
11
comfy/sd.py
@ -995,6 +995,17 @@ class VAE:
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_output_shape(self, samples_shape):
|
||||
self.throw_exception_if_invalid()
|
||||
if hasattr(self.first_stage_model, "decode_output_shape"):
|
||||
return self.first_stage_model.decode_output_shape(samples_shape)
|
||||
raise RuntimeError("This VAE does not expose decode output shape information.")
|
||||
|
||||
def decode_stream_start(self, samples_in):
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
self.first_stage_model.decode_start(samples_in.to(device=self.device, dtype=self.vae_dtype))
|
||||
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
self.throw_exception_if_invalid()
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||
|
||||
@ -123,6 +123,60 @@ class PreviewingImageStream(Input.ImageStream):
|
||||
return chunk
|
||||
|
||||
|
||||
class VAEDecodedImageStream(Input.ImageStream):
|
||||
def __init__(self, vae, latent: Input.Latent):
|
||||
super().__init__()
|
||||
self._vae = vae
|
||||
self._latent = latent
|
||||
vae.throw_exception_if_invalid()
|
||||
if not getattr(vae.first_stage_model, "comfy_has_chunked_io", False):
|
||||
raise RuntimeError("This VAE does not expose chunked decode support, so VAE Decode Stream cannot be used.")
|
||||
if latent.ndim != 5:
|
||||
raise RuntimeError("VAE Decode Stream expects a video latent shaped [batch, channels, frames, height, width].")
|
||||
if latent.shape[0] != 1:
|
||||
raise RuntimeError("VAE Decode Stream currently requires latent batch size 1.")
|
||||
output_shape = vae.decode_output_shape(latent.shape)
|
||||
self._channels = int(output_shape[1])
|
||||
self._width = int(output_shape[4])
|
||||
self._height = int(output_shape[3])
|
||||
self._total_frames = int(output_shape[0] * output_shape[2])
|
||||
self._frames_emitted = 0
|
||||
|
||||
def _update_progress(self, value: int) -> None:
|
||||
current = get_executing_context()
|
||||
if current is None:
|
||||
return
|
||||
get_progress_state().update_progress(
|
||||
current.node_id,
|
||||
value=float(value),
|
||||
max_value=float(max(self._total_frames, 1)),
|
||||
)
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
return self._width, self._height
|
||||
|
||||
def do_reset(self) -> None:
|
||||
self._frames_emitted = 0
|
||||
self._update_progress(0)
|
||||
self._vae.decode_stream_start(self._latent)
|
||||
|
||||
def do_pull(self, max_frames: int) -> Input.Image:
|
||||
chunk = self._vae.first_stage_model.decode_chunk(max_frames)
|
||||
if chunk is None:
|
||||
return torch.empty(
|
||||
(0, self._height, self._width, self._channels),
|
||||
device=self._vae.output_device,
|
||||
dtype=self._vae.vae_output_dtype(),
|
||||
)
|
||||
|
||||
chunk = chunk.to(device=self._vae.output_device, dtype=self._vae.vae_output_dtype())
|
||||
chunk = self._vae.process_output(chunk).movedim(1, -1)
|
||||
chunk = chunk.reshape((-1,) + tuple(chunk.shape[-3:]))
|
||||
self._frames_emitted += int(chunk.shape[0])
|
||||
self._update_progress(self._frames_emitted)
|
||||
return chunk
|
||||
|
||||
|
||||
class ImageBatchToStream(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -178,6 +232,32 @@ class ImageStreamToBatch(io.ComfyNode):
|
||||
return io.NodeOutput(torch.cat(chunks, dim=0))
|
||||
|
||||
|
||||
class VAEDecodeStream(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VAEDecodeStream",
|
||||
display_name="VAE Decode Stream",
|
||||
category="image/stream",
|
||||
search_aliases=["vae stream decode", "latent to stream", "video latent stream"],
|
||||
description="Decodes a latent into an IMAGE_STREAM.",
|
||||
inputs=[
|
||||
io.Latent.Input("samples", tooltip="The LTX latent to decode."),
|
||||
io.Vae.Input("vae", tooltip="The LTX VAE used for chunked streaming decode."),
|
||||
],
|
||||
outputs=[
|
||||
io.ImageStream.Output(display_name="stream"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples: Input.Latent, vae) -> io.NodeOutput:
|
||||
latent = samples["samples"]
|
||||
if latent.is_nested:
|
||||
latent = latent.unbind()[0]
|
||||
return io.NodeOutput(VAEDecodedImageStream(vae, latent))
|
||||
|
||||
|
||||
class PreviewImageStream(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -230,6 +310,7 @@ class ImageStreamExtension(ComfyExtension):
|
||||
return [
|
||||
ImageBatchToStream,
|
||||
ImageStreamToBatch,
|
||||
VAEDecodeStream,
|
||||
PreviewImageStream,
|
||||
StreamSink,
|
||||
]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user