mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 20:30:25 +08:00
Also support video taes in previews
Only first frame for now as live preview playback is currently only available through VHS custom nodes.
This commit is contained in:
parent
cd1e1efb21
commit
7ef46f66ff
@ -382,6 +382,7 @@ class HunyuanVideo(LatentFormat):
|
|||||||
]
|
]
|
||||||
|
|
||||||
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
||||||
|
taesd_decoder_name = "taehv"
|
||||||
|
|
||||||
class Cosmos1CV8x8x8(LatentFormat):
|
class Cosmos1CV8x8x8(LatentFormat):
|
||||||
latent_channels = 16
|
latent_channels = 16
|
||||||
@ -445,7 +446,7 @@ class Wan21(LatentFormat):
|
|||||||
]).view(1, self.latent_channels, 1, 1, 1)
|
]).view(1, self.latent_channels, 1, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
self.taesd_decoder_name = None #TODO
|
self.taesd_decoder_name = "lighttaew2_1"
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||||
@ -516,6 +517,7 @@ class Wan22(Wan21):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scale_factor = 1.0
|
self.scale_factor = 1.0
|
||||||
|
self.taesd_decoder_name = "lighttaew2_2"
|
||||||
self.latents_mean = torch.tensor([
|
self.latents_mean = torch.tensor([
|
||||||
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||||
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||||
@ -670,6 +672,7 @@ class HunyuanVideo15(LatentFormat):
|
|||||||
latent_channels = 32
|
latent_channels = 32
|
||||||
latent_dimensions = 3
|
latent_dimensions = 3
|
||||||
scale_factor = 1.03682
|
scale_factor = 1.03682
|
||||||
|
taesd_decoder_name = "lighttaehy1_5"
|
||||||
|
|
||||||
class Hunyuan3Dv2(LatentFormat):
|
class Hunyuan3Dv2(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
|
|||||||
@ -112,13 +112,14 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
|||||||
|
|
||||||
|
|
||||||
class TAEHV(nn.Module):
|
class TAEHV(nn.Module):
|
||||||
def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None):
|
def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.image_channels = 3
|
self.image_channels = 3
|
||||||
self.patch_size = 1
|
self.patch_size = 1
|
||||||
self.latent_channels = latent_channels
|
self.latent_channels = latent_channels
|
||||||
self.parallel = parallel
|
self.parallel = parallel
|
||||||
self.latent_format = latent_format
|
self.latent_format = latent_format
|
||||||
|
self.show_progress_bar = show_progress_bar
|
||||||
self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x)
|
self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x)
|
||||||
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
|
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
|
||||||
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
|
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
|
||||||
@ -144,8 +145,15 @@ class TAEHV(nn.Module):
|
|||||||
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
|
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
|
||||||
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
|
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
|
||||||
)
|
)
|
||||||
|
@property
|
||||||
|
def show_progress_bar(self):
|
||||||
|
return self._show_progress_bar
|
||||||
|
|
||||||
def encode(self, x, show_progress_bar=True, **kwargs):
|
@show_progress_bar.setter
|
||||||
|
def show_progress_bar(self, value):
|
||||||
|
self._show_progress_bar = value
|
||||||
|
|
||||||
|
def encode(self, x, **kwargs):
|
||||||
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
|
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
|
||||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||||
if x.shape[1] % 4 != 0:
|
if x.shape[1] % 4 != 0:
|
||||||
@ -153,11 +161,11 @@ class TAEHV(nn.Module):
|
|||||||
n_pad = 4 - x.shape[1] % 4
|
n_pad = 4 - x.shape[1] % 4
|
||||||
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
|
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
|
||||||
x = torch.cat([x, padding], 1)
|
x = torch.cat([x, padding], 1)
|
||||||
x = apply_model_with_memblocks(self.encoder, x, self.parallel, show_progress_bar).movedim(2, 1)
|
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
|
||||||
return self.process_out(x)
|
return self.process_out(x)
|
||||||
|
|
||||||
def decode(self, x, show_progress_bar=True, **kwargs):
|
def decode(self, x, **kwargs):
|
||||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, show_progress_bar)
|
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||||
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
|
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
|
||||||
return x[:, self.frames_to_trim:].movedim(2, 1)
|
return x[:, self.frames_to_trim:].movedim(2, 1)
|
||||||
|
|||||||
@ -2,17 +2,24 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from comfy.cli_args import args, LatentPreviewMethod
|
from comfy.cli_args import args, LatentPreviewMethod
|
||||||
from comfy.taesd.taesd import TAESD
|
from comfy.taesd.taesd import TAESD
|
||||||
|
from comfy.sd import VAE
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
MAX_PREVIEW_RESOLUTION = args.preview_size
|
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||||
|
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
||||||
|
|
||||||
def preview_to_image(latent_image):
|
def preview_to_image(latent_image, do_scale=True):
|
||||||
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
if do_scale:
|
||||||
.mul(0xFF) # to 0..255
|
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
)
|
.mul(0xFF) # to 0..255
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
latents_ubyte = (latent_image.clamp(0, 1)
|
||||||
|
.mul(0xFF) # to 0..255
|
||||||
|
)
|
||||||
if comfy.model_management.directml_enabled:
|
if comfy.model_management.directml_enabled:
|
||||||
latents_ubyte = latents_ubyte.to(dtype=torch.uint8)
|
latents_ubyte = latents_ubyte.to(dtype=torch.uint8)
|
||||||
latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
|
latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
|
||||||
@ -35,6 +42,10 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
|||||||
x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
|
x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
|
||||||
return preview_to_image(x_sample)
|
return preview_to_image(x_sample)
|
||||||
|
|
||||||
|
class TAEHVPreviewerImpl(TAESDPreviewerImpl):
|
||||||
|
def decode_latent_to_preview(self, x0):
|
||||||
|
x_sample = self.taesd.decode(x0[:1, :, :1])[0][0]
|
||||||
|
return preview_to_image(x_sample, do_scale=False)
|
||||||
|
|
||||||
class Latent2RGBPreviewer(LatentPreviewer):
|
class Latent2RGBPreviewer(LatentPreviewer):
|
||||||
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
|
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
|
||||||
@ -78,8 +89,13 @@ def get_previewer(device, latent_format):
|
|||||||
|
|
||||||
if method == LatentPreviewMethod.TAESD:
|
if method == LatentPreviewMethod.TAESD:
|
||||||
if taesd_decoder_path:
|
if taesd_decoder_path:
|
||||||
taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
|
if latent_format.taesd_decoder_name in VIDEO_TAES:
|
||||||
previewer = TAESDPreviewerImpl(taesd)
|
taesd = VAE(comfy.utils.load_torch_file(taesd_decoder_path))
|
||||||
|
taesd.first_stage_model.show_progress_bar = False
|
||||||
|
previewer = TAEHVPreviewerImpl(taesd)
|
||||||
|
else:
|
||||||
|
taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
|
||||||
|
previewer = TAESDPreviewerImpl(taesd)
|
||||||
else:
|
else:
|
||||||
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
||||||
|
|
||||||
|
|||||||
18
nodes.py
18
nodes.py
@ -692,8 +692,10 @@ class LoraLoaderModelOnly(LoraLoader):
|
|||||||
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
|
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
|
||||||
|
|
||||||
class VAELoader:
|
class VAELoader:
|
||||||
|
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
||||||
|
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def vae_list():
|
def vae_list(s):
|
||||||
vaes = folder_paths.get_filename_list("vae")
|
vaes = folder_paths.get_filename_list("vae")
|
||||||
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
||||||
sdxl_taesd_enc = False
|
sdxl_taesd_enc = False
|
||||||
@ -722,6 +724,11 @@ class VAELoader:
|
|||||||
f1_taesd_dec = True
|
f1_taesd_dec = True
|
||||||
elif v.startswith("taef1_decoder."):
|
elif v.startswith("taef1_decoder."):
|
||||||
f1_taesd_enc = True
|
f1_taesd_enc = True
|
||||||
|
else:
|
||||||
|
for tae in s.video_taes:
|
||||||
|
if v.startswith(tae):
|
||||||
|
vaes.append(v)
|
||||||
|
|
||||||
if sd1_taesd_dec and sd1_taesd_enc:
|
if sd1_taesd_dec and sd1_taesd_enc:
|
||||||
vaes.append("taesd")
|
vaes.append("taesd")
|
||||||
if sdxl_taesd_dec and sdxl_taesd_enc:
|
if sdxl_taesd_dec and sdxl_taesd_enc:
|
||||||
@ -765,7 +772,7 @@ class VAELoader:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "vae_name": (s.vae_list(), )}}
|
return {"required": { "vae_name": (s.vae_list(s), )}}
|
||||||
RETURN_TYPES = ("VAE",)
|
RETURN_TYPES = ("VAE",)
|
||||||
FUNCTION = "load_vae"
|
FUNCTION = "load_vae"
|
||||||
|
|
||||||
@ -776,10 +783,13 @@ class VAELoader:
|
|||||||
if vae_name == "pixel_space":
|
if vae_name == "pixel_space":
|
||||||
sd = {}
|
sd = {}
|
||||||
sd["pixel_space_vae"] = torch.tensor(1.0)
|
sd["pixel_space_vae"] = torch.tensor(1.0)
|
||||||
elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
elif vae_name in self.image_taes:
|
||||||
sd = self.load_taesd(vae_name)
|
sd = self.load_taesd(vae_name)
|
||||||
else:
|
else:
|
||||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
if os.path.splitext(vae_name)[0] in self.video_taes:
|
||||||
|
vae_path = folder_paths.get_full_path_or_raise("vae_approx", vae_name)
|
||||||
|
else:
|
||||||
|
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||||
sd = comfy.utils.load_torch_file(vae_path)
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
vae = comfy.sd.VAE(sd=sd)
|
vae = comfy.sd.VAE(sd=sd)
|
||||||
vae.throw_exception_if_invalid()
|
vae.throw_exception_if_invalid()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user