mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-28 03:12:31 +08:00
nodes_image_stream: implement VAE decoder node
This commit is contained in:
parent
0c70446c9b
commit
1c2d37944c
@ -1388,6 +1388,15 @@ class VideoVAE(nn.Module):
|
|||||||
def decode_output_shape(self, input_shape):
|
def decode_output_shape(self, input_shape):
|
||||||
return self.decoder.decode_output_shape(input_shape)
|
return self.decoder.decode_output_shape(input_shape)
|
||||||
|
|
||||||
|
def decode_start(self, x):
|
||||||
|
clear_temporal_cache_state(self.decoder)
|
||||||
|
if self.timestep_conditioning: #TODO: seed
|
||||||
|
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||||
|
return self.decoder.forward_start(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
||||||
|
|
||||||
|
def decode_chunk(self, output_t: int):
|
||||||
|
return self.decoder.forward_resume(output_t)
|
||||||
|
|
||||||
def decode(self, x, output_buffer=None):
|
def decode(self, x, output_buffer=None):
|
||||||
if self.timestep_conditioning: #TODO: seed
|
if self.timestep_conditioning: #TODO: seed
|
||||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||||
|
|||||||
11
comfy/sd.py
11
comfy/sd.py
@ -995,6 +995,17 @@ class VAE:
|
|||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
|
def decode_output_shape(self, samples_shape):
|
||||||
|
self.throw_exception_if_invalid()
|
||||||
|
if hasattr(self.first_stage_model, "decode_output_shape"):
|
||||||
|
return self.first_stage_model.decode_output_shape(samples_shape)
|
||||||
|
raise RuntimeError("This VAE does not expose decode output shape information.")
|
||||||
|
|
||||||
|
def decode_stream_start(self, samples_in):
|
||||||
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
|
self.first_stage_model.decode_start(samples_in.to(device=self.device, dtype=self.vae_dtype))
|
||||||
|
|
||||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
self.throw_exception_if_invalid()
|
self.throw_exception_if_invalid()
|
||||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||||
|
|||||||
@ -123,6 +123,60 @@ class PreviewingImageStream(Input.ImageStream):
|
|||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
|
class VAEDecodedImageStream(Input.ImageStream):
|
||||||
|
def __init__(self, vae, latent: Input.Latent):
|
||||||
|
super().__init__()
|
||||||
|
self._vae = vae
|
||||||
|
self._latent = latent
|
||||||
|
vae.throw_exception_if_invalid()
|
||||||
|
if not getattr(vae.first_stage_model, "comfy_has_chunked_io", False):
|
||||||
|
raise RuntimeError("This VAE does not expose chunked decode support, so VAE Decode Stream cannot be used.")
|
||||||
|
if latent.ndim != 5:
|
||||||
|
raise RuntimeError("VAE Decode Stream expects a video latent shaped [batch, channels, frames, height, width].")
|
||||||
|
if latent.shape[0] != 1:
|
||||||
|
raise RuntimeError("VAE Decode Stream currently requires latent batch size 1.")
|
||||||
|
output_shape = vae.decode_output_shape(latent.shape)
|
||||||
|
self._channels = int(output_shape[1])
|
||||||
|
self._width = int(output_shape[4])
|
||||||
|
self._height = int(output_shape[3])
|
||||||
|
self._total_frames = int(output_shape[0] * output_shape[2])
|
||||||
|
self._frames_emitted = 0
|
||||||
|
|
||||||
|
def _update_progress(self, value: int) -> None:
|
||||||
|
current = get_executing_context()
|
||||||
|
if current is None:
|
||||||
|
return
|
||||||
|
get_progress_state().update_progress(
|
||||||
|
current.node_id,
|
||||||
|
value=float(value),
|
||||||
|
max_value=float(max(self._total_frames, 1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_dimensions(self) -> tuple[int, int]:
|
||||||
|
return self._width, self._height
|
||||||
|
|
||||||
|
def do_reset(self) -> None:
|
||||||
|
self._frames_emitted = 0
|
||||||
|
self._update_progress(0)
|
||||||
|
self._vae.decode_stream_start(self._latent)
|
||||||
|
|
||||||
|
def do_pull(self, max_frames: int) -> Input.Image:
|
||||||
|
chunk = self._vae.first_stage_model.decode_chunk(max_frames)
|
||||||
|
if chunk is None:
|
||||||
|
return torch.empty(
|
||||||
|
(0, self._height, self._width, self._channels),
|
||||||
|
device=self._vae.output_device,
|
||||||
|
dtype=self._vae.vae_output_dtype(),
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk = chunk.to(device=self._vae.output_device, dtype=self._vae.vae_output_dtype())
|
||||||
|
chunk = self._vae.process_output(chunk).movedim(1, -1)
|
||||||
|
chunk = chunk.reshape((-1,) + tuple(chunk.shape[-3:]))
|
||||||
|
self._frames_emitted += int(chunk.shape[0])
|
||||||
|
self._update_progress(self._frames_emitted)
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
class ImageBatchToStream(io.ComfyNode):
|
class ImageBatchToStream(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -178,6 +232,32 @@ class ImageStreamToBatch(io.ComfyNode):
|
|||||||
return io.NodeOutput(torch.cat(chunks, dim=0))
|
return io.NodeOutput(torch.cat(chunks, dim=0))
|
||||||
|
|
||||||
|
|
||||||
|
class VAEDecodeStream(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="VAEDecodeStream",
|
||||||
|
display_name="VAE Decode Stream",
|
||||||
|
category="image/stream",
|
||||||
|
search_aliases=["vae stream decode", "latent to stream", "video latent stream"],
|
||||||
|
description="Decodes a latent into an IMAGE_STREAM.",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples", tooltip="The LTX latent to decode."),
|
||||||
|
io.Vae.Input("vae", tooltip="The LTX VAE used for chunked streaming decode."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.ImageStream.Output(display_name="stream"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, samples: Input.Latent, vae) -> io.NodeOutput:
|
||||||
|
latent = samples["samples"]
|
||||||
|
if latent.is_nested:
|
||||||
|
latent = latent.unbind()[0]
|
||||||
|
return io.NodeOutput(VAEDecodedImageStream(vae, latent))
|
||||||
|
|
||||||
|
|
||||||
class PreviewImageStream(io.ComfyNode):
|
class PreviewImageStream(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -230,6 +310,7 @@ class ImageStreamExtension(ComfyExtension):
|
|||||||
return [
|
return [
|
||||||
ImageBatchToStream,
|
ImageBatchToStream,
|
||||||
ImageStreamToBatch,
|
ImageStreamToBatch,
|
||||||
|
VAEDecodeStream,
|
||||||
PreviewImageStream,
|
PreviewImageStream,
|
||||||
StreamSink,
|
StreamSink,
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user