diff --git a/comfy/component_model/tensor_types.py b/comfy/component_model/tensor_types.py index f244e381c..63d96b60d 100644 --- a/comfy/component_model/tensor_types.py +++ b/comfy/component_model/tensor_types.py @@ -5,7 +5,7 @@ from torch import Tensor from typing_extensions import NotRequired ImageBatch = Float[Tensor, "batch height width channels"] -LatentBatch = Float[Tensor, "batch channels height width"] +LatentBatch = Float[Tensor, "batch channels width height"] SD15LatentBatch = Float[Tensor, "batch 4 height width"] SDXLLatentBatch = Float[Tensor, "batch 8 height width"] SD3LatentBatch = Float[Tensor, "batch 16 height width"] diff --git a/comfy/node_helpers.py b/comfy/node_helpers.py index 4278d293a..71482a9d3 100644 --- a/comfy/node_helpers.py +++ b/comfy/node_helpers.py @@ -1,11 +1,12 @@ import hashlib from PIL import ImageFile, UnidentifiedImageError +from comfy_api.latest import io from .component_model.files import get_package_as_path from .execution_context import current_execution_context -def conditioning_set_values(conditioning, values: dict = None, append=False): +def conditioning_set_values(conditioning, values: dict = None, append=False) -> io.Conditioning.CondList: if values is None: values = {} c = [] diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 93c76ab87..ab013a9eb 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -2,6 +2,8 @@ from __future__ import annotations import json import logging +from typing import Optional + import math import os import random @@ -14,6 +16,7 @@ from PIL.PngImagePlugin import PngInfo from huggingface_hub import snapshot_download from natsort import natsorted +from comfy_api.latest import io from .. import clip_vision as clip_vision_module from .. import controlnet from .. import diffusers_load @@ -27,7 +30,7 @@ from ..cli_args import args from ..cmd import folder_paths, latent_preview from ..comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator from ..component_model.deprecation import _deprecate_method -from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch, RGBAImageBatch +from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch, RGBAImageBatch, Latent from ..execution_context import current_execution_context from ..images import open_image from ..interruption import interrupt_current_processing @@ -308,6 +311,8 @@ class VAEDecode: DESCRIPTION = "Decodes latent images back into pixel space images." def decode(self, vae, samples): + if samples is None: + return None, images = vae.decode(samples["samples"]) if len(images.shape) == 5: # Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) @@ -331,6 +336,8 @@ class VAEDecodeTiled: CATEGORY = "_for_testing" def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8): + if samples is None: + return None, if tile_size < overlap * 4: overlap = tile_size // 4 if temporal_size < temporal_overlap * 2: @@ -360,9 +367,11 @@ class VAEEncode: CATEGORY = "latent" - def encode(self, vae: VAE, pixels): + def encode(self, vae: VAE, pixels) -> tuple[Optional[Latent]]: + if pixels is None: + return None, t = vae.encode(pixels[:, :, :, :3]) - return ({"samples": t},) + return (Latent(**{"samples": t}),) class VAEEncodeTiled: @@ -380,9 +389,11 @@ class VAEEncodeTiled: CATEGORY = "_for_testing" - def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): + def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8) -> tuple[Optional[Latent]]: + if pixels is None: + return None, t = vae.encode_tiled(pixels[:, :, :, :3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) - return ({"samples": t},) + return (Latent(**{"samples": t}),) class VAEEncodeForInpaint: @@ -395,7 +406,9 @@ class VAEEncodeForInpaint: CATEGORY = "latent/inpaint" - def encode(self, vae, pixels, mask, grow_mask_by=6): + def encode(self, vae, pixels, mask, grow_mask_by=6) -> tuple[Optional[Latent]]: + if pixels is None: + return None, x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") @@ -423,7 +436,7 @@ class VAEEncodeForInpaint: pixels[:, :, :, i] += 0.5 t = vae.encode(pixels) - return ({"samples": t, "noise_mask": (mask_erosion[:, :, :x, :y].round())},) + return (Latent(**{"samples": t, "noise_mask": (mask_erosion[:, :, :x, :y].round())}),) class InpaintModelConditioning: @@ -443,7 +456,9 @@ class InpaintModelConditioning: CATEGORY = "conditioning/inpaint" - def encode(self, positive, negative, pixels, vae, mask, noise_mask=True): + def encode(self, positive: io.Conditioning.CondList, negative: io.Conditioning.CondList, pixels, vae, mask, noise_mask=True) -> tuple[io.Conditioning.CondList, io.Conditioning.CondList, Optional[Latent]]: + if pixels is None: + return positive, negative, None x = (pixels.shape[1] // 8) * 8 y = (pixels.shape[2] // 8) * 8 mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") @@ -464,13 +479,12 @@ class InpaintModelConditioning: concat_latent = vae.encode(pixels) orig_latent = vae.encode(orig_pixels) - out_latent = {} + out_latent: Latent = {"samples": orig_latent} - out_latent["samples"] = orig_latent if noise_mask: out_latent["noise_mask"] = mask - out = [] + out: list[io.Conditioning.CondList] = [] for conditioning in [positive, negative]: c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent, "concat_mask": mask}) @@ -1291,6 +1305,8 @@ class LatentFromBatch: CATEGORY = "latent/batch" def frombatch(self, samples, batch_index, length): + if samples is None: + return None, s = samples.copy() s_in = samples["samples"] batch_index = min(s_in.shape[0] - 1, batch_index) @@ -1324,6 +1340,8 @@ class RepeatLatentBatch: CATEGORY = "latent/batch" def repeat(self, samples, amount): + if samples is None: + return None, s = samples.copy() s_in = samples["samples"] @@ -1356,6 +1374,8 @@ class LatentUpscale: CATEGORY = "latent" def upscale(self, samples, upscale_method, width, height, crop): + if samples is None: + return None, if width == 0 and height == 0: s = samples else: @@ -1389,6 +1409,8 @@ class LatentUpscaleBy: CATEGORY = "latent" def upscale(self, samples, upscale_method, scale_by): + if samples is None: + return None, s = samples.copy() width = round(samples["samples"].shape[-1] * scale_by) height = round(samples["samples"].shape[-2] * scale_by) @@ -1409,6 +1431,8 @@ class LatentRotate: CATEGORY = "latent/transform" def rotate(self, samples, rotation): + if samples is None: + return None, s = samples.copy() rotate_by = 0 if rotation.startswith("90"): @@ -1435,6 +1459,8 @@ class LatentFlip: CATEGORY = "latent/transform" def flip(self, samples, flip_method): + if samples is None: + return None, s = samples.copy() if flip_method.startswith("x"): s["samples"] = torch.flip(samples["samples"], dims=[2]) @@ -1546,6 +1572,8 @@ class LatentCrop: CATEGORY = "latent/transform" def crop(self, samples, width, height, x, y): + if samples is None: + return None, s = samples.copy() samples = samples['samples'] x = x // 8 @@ -1578,6 +1606,8 @@ class SetLatentNoiseMask: CATEGORY = "latent/inpaint" def set_mask(self, samples, mask): + if samples is None: + return None, s = samples.copy() s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) return (s,) diff --git a/comfy/sd.py b/comfy/sd.py index 7cd8e5d40..b258563e0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -801,6 +801,7 @@ class VAE: pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) do_tile = False + samples = None if self.latent_dim == 3 and pixel_samples.ndim < 5: if not self.not_video: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) diff --git a/comfy_extras/nodes/nodes_audio.py b/comfy_extras/nodes/nodes_audio.py index 804de3f91..5344bb53c 100644 --- a/comfy_extras/nodes/nodes_audio.py +++ b/comfy_extras/nodes/nodes_audio.py @@ -101,6 +101,8 @@ class VAEDecodeAudio: CATEGORY = "latent/audio" def decode(self, vae, samples): + if samples is None: + return None, audio = vae.decode(samples["samples"]).movedim(-1, 1) std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 diff --git a/comfy_extras/nodes/nodes_hunyuan3d.py b/comfy_extras/nodes/nodes_hunyuan3d.py index c6b924634..dc30179de 100644 --- a/comfy_extras/nodes/nodes_hunyuan3d.py +++ b/comfy_extras/nodes/nodes_hunyuan3d.py @@ -97,6 +97,8 @@ class VAEDecodeHunyuan3D: CATEGORY = "latent/3d" def decode(self, vae, samples, num_chunks, octree_resolution): + if samples is None: + return None, voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) return (voxels, ) diff --git a/comfy_extras/nodes/nodes_latent.py b/comfy_extras/nodes/nodes_latent.py index 368c7433e..773bdcd7f 100644 --- a/comfy_extras/nodes/nodes_latent.py +++ b/comfy_extras/nodes/nodes_latent.py @@ -293,6 +293,8 @@ class LatentAddNoiseChannels(io.ComfyNode): @classmethod def execute(cls, samples: Latent, std_dev, seed: int, slice_i: int, slice_j: int): + if samples is None: + return None, s = samples.copy() latent = samples["samples"]