mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 13:50:49 +08:00
wip latent nodes can return None for graceful behavior in multi-reference-latent scenarios
This commit is contained in:
parent
82bffb7855
commit
6f2589f256
@ -5,7 +5,7 @@ from torch import Tensor
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
ImageBatch = Float[Tensor, "batch height width channels"]
|
||||
LatentBatch = Float[Tensor, "batch channels height width"]
|
||||
LatentBatch = Float[Tensor, "batch channels width height"]
|
||||
SD15LatentBatch = Float[Tensor, "batch 4 height width"]
|
||||
SDXLLatentBatch = Float[Tensor, "batch 8 height width"]
|
||||
SD3LatentBatch = Float[Tensor, "batch 16 height width"]
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import hashlib
|
||||
from PIL import ImageFile, UnidentifiedImageError
|
||||
|
||||
from comfy_api.latest import io
|
||||
from .component_model.files import get_package_as_path
|
||||
from .execution_context import current_execution_context
|
||||
|
||||
|
||||
def conditioning_set_values(conditioning, values: dict = None, append=False):
|
||||
def conditioning_set_values(conditioning, values: dict = None, append=False) -> io.Conditioning.CondList:
|
||||
if values is None:
|
||||
values = {}
|
||||
c = []
|
||||
|
||||
@ -2,6 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
@ -14,6 +16,7 @@ from PIL.PngImagePlugin import PngInfo
|
||||
from huggingface_hub import snapshot_download
|
||||
from natsort import natsorted
|
||||
|
||||
from comfy_api.latest import io
|
||||
from .. import clip_vision as clip_vision_module
|
||||
from .. import controlnet
|
||||
from .. import diffusers_load
|
||||
@ -27,7 +30,7 @@ from ..cli_args import args
|
||||
from ..cmd import folder_paths, latent_preview
|
||||
from ..comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
||||
from ..component_model.deprecation import _deprecate_method
|
||||
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch, RGBAImageBatch
|
||||
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch, RGBAImageBatch, Latent
|
||||
from ..execution_context import current_execution_context
|
||||
from ..images import open_image
|
||||
from ..interruption import interrupt_current_processing
|
||||
@ -308,6 +311,8 @@ class VAEDecode:
|
||||
DESCRIPTION = "Decodes latent images back into pixel space images."
|
||||
|
||||
def decode(self, vae, samples):
|
||||
if samples is None:
|
||||
return None,
|
||||
images = vae.decode(samples["samples"])
|
||||
if len(images.shape) == 5: # Combine batches
|
||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||
@ -331,6 +336,8 @@ class VAEDecodeTiled:
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
|
||||
if samples is None:
|
||||
return None,
|
||||
if tile_size < overlap * 4:
|
||||
overlap = tile_size // 4
|
||||
if temporal_size < temporal_overlap * 2:
|
||||
@ -360,9 +367,11 @@ class VAEEncode:
|
||||
|
||||
CATEGORY = "latent"
|
||||
|
||||
def encode(self, vae: VAE, pixels):
|
||||
def encode(self, vae: VAE, pixels) -> tuple[Optional[Latent]]:
|
||||
if pixels is None:
|
||||
return None,
|
||||
t = vae.encode(pixels[:, :, :, :3])
|
||||
return ({"samples": t},)
|
||||
return (Latent(**{"samples": t}),)
|
||||
|
||||
|
||||
class VAEEncodeTiled:
|
||||
@ -380,9 +389,11 @@ class VAEEncodeTiled:
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
|
||||
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8) -> tuple[Optional[Latent]]:
|
||||
if pixels is None:
|
||||
return None,
|
||||
t = vae.encode_tiled(pixels[:, :, :, :3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||
return ({"samples": t},)
|
||||
return (Latent(**{"samples": t}),)
|
||||
|
||||
|
||||
class VAEEncodeForInpaint:
|
||||
@ -395,7 +406,9 @@ class VAEEncodeForInpaint:
|
||||
|
||||
CATEGORY = "latent/inpaint"
|
||||
|
||||
def encode(self, vae, pixels, mask, grow_mask_by=6):
|
||||
def encode(self, vae, pixels, mask, grow_mask_by=6) -> tuple[Optional[Latent]]:
|
||||
if pixels is None:
|
||||
return None,
|
||||
x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
|
||||
y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
|
||||
@ -423,7 +436,7 @@ class VAEEncodeForInpaint:
|
||||
pixels[:, :, :, i] += 0.5
|
||||
t = vae.encode(pixels)
|
||||
|
||||
return ({"samples": t, "noise_mask": (mask_erosion[:, :, :x, :y].round())},)
|
||||
return (Latent(**{"samples": t, "noise_mask": (mask_erosion[:, :, :x, :y].round())}),)
|
||||
|
||||
|
||||
class InpaintModelConditioning:
|
||||
@ -443,7 +456,9 @@ class InpaintModelConditioning:
|
||||
|
||||
CATEGORY = "conditioning/inpaint"
|
||||
|
||||
def encode(self, positive, negative, pixels, vae, mask, noise_mask=True):
|
||||
def encode(self, positive: io.Conditioning.CondList, negative: io.Conditioning.CondList, pixels, vae, mask, noise_mask=True) -> tuple[io.Conditioning.CondList, io.Conditioning.CondList, Optional[Latent]]:
|
||||
if pixels is None:
|
||||
return positive, negative, None
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
|
||||
@ -464,13 +479,12 @@ class InpaintModelConditioning:
|
||||
concat_latent = vae.encode(pixels)
|
||||
orig_latent = vae.encode(orig_pixels)
|
||||
|
||||
out_latent = {}
|
||||
out_latent: Latent = {"samples": orig_latent}
|
||||
|
||||
out_latent["samples"] = orig_latent
|
||||
if noise_mask:
|
||||
out_latent["noise_mask"] = mask
|
||||
|
||||
out = []
|
||||
out: list[io.Conditioning.CondList] = []
|
||||
for conditioning in [positive, negative]:
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent,
|
||||
"concat_mask": mask})
|
||||
@ -1291,6 +1305,8 @@ class LatentFromBatch:
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
def frombatch(self, samples, batch_index, length):
|
||||
if samples is None:
|
||||
return None,
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||
@ -1324,6 +1340,8 @@ class RepeatLatentBatch:
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
def repeat(self, samples, amount):
|
||||
if samples is None:
|
||||
return None,
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
|
||||
@ -1356,6 +1374,8 @@ class LatentUpscale:
|
||||
CATEGORY = "latent"
|
||||
|
||||
def upscale(self, samples, upscale_method, width, height, crop):
|
||||
if samples is None:
|
||||
return None,
|
||||
if width == 0 and height == 0:
|
||||
s = samples
|
||||
else:
|
||||
@ -1389,6 +1409,8 @@ class LatentUpscaleBy:
|
||||
CATEGORY = "latent"
|
||||
|
||||
def upscale(self, samples, upscale_method, scale_by):
|
||||
if samples is None:
|
||||
return None,
|
||||
s = samples.copy()
|
||||
width = round(samples["samples"].shape[-1] * scale_by)
|
||||
height = round(samples["samples"].shape[-2] * scale_by)
|
||||
@ -1409,6 +1431,8 @@ class LatentRotate:
|
||||
CATEGORY = "latent/transform"
|
||||
|
||||
def rotate(self, samples, rotation):
|
||||
if samples is None:
|
||||
return None,
|
||||
s = samples.copy()
|
||||
rotate_by = 0
|
||||
if rotation.startswith("90"):
|
||||
@ -1435,6 +1459,8 @@ class LatentFlip:
|
||||
CATEGORY = "latent/transform"
|
||||
|
||||
def flip(self, samples, flip_method):
|
||||
if samples is None:
|
||||
return None,
|
||||
s = samples.copy()
|
||||
if flip_method.startswith("x"):
|
||||
s["samples"] = torch.flip(samples["samples"], dims=[2])
|
||||
@ -1546,6 +1572,8 @@ class LatentCrop:
|
||||
CATEGORY = "latent/transform"
|
||||
|
||||
def crop(self, samples, width, height, x, y):
|
||||
if samples is None:
|
||||
return None,
|
||||
s = samples.copy()
|
||||
samples = samples['samples']
|
||||
x = x // 8
|
||||
@ -1578,6 +1606,8 @@ class SetLatentNoiseMask:
|
||||
CATEGORY = "latent/inpaint"
|
||||
|
||||
def set_mask(self, samples, mask):
|
||||
if samples is None:
|
||||
return None,
|
||||
s = samples.copy()
|
||||
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||
return (s,)
|
||||
|
||||
@ -801,6 +801,7 @@ class VAE:
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
do_tile = False
|
||||
samples = None
|
||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||
if not self.not_video:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
|
||||
@ -101,6 +101,8 @@ class VAEDecodeAudio:
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def decode(self, vae, samples):
|
||||
if samples is None:
|
||||
return None,
|
||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
|
||||
@ -97,6 +97,8 @@ class VAEDecodeHunyuan3D:
|
||||
CATEGORY = "latent/3d"
|
||||
|
||||
def decode(self, vae, samples, num_chunks, octree_resolution):
|
||||
if samples is None:
|
||||
return None,
|
||||
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
|
||||
return (voxels, )
|
||||
|
||||
|
||||
@ -293,6 +293,8 @@ class LatentAddNoiseChannels(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples: Latent, std_dev, seed: int, slice_i: int, slice_j: int):
|
||||
if samples is None:
|
||||
return None,
|
||||
s = samples.copy()
|
||||
|
||||
latent = samples["samples"]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user