wip latent nodes can return None for graceful behavior in multi-reference-latent scenarios

This commit is contained in:
Benjamin Berman 2025-10-30 12:38:02 -07:00
parent 82bffb7855
commit 6f2589f256
7 changed files with 51 additions and 13 deletions

View File

@ -5,7 +5,7 @@ from torch import Tensor
from typing_extensions import NotRequired
ImageBatch = Float[Tensor, "batch height width channels"]
LatentBatch = Float[Tensor, "batch channels height width"]
LatentBatch = Float[Tensor, "batch channels width height"]
SD15LatentBatch = Float[Tensor, "batch 4 height width"]
SDXLLatentBatch = Float[Tensor, "batch 8 height width"]
SD3LatentBatch = Float[Tensor, "batch 16 height width"]

View File

@ -1,11 +1,12 @@
import hashlib
from PIL import ImageFile, UnidentifiedImageError
from comfy_api.latest import io
from .component_model.files import get_package_as_path
from .execution_context import current_execution_context
def conditioning_set_values(conditioning, values: dict = None, append=False):
def conditioning_set_values(conditioning, values: dict = None, append=False) -> io.Conditioning.CondList:
if values is None:
values = {}
c = []

View File

@ -2,6 +2,8 @@ from __future__ import annotations
import json
import logging
from typing import Optional
import math
import os
import random
@ -14,6 +16,7 @@ from PIL.PngImagePlugin import PngInfo
from huggingface_hub import snapshot_download
from natsort import natsorted
from comfy_api.latest import io
from .. import clip_vision as clip_vision_module
from .. import controlnet
from .. import diffusers_load
@ -27,7 +30,7 @@ from ..cli_args import args
from ..cmd import folder_paths, latent_preview
from ..comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
from ..component_model.deprecation import _deprecate_method
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch, RGBAImageBatch
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch, RGBAImageBatch, Latent
from ..execution_context import current_execution_context
from ..images import open_image
from ..interruption import interrupt_current_processing
@ -308,6 +311,8 @@ class VAEDecode:
DESCRIPTION = "Decodes latent images back into pixel space images."
def decode(self, vae, samples):
if samples is None:
return None,
images = vae.decode(samples["samples"])
if len(images.shape) == 5: # Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
@ -331,6 +336,8 @@ class VAEDecodeTiled:
CATEGORY = "_for_testing"
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
if samples is None:
return None,
if tile_size < overlap * 4:
overlap = tile_size // 4
if temporal_size < temporal_overlap * 2:
@ -360,9 +367,11 @@ class VAEEncode:
CATEGORY = "latent"
def encode(self, vae: VAE, pixels):
def encode(self, vae: VAE, pixels) -> tuple[Optional[Latent]]:
if pixels is None:
return None,
t = vae.encode(pixels[:, :, :, :3])
return ({"samples": t},)
return (Latent(**{"samples": t}),)
class VAEEncodeTiled:
@ -380,9 +389,11 @@ class VAEEncodeTiled:
CATEGORY = "_for_testing"
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8) -> tuple[Optional[Latent]]:
if pixels is None:
return None,
t = vae.encode_tiled(pixels[:, :, :, :3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
return ({"samples": t},)
return (Latent(**{"samples": t}),)
class VAEEncodeForInpaint:
@ -395,7 +406,9 @@ class VAEEncodeForInpaint:
CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask, grow_mask_by=6):
def encode(self, vae, pixels, mask, grow_mask_by=6) -> tuple[Optional[Latent]]:
if pixels is None:
return None,
x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
@ -423,7 +436,7 @@ class VAEEncodeForInpaint:
pixels[:, :, :, i] += 0.5
t = vae.encode(pixels)
return ({"samples": t, "noise_mask": (mask_erosion[:, :, :x, :y].round())},)
return (Latent(**{"samples": t, "noise_mask": (mask_erosion[:, :, :x, :y].round())}),)
class InpaintModelConditioning:
@ -443,7 +456,9 @@ class InpaintModelConditioning:
CATEGORY = "conditioning/inpaint"
def encode(self, positive, negative, pixels, vae, mask, noise_mask=True):
def encode(self, positive: io.Conditioning.CondList, negative: io.Conditioning.CondList, pixels, vae, mask, noise_mask=True) -> tuple[io.Conditioning.CondList, io.Conditioning.CondList, Optional[Latent]]:
if pixels is None:
return positive, negative, None
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
@ -464,13 +479,12 @@ class InpaintModelConditioning:
concat_latent = vae.encode(pixels)
orig_latent = vae.encode(orig_pixels)
out_latent = {}
out_latent: Latent = {"samples": orig_latent}
out_latent["samples"] = orig_latent
if noise_mask:
out_latent["noise_mask"] = mask
out = []
out: list[io.Conditioning.CondList] = []
for conditioning in [positive, negative]:
c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent,
"concat_mask": mask})
@ -1291,6 +1305,8 @@ class LatentFromBatch:
CATEGORY = "latent/batch"
def frombatch(self, samples, batch_index, length):
if samples is None:
return None,
s = samples.copy()
s_in = samples["samples"]
batch_index = min(s_in.shape[0] - 1, batch_index)
@ -1324,6 +1340,8 @@ class RepeatLatentBatch:
CATEGORY = "latent/batch"
def repeat(self, samples, amount):
if samples is None:
return None,
s = samples.copy()
s_in = samples["samples"]
@ -1356,6 +1374,8 @@ class LatentUpscale:
CATEGORY = "latent"
def upscale(self, samples, upscale_method, width, height, crop):
if samples is None:
return None,
if width == 0 and height == 0:
s = samples
else:
@ -1389,6 +1409,8 @@ class LatentUpscaleBy:
CATEGORY = "latent"
def upscale(self, samples, upscale_method, scale_by):
if samples is None:
return None,
s = samples.copy()
width = round(samples["samples"].shape[-1] * scale_by)
height = round(samples["samples"].shape[-2] * scale_by)
@ -1409,6 +1431,8 @@ class LatentRotate:
CATEGORY = "latent/transform"
def rotate(self, samples, rotation):
if samples is None:
return None,
s = samples.copy()
rotate_by = 0
if rotation.startswith("90"):
@ -1435,6 +1459,8 @@ class LatentFlip:
CATEGORY = "latent/transform"
def flip(self, samples, flip_method):
if samples is None:
return None,
s = samples.copy()
if flip_method.startswith("x"):
s["samples"] = torch.flip(samples["samples"], dims=[2])
@ -1546,6 +1572,8 @@ class LatentCrop:
CATEGORY = "latent/transform"
def crop(self, samples, width, height, x, y):
if samples is None:
return None,
s = samples.copy()
samples = samples['samples']
x = x // 8
@ -1578,6 +1606,8 @@ class SetLatentNoiseMask:
CATEGORY = "latent/inpaint"
def set_mask(self, samples, mask):
if samples is None:
return None,
s = samples.copy()
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
return (s,)

View File

@ -801,6 +801,7 @@ class VAE:
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
do_tile = False
samples = None
if self.latent_dim == 3 and pixel_samples.ndim < 5:
if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)

View File

@ -101,6 +101,8 @@ class VAEDecodeAudio:
CATEGORY = "latent/audio"
def decode(self, vae, samples):
if samples is None:
return None,
audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
std[std < 1.0] = 1.0

View File

@ -97,6 +97,8 @@ class VAEDecodeHunyuan3D:
CATEGORY = "latent/3d"
def decode(self, vae, samples, num_chunks, octree_resolution):
if samples is None:
return None,
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
return (voxels, )

View File

@ -293,6 +293,8 @@ class LatentAddNoiseChannels(io.ComfyNode):
@classmethod
def execute(cls, samples: Latent, std_dev, seed: int, slice_i: int, slice_j: int):
if samples is None:
return None,
s = samples.copy()
latent = samples["samples"]