wip latent nodes can return None for graceful behavior in multi-reference-latent scenarios

This commit is contained in:
Benjamin Berman 2025-10-30 12:38:02 -07:00
parent 82bffb7855
commit 6f2589f256
7 changed files with 51 additions and 13 deletions

View File

@ -5,7 +5,7 @@ from torch import Tensor
from typing_extensions import NotRequired from typing_extensions import NotRequired
ImageBatch = Float[Tensor, "batch height width channels"] ImageBatch = Float[Tensor, "batch height width channels"]
LatentBatch = Float[Tensor, "batch channels height width"] LatentBatch = Float[Tensor, "batch channels width height"]
SD15LatentBatch = Float[Tensor, "batch 4 height width"] SD15LatentBatch = Float[Tensor, "batch 4 height width"]
SDXLLatentBatch = Float[Tensor, "batch 8 height width"] SDXLLatentBatch = Float[Tensor, "batch 8 height width"]
SD3LatentBatch = Float[Tensor, "batch 16 height width"] SD3LatentBatch = Float[Tensor, "batch 16 height width"]

View File

@ -1,11 +1,12 @@
import hashlib import hashlib
from PIL import ImageFile, UnidentifiedImageError from PIL import ImageFile, UnidentifiedImageError
from comfy_api.latest import io
from .component_model.files import get_package_as_path from .component_model.files import get_package_as_path
from .execution_context import current_execution_context from .execution_context import current_execution_context
def conditioning_set_values(conditioning, values: dict = None, append=False): def conditioning_set_values(conditioning, values: dict = None, append=False) -> io.Conditioning.CondList:
if values is None: if values is None:
values = {} values = {}
c = [] c = []

View File

@ -2,6 +2,8 @@ from __future__ import annotations
import json import json
import logging import logging
from typing import Optional
import math import math
import os import os
import random import random
@ -14,6 +16,7 @@ from PIL.PngImagePlugin import PngInfo
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from natsort import natsorted from natsort import natsorted
from comfy_api.latest import io
from .. import clip_vision as clip_vision_module from .. import clip_vision as clip_vision_module
from .. import controlnet from .. import controlnet
from .. import diffusers_load from .. import diffusers_load
@ -27,7 +30,7 @@ from ..cli_args import args
from ..cmd import folder_paths, latent_preview from ..cmd import folder_paths, latent_preview
from ..comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator from ..comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
from ..component_model.deprecation import _deprecate_method from ..component_model.deprecation import _deprecate_method
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch, RGBAImageBatch from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch, RGBAImageBatch, Latent
from ..execution_context import current_execution_context from ..execution_context import current_execution_context
from ..images import open_image from ..images import open_image
from ..interruption import interrupt_current_processing from ..interruption import interrupt_current_processing
@ -308,6 +311,8 @@ class VAEDecode:
DESCRIPTION = "Decodes latent images back into pixel space images." DESCRIPTION = "Decodes latent images back into pixel space images."
def decode(self, vae, samples): def decode(self, vae, samples):
if samples is None:
return None,
images = vae.decode(samples["samples"]) images = vae.decode(samples["samples"])
if len(images.shape) == 5: # Combine batches if len(images.shape) == 5: # Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
@ -331,6 +336,8 @@ class VAEDecodeTiled:
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8): def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
if samples is None:
return None,
if tile_size < overlap * 4: if tile_size < overlap * 4:
overlap = tile_size // 4 overlap = tile_size // 4
if temporal_size < temporal_overlap * 2: if temporal_size < temporal_overlap * 2:
@ -360,9 +367,11 @@ class VAEEncode:
CATEGORY = "latent" CATEGORY = "latent"
def encode(self, vae: VAE, pixels): def encode(self, vae: VAE, pixels) -> tuple[Optional[Latent]]:
if pixels is None:
return None,
t = vae.encode(pixels[:, :, :, :3]) t = vae.encode(pixels[:, :, :, :3])
return ({"samples": t},) return (Latent(**{"samples": t}),)
class VAEEncodeTiled: class VAEEncodeTiled:
@ -380,9 +389,11 @@ class VAEEncodeTiled:
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8) -> tuple[Optional[Latent]]:
if pixels is None:
return None,
t = vae.encode_tiled(pixels[:, :, :, :3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) t = vae.encode_tiled(pixels[:, :, :, :3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
return ({"samples": t},) return (Latent(**{"samples": t}),)
class VAEEncodeForInpaint: class VAEEncodeForInpaint:
@ -395,7 +406,9 @@ class VAEEncodeForInpaint:
CATEGORY = "latent/inpaint" CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask, grow_mask_by=6): def encode(self, vae, pixels, mask, grow_mask_by=6) -> tuple[Optional[Latent]]:
if pixels is None:
return None,
x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
@ -423,7 +436,7 @@ class VAEEncodeForInpaint:
pixels[:, :, :, i] += 0.5 pixels[:, :, :, i] += 0.5
t = vae.encode(pixels) t = vae.encode(pixels)
return ({"samples": t, "noise_mask": (mask_erosion[:, :, :x, :y].round())},) return (Latent(**{"samples": t, "noise_mask": (mask_erosion[:, :, :x, :y].round())}),)
class InpaintModelConditioning: class InpaintModelConditioning:
@ -443,7 +456,9 @@ class InpaintModelConditioning:
CATEGORY = "conditioning/inpaint" CATEGORY = "conditioning/inpaint"
def encode(self, positive, negative, pixels, vae, mask, noise_mask=True): def encode(self, positive: io.Conditioning.CondList, negative: io.Conditioning.CondList, pixels, vae, mask, noise_mask=True) -> tuple[io.Conditioning.CondList, io.Conditioning.CondList, Optional[Latent]]:
if pixels is None:
return positive, negative, None
x = (pixels.shape[1] // 8) * 8 x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8 y = (pixels.shape[2] // 8) * 8
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
@ -464,13 +479,12 @@ class InpaintModelConditioning:
concat_latent = vae.encode(pixels) concat_latent = vae.encode(pixels)
orig_latent = vae.encode(orig_pixels) orig_latent = vae.encode(orig_pixels)
out_latent = {} out_latent: Latent = {"samples": orig_latent}
out_latent["samples"] = orig_latent
if noise_mask: if noise_mask:
out_latent["noise_mask"] = mask out_latent["noise_mask"] = mask
out = [] out: list[io.Conditioning.CondList] = []
for conditioning in [positive, negative]: for conditioning in [positive, negative]:
c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent, c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent,
"concat_mask": mask}) "concat_mask": mask})
@ -1291,6 +1305,8 @@ class LatentFromBatch:
CATEGORY = "latent/batch" CATEGORY = "latent/batch"
def frombatch(self, samples, batch_index, length): def frombatch(self, samples, batch_index, length):
if samples is None:
return None,
s = samples.copy() s = samples.copy()
s_in = samples["samples"] s_in = samples["samples"]
batch_index = min(s_in.shape[0] - 1, batch_index) batch_index = min(s_in.shape[0] - 1, batch_index)
@ -1324,6 +1340,8 @@ class RepeatLatentBatch:
CATEGORY = "latent/batch" CATEGORY = "latent/batch"
def repeat(self, samples, amount): def repeat(self, samples, amount):
if samples is None:
return None,
s = samples.copy() s = samples.copy()
s_in = samples["samples"] s_in = samples["samples"]
@ -1356,6 +1374,8 @@ class LatentUpscale:
CATEGORY = "latent" CATEGORY = "latent"
def upscale(self, samples, upscale_method, width, height, crop): def upscale(self, samples, upscale_method, width, height, crop):
if samples is None:
return None,
if width == 0 and height == 0: if width == 0 and height == 0:
s = samples s = samples
else: else:
@ -1389,6 +1409,8 @@ class LatentUpscaleBy:
CATEGORY = "latent" CATEGORY = "latent"
def upscale(self, samples, upscale_method, scale_by): def upscale(self, samples, upscale_method, scale_by):
if samples is None:
return None,
s = samples.copy() s = samples.copy()
width = round(samples["samples"].shape[-1] * scale_by) width = round(samples["samples"].shape[-1] * scale_by)
height = round(samples["samples"].shape[-2] * scale_by) height = round(samples["samples"].shape[-2] * scale_by)
@ -1409,6 +1431,8 @@ class LatentRotate:
CATEGORY = "latent/transform" CATEGORY = "latent/transform"
def rotate(self, samples, rotation): def rotate(self, samples, rotation):
if samples is None:
return None,
s = samples.copy() s = samples.copy()
rotate_by = 0 rotate_by = 0
if rotation.startswith("90"): if rotation.startswith("90"):
@ -1435,6 +1459,8 @@ class LatentFlip:
CATEGORY = "latent/transform" CATEGORY = "latent/transform"
def flip(self, samples, flip_method): def flip(self, samples, flip_method):
if samples is None:
return None,
s = samples.copy() s = samples.copy()
if flip_method.startswith("x"): if flip_method.startswith("x"):
s["samples"] = torch.flip(samples["samples"], dims=[2]) s["samples"] = torch.flip(samples["samples"], dims=[2])
@ -1546,6 +1572,8 @@ class LatentCrop:
CATEGORY = "latent/transform" CATEGORY = "latent/transform"
def crop(self, samples, width, height, x, y): def crop(self, samples, width, height, x, y):
if samples is None:
return None,
s = samples.copy() s = samples.copy()
samples = samples['samples'] samples = samples['samples']
x = x // 8 x = x // 8
@ -1578,6 +1606,8 @@ class SetLatentNoiseMask:
CATEGORY = "latent/inpaint" CATEGORY = "latent/inpaint"
def set_mask(self, samples, mask): def set_mask(self, samples, mask):
if samples is None:
return None,
s = samples.copy() s = samples.copy()
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
return (s,) return (s,)

View File

@ -801,6 +801,7 @@ class VAE:
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1) pixel_samples = pixel_samples.movedim(-1, 1)
do_tile = False do_tile = False
samples = None
if self.latent_dim == 3 and pixel_samples.ndim < 5: if self.latent_dim == 3 and pixel_samples.ndim < 5:
if not self.not_video: if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)

View File

@ -101,6 +101,8 @@ class VAEDecodeAudio:
CATEGORY = "latent/audio" CATEGORY = "latent/audio"
def decode(self, vae, samples): def decode(self, vae, samples):
if samples is None:
return None,
audio = vae.decode(samples["samples"]).movedim(-1, 1) audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
std[std < 1.0] = 1.0 std[std < 1.0] = 1.0

View File

@ -97,6 +97,8 @@ class VAEDecodeHunyuan3D:
CATEGORY = "latent/3d" CATEGORY = "latent/3d"
def decode(self, vae, samples, num_chunks, octree_resolution): def decode(self, vae, samples, num_chunks, octree_resolution):
if samples is None:
return None,
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
return (voxels, ) return (voxels, )

View File

@ -293,6 +293,8 @@ class LatentAddNoiseChannels(io.ComfyNode):
@classmethod @classmethod
def execute(cls, samples: Latent, std_dev, seed: int, slice_i: int, slice_j: int): def execute(cls, samples: Latent, std_dev, seed: int, slice_i: int, slice_j: int):
if samples is None:
return None,
s = samples.copy() s = samples.copy()
latent = samples["samples"] latent = samples["samples"]