mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 09:27:24 +08:00
Add SeedVR2 model and VAE support
This commit is contained in:
parent
57414dadfe
commit
15a500ff6b
1
.gitignore
vendored
1
.gitignore
vendored
@ -13,6 +13,7 @@ extra_model_paths.yaml
|
||||
.idea/
|
||||
venv*/
|
||||
.venv/
|
||||
.pyisolate_venvs/
|
||||
/web/extensions/*
|
||||
!/web/extensions/logging.js.example
|
||||
!/web/extensions/core/
|
||||
|
||||
@ -4,6 +4,7 @@ class LatentFormat:
|
||||
scale_factor = 1.0
|
||||
latent_channels = 4
|
||||
latent_dimensions = 2
|
||||
preserve_empty_channel_multiples = False
|
||||
latent_rgb_factors = None
|
||||
latent_rgb_factors_bias = None
|
||||
latent_rgb_factors_reshape = None
|
||||
@ -769,6 +770,10 @@ class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
class SeedVR2(LatentFormat):
|
||||
latent_channels = 16
|
||||
preserve_empty_channel_multiples = True
|
||||
|
||||
class ACEAudio15(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
1742
comfy/ldm/seedvr/model.py
Normal file
1742
comfy/ldm/seedvr/model.py
Normal file
File diff suppressed because it is too large
Load Diff
2421
comfy/ldm/seedvr/vae.py
Normal file
2421
comfy/ldm/seedvr/vae.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -50,6 +50,8 @@ import comfy.ldm.chroma.model
|
||||
import comfy.ldm.chroma_radiance.model
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.seedvr.model
|
||||
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
@ -923,6 +925,16 @@ class HunyuanDiT(BaseModel):
|
||||
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||
return out
|
||||
|
||||
class SeedVR2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
condition = kwargs.get("condition", None)
|
||||
if condition is not None:
|
||||
out["condition"] = comfy.conds.CONDRegular(condition)
|
||||
return out
|
||||
|
||||
class PixArt(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
|
||||
|
||||
@ -577,6 +577,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 3072
|
||||
dit_config["heads"] = 24
|
||||
dit_config["num_layers"] = 36
|
||||
# 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.``
|
||||
# submodules) at EVERY block — verified by inspecting the 7B
|
||||
# state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means
|
||||
# ``MMModule.shared_weights=False``). Native NaDiT computes
|
||||
# per-block ``shared_weights = not (i < mm_layers)``, so to keep
|
||||
# every block non-shared we set ``mm_layers = num_layers``.
|
||||
# Without this, blocks at index >= mm_layers (default 10) try to
|
||||
# load ``blocks.N.*.all.*`` keys that don't exist in the file,
|
||||
# silently miss-load → all-black output.
|
||||
dit_config["mm_layers"] = 36
|
||||
dit_config["norm_eps"] = 1e-5
|
||||
dit_config["qk_rope"] = True
|
||||
dit_config["rope_type"] = "rope3d"
|
||||
dit_config["rope_dim"] = 64
|
||||
dit_config["mlp_type"] = "normal"
|
||||
return dit_config
|
||||
elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 3072
|
||||
dit_config["heads"] = 24
|
||||
dit_config["num_layers"] = 36
|
||||
# This checkpoint layout carries shared ``all.`` MMModule keys.
|
||||
# Preserve the historical split: the initial blocks use separate
|
||||
# vid/txt modules, later blocks use shared modules.
|
||||
dit_config["mm_layers"] = 10
|
||||
dit_config["norm_eps"] = 1e-5
|
||||
dit_config["qk_rope"] = True
|
||||
dit_config["mlp_type"] = "swiglu"
|
||||
return dit_config
|
||||
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 2560
|
||||
dit_config["heads"] = 20
|
||||
dit_config["num_layers"] = 32
|
||||
dit_config["norm_eps"] = 1.0e-05
|
||||
dit_config["qk_rope"] = None
|
||||
dit_config["mlp_type"] = "swiglu"
|
||||
dit_config["vid_out_norm"] = True
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "wan2.1"
|
||||
|
||||
@ -44,7 +44,13 @@ def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None,
|
||||
is_empty = torch.count_nonzero(latent_image) == 0
|
||||
if is_empty:
|
||||
if latent_format.latent_channels != latent_image.shape[1]:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
preserves_collapsed_channels = (
|
||||
getattr(latent_format, "preserve_empty_channel_multiples", False)
|
||||
and latent_image.ndim == 4
|
||||
and latent_image.shape[1] % latent_format.latent_channels == 0
|
||||
)
|
||||
if not preserves_collapsed_channels:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
if downscale_ratio_spacial is not None:
|
||||
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
|
||||
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio
|
||||
|
||||
237
comfy/sd.py
237
comfy/sd.py
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import json
|
||||
import torch
|
||||
from enum import Enum
|
||||
@ -16,6 +17,7 @@ import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.seedvr.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.cogvideo.vae
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
@ -80,6 +82,36 @@ import comfy.latent_formats
|
||||
|
||||
import comfy.ldm.flux.redux
|
||||
|
||||
SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL = 160
|
||||
|
||||
|
||||
def _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w):
|
||||
output_t = max(1, (latent_t - 1) * 4 + 1)
|
||||
return output_t * latent_h * 8 * latent_w * 8
|
||||
|
||||
|
||||
def _seedvr2_vae_decode_memory_used(shape):
|
||||
if len(shape) == 5:
|
||||
candidates = []
|
||||
if shape[1] == 16:
|
||||
candidates.append((shape[2], shape[3], shape[4]))
|
||||
if shape[-1] == 16:
|
||||
candidates.append((shape[1], shape[2], shape[3]))
|
||||
if len(candidates) == 0:
|
||||
candidates.append((shape[2], shape[3], shape[4]))
|
||||
output_pixels = max(_seedvr2_vae_decode_output_pixels(*candidate) for candidate in candidates)
|
||||
elif len(shape) == 4:
|
||||
latent_t = max(1, (shape[1] + 15) // 16)
|
||||
latent_h, latent_w = shape[2], shape[3]
|
||||
output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w)
|
||||
else:
|
||||
latent_t, latent_h, latent_w = 1, shape[-2], shape[-1]
|
||||
output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w)
|
||||
# SeedVR2 decode performs full-frame LAB histogram matching: fp32 channels
|
||||
# plus int64 sort indices dominate peak memory, not the VAE weight dtype.
|
||||
return output_pixels * SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL
|
||||
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
@ -463,8 +495,10 @@ class CLIP:
|
||||
|
||||
class VAE:
|
||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd
|
||||
if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
if metadata is None or metadata.get("keep_diffusers_format") != "true":
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
if model_management.is_amd():
|
||||
VAE_KL_MEM_RATIO = 2.73
|
||||
@ -536,6 +570,20 @@ class VAE:
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
|
||||
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
|
||||
self.latent_channels = 16
|
||||
self.latent_dim = 3
|
||||
self.disable_offload = True
|
||||
self.memory_used_decode = lambda shape, dtype: _seedvr2_vae_decode_memory_used(shape)
|
||||
self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.crop_input = False
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
@ -663,6 +711,7 @@ class VAE:
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||
self.downscale_index_formula = (8, 32, 32)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
@ -992,6 +1041,40 @@ class VAE:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||
|
||||
def decode_tiled_seedvr2(self, samples, tile_x=32, tile_y=32, overlap=8, tile_t=16, overlap_t=4):
|
||||
sf_s = getattr(self.first_stage_model, "spatial_downsample_factor", 8)
|
||||
sf_t = getattr(self.first_stage_model, "temporal_downsample_factor", 4)
|
||||
if tile_t is None:
|
||||
tile_t = 16
|
||||
if overlap_t is None:
|
||||
overlap_t = 4
|
||||
if tile_t > 0:
|
||||
temporal_size = tile_t * sf_t
|
||||
temporal_overlap = max(0, overlap_t) * sf_t
|
||||
else:
|
||||
temporal_size = 0
|
||||
temporal_overlap = 0
|
||||
args = {
|
||||
"enable_tiling": True,
|
||||
"tile_size": (tile_y * sf_s, tile_x * sf_s),
|
||||
"tile_overlap": (overlap * sf_s, overlap * sf_s),
|
||||
"temporal_size": temporal_size,
|
||||
"temporal_overlap": temporal_overlap,
|
||||
}
|
||||
output = self.first_stage_model.decode(
|
||||
samples.to(self.vae_dtype).to(self.device),
|
||||
seedvr2_tiling=args,
|
||||
)
|
||||
return self.process_output(output.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
|
||||
|
||||
def _format_seedvr2_encoded_samples(self, samples):
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
|
||||
if samples.ndim == 4:
|
||||
samples = samples.unsqueeze(2)
|
||||
samples = samples.contiguous()
|
||||
samples = samples * 0.9152
|
||||
return samples
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
@ -1028,6 +1111,36 @@ class VAE:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
|
||||
def encode_tiled_seedvr2(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
if tile_y is None:
|
||||
tile_y = 512
|
||||
if tile_x is None:
|
||||
tile_x = 512
|
||||
if overlap is None:
|
||||
overlap_y = 64
|
||||
overlap_x = 64
|
||||
else:
|
||||
overlap_y = overlap
|
||||
overlap_x = overlap
|
||||
if tile_t is None:
|
||||
tile_t = 9999
|
||||
if overlap_t is None:
|
||||
overlap_t = 0
|
||||
overlap_y = min(overlap_y, max(0, tile_y - 8))
|
||||
overlap_x = min(overlap_x, max(0, tile_x - 8))
|
||||
self.first_stage_model.device = self.device
|
||||
x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device)
|
||||
output = comfy.ldm.seedvr.vae.tiled_vae(
|
||||
x,
|
||||
self.first_stage_model,
|
||||
tile_size=(tile_y, tile_x),
|
||||
tile_overlap=(overlap_y, overlap_x),
|
||||
temporal_size=tile_t,
|
||||
temporal_overlap=overlap_t,
|
||||
encode=True,
|
||||
)
|
||||
return output.to(device=self.output_device, dtype=self.vae_output_dtype())
|
||||
|
||||
def decode(self, samples_in, vae_options={}):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = None
|
||||
@ -1075,16 +1188,40 @@ class VAE:
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
# SeedVR2 latents arrive in 4D collapsed form ``(B, 16*T, H, W)``
|
||||
# downstream of ``SeedVR2Conditioning`` (which performs the
|
||||
# ``rearrange(b c t h w -> b (c t) h w)`` collapse). The
|
||||
# generic ``decode_tiled_`` would treat the channel dim as
|
||||
# spatial-only and crash on the collapsed (16, T) layout
|
||||
# under ``tiled_scale``'s mask broadcast; route SeedVR2 4D
|
||||
# latents to ``decode_tiled_seedvr2`` instead, whose wrapper
|
||||
# dispatch handles both 4D and 5D inputs.
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
else:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
|
||||
pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
else:
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
def decode_tiled(
|
||||
self,
|
||||
samples,
|
||||
tile_x=None,
|
||||
tile_y=None,
|
||||
overlap=None,
|
||||
tile_t=None,
|
||||
overlap_t=None,
|
||||
):
|
||||
self.throw_exception_if_invalid()
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
@ -1098,7 +1235,20 @@ class VAE:
|
||||
args["overlap"] = overlap
|
||||
|
||||
with model_management.cuda_device_context(self.device):
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper) and dims in (2, 3):
|
||||
seedvr2_args = {}
|
||||
if tile_x is not None:
|
||||
seedvr2_args["tile_x"] = tile_x
|
||||
if tile_y is not None:
|
||||
seedvr2_args["tile_y"] = tile_y
|
||||
if overlap is not None:
|
||||
seedvr2_args["overlap"] = overlap
|
||||
if tile_t is not None:
|
||||
seedvr2_args["tile_t"] = tile_t
|
||||
if overlap_t is not None:
|
||||
seedvr2_args["overlap_t"] = overlap_t
|
||||
output = self.decode_tiled_seedvr2(samples, **seedvr2_args)
|
||||
elif dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
@ -1140,6 +1290,8 @@ class VAE:
|
||||
else:
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
@ -1159,20 +1311,23 @@ class VAE:
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
|
||||
samples = self.encode_tiled_seedvr2(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
else:
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
return samples
|
||||
return self._format_seedvr2_encoded_samples(samples)
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
dims = self.latent_dim
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if dims == 3:
|
||||
if dims == 3 and pixel_samples.ndim < 5:
|
||||
if not self.not_video:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
else:
|
||||
@ -1196,22 +1351,47 @@ class VAE:
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
|
||||
seedvr2_args = {}
|
||||
if tile_x is not None:
|
||||
seedvr2_args["tile_x"] = tile_x
|
||||
else:
|
||||
seedvr2_args["tile_x"] = 512
|
||||
if tile_y is not None:
|
||||
seedvr2_args["tile_y"] = tile_y
|
||||
else:
|
||||
seedvr2_args["tile_y"] = 512
|
||||
if overlap is not None:
|
||||
seedvr2_args["overlap"] = overlap
|
||||
else:
|
||||
seedvr2_args["overlap"] = 64
|
||||
if tile_t is not None:
|
||||
seedvr2_args["tile_t"] = tile_t
|
||||
else:
|
||||
seedvr2_args["tile_t"] = 9999
|
||||
if overlap_t is not None:
|
||||
seedvr2_args["overlap_t"] = overlap_t
|
||||
else:
|
||||
seedvr2_args["overlap_t"] = 0
|
||||
samples = self.encode_tiled_seedvr2(pixel_samples, **seedvr2_args)
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
spatial_overlap = overlap if overlap is not None else 64
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, spatial_overlap, spatial_overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
|
||||
return samples
|
||||
return self._format_seedvr2_encoded_samples(samples)
|
||||
|
||||
def get_sd(self):
|
||||
return self.first_stage_model.state_dict()
|
||||
@ -1719,6 +1899,17 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
return (model, clip, vae)
|
||||
|
||||
|
||||
def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device):
|
||||
set_dtype = model_config.set_inference_dtype
|
||||
parameters = inspect.signature(set_dtype).parameters
|
||||
supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values())
|
||||
if supports_device:
|
||||
set_dtype(dtype, manual_cast_dtype, device=device)
|
||||
else:
|
||||
set_dtype(dtype, manual_cast_dtype)
|
||||
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||
@ -1826,7 +2017,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
|
||||
|
||||
if model_config.clip_vision_prefix is not None:
|
||||
if output_clipvision:
|
||||
@ -1967,7 +2158,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
|
||||
|
||||
if custom_operations is not None:
|
||||
model_config.custom_operations = custom_operations
|
||||
|
||||
@ -1536,6 +1536,35 @@ class Chroma(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||
|
||||
class SeedVR2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "seedvr2"
|
||||
}
|
||||
latent_format = comfy.latent_formats.SeedVR2
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
sampling_settings = {
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
|
||||
if (
|
||||
dtype == torch.float16
|
||||
and manual_cast_dtype is None
|
||||
and comfy.model_management.should_use_bf16(device)
|
||||
):
|
||||
manual_cast_dtype = torch.bfloat16
|
||||
super().set_inference_dtype(dtype, manual_cast_dtype, device=device)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SeedVR2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
unet_config = {
|
||||
"image_model": "chroma_radiance",
|
||||
@ -1855,7 +1884,6 @@ class LongCatImage(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class RT_DETR_v4(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "RT_DETR_v4",
|
||||
@ -2090,6 +2118,7 @@ models = [
|
||||
HiDream,
|
||||
HiDreamO1,
|
||||
Chroma,
|
||||
SeedVR2,
|
||||
ChromaRadiance,
|
||||
ACEStep,
|
||||
ACEStep15,
|
||||
|
||||
@ -115,7 +115,7 @@ class BASE:
|
||||
replace_prefix = {"": self.vae_key_prefix[0]}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
|
||||
self.unet_config['dtype'] = dtype
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user