from comfy import model_management from comfy_api.latest import ComfyExtension, IO from typing_extensions import override import math class LTXVLatentUpsampler(IO.ComfyNode): """ Upsamples a video latent by a factor of 2. """ @classmethod def define_schema(cls): return IO.Schema( node_id="LTXVLatentUpsampler", category="latent/video", is_experimental=True, inputs=[ IO.Latent.Input("samples"), IO.LatentUpscaleModel.Input("upscale_model"), IO.Vae.Input("vae"), ], outputs=[ IO.Latent.Output(), ], ) @classmethod def execute(cls, samples, upscale_model, vae) -> IO.NodeOutput: """ Upsample the input latent using the provided model. Args: samples (dict): Input latent samples upscale_model (LatentUpsampler): Loaded upscale model vae: VAE model for normalization Returns: tuple: Tuple containing the upsampled latent """ device = model_management.get_torch_device() memory_required = model_management.module_size(upscale_model) model_dtype = next(upscale_model.parameters()).dtype latents = samples["samples"] input_dtype = latents.dtype memory_required += math.prod(latents.shape) * 3000.0 # TODO: more accurate model_management.free_memory(memory_required, device) try: upscale_model.to(device) # TODO: use the comfy model management system. latents = latents.to(dtype=model_dtype, device=device) """Upsample latents without tiling.""" latents = vae.first_stage_model.per_channel_statistics.un_normalize(latents) upsampled_latents = upscale_model(latents) finally: upscale_model.cpu() upsampled_latents = vae.first_stage_model.per_channel_statistics.normalize( upsampled_latents ) upsampled_latents = upsampled_latents.to(dtype=input_dtype, device=model_management.intermediate_device()) return_dict = samples.copy() return_dict["samples"] = upsampled_latents return_dict.pop("noise_mask", None) return IO.NodeOutput(return_dict) upsample_latent = execute # TODO: remove class LTXVLatentUpsamplerExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [LTXVLatentUpsampler] async def comfy_entrypoint() -> LTXVLatentUpsamplerExtension: return LTXVLatentUpsamplerExtension()