mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-24 00:39:30 +08:00
Add SeedVR2 model support
This commit is contained in:
parent
6d18f4adac
commit
cd18c4460a
@ -779,6 +779,9 @@ class ACEAudio(LatentFormat):
|
|||||||
latent_channels = 8
|
latent_channels = 8
|
||||||
latent_dimensions = 2
|
latent_dimensions = 2
|
||||||
|
|
||||||
|
class SeedVR2(LatentFormat):
|
||||||
|
latent_channels = 16
|
||||||
|
|
||||||
class ACEAudio15(LatentFormat):
|
class ACEAudio15(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
|
|||||||
@ -22,7 +22,7 @@ def torch_cat_if_needed(xl, dim):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_timestep_embedding(timesteps, embedding_dim):
|
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
|
||||||
"""
|
"""
|
||||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||||
From Fairseq.
|
From Fairseq.
|
||||||
@ -33,11 +33,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
|||||||
assert len(timesteps.shape) == 1
|
assert len(timesteps.shape) == 1
|
||||||
|
|
||||||
half_dim = embedding_dim // 2
|
half_dim = embedding_dim // 2
|
||||||
emb = math.log(10000) / (half_dim - 1)
|
emb = math.log(10000) / (half_dim - downscale_freq_shift)
|
||||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||||
emb = emb.to(device=timesteps.device)
|
emb = emb.to(device=timesteps.device)
|
||||||
emb = timesteps.float()[:, None] * emb[None, :]
|
emb = timesteps.float()[:, None] * emb[None, :]
|
||||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||||
|
if flip_sin_to_cos:
|
||||||
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||||
if embedding_dim % 2 == 1: # zero pad
|
if embedding_dim % 2 == 1: # zero pad
|
||||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||||
return emb
|
return emb
|
||||||
|
|||||||
77
comfy/ldm/seedvr/attention.py
Normal file
77
comfy/ldm/seedvr/attention.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.ldm.modules import attention as _attention
|
||||||
|
|
||||||
|
|
||||||
|
def _var_attention_qkv(q, k, v, heads, skip_reshape):
|
||||||
|
if skip_reshape:
|
||||||
|
return q, k, v, q.shape[-1]
|
||||||
|
total_tokens, embed_dim = q.shape
|
||||||
|
head_dim = embed_dim // heads
|
||||||
|
return (
|
||||||
|
q.view(total_tokens, heads, head_dim),
|
||||||
|
k.view(k.shape[0], heads, head_dim),
|
||||||
|
v.view(v.shape[0], heads, head_dim),
|
||||||
|
head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
|
||||||
|
if skip_output_reshape:
|
||||||
|
return out
|
||||||
|
return out.reshape(-1, heads * head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
|
||||||
|
if cu_seqlens.dtype not in (torch.int32, torch.int64):
|
||||||
|
raise ValueError(f"{name} must use an integer dtype")
|
||||||
|
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
|
||||||
|
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
|
||||||
|
if cu_seqlens[0].item() != 0:
|
||||||
|
raise ValueError(f"{name} must start at 0")
|
||||||
|
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
|
||||||
|
raise ValueError(f"{name} must be strictly increasing")
|
||||||
|
if cu_seqlens[-1].item() != token_count:
|
||||||
|
raise ValueError(f"{name} does not match token count")
|
||||||
|
|
||||||
|
|
||||||
|
def _split_indices(cu_seqlens):
|
||||||
|
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
|
||||||
|
|
||||||
|
|
||||||
|
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||||
|
|
||||||
|
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
|
||||||
|
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
|
||||||
|
if cu_seqlens_k[-1].item() != v.shape[0]:
|
||||||
|
raise ValueError("cu_seqlens_k does not match v token count")
|
||||||
|
|
||||||
|
q_split_indices = _split_indices(cu_seqlens_q)
|
||||||
|
k_split_indices = _split_indices(cu_seqlens_k)
|
||||||
|
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
|
||||||
|
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
|
||||||
|
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
|
||||||
|
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
|
||||||
|
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
|
||||||
|
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
|
||||||
|
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
|
||||||
|
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
|
||||||
|
out_dtype = q_i.dtype
|
||||||
|
if _attention.optimized_attention is _attention.attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
|
||||||
|
q_i = q_i.to(torch.bfloat16)
|
||||||
|
k_i = k_i.to(torch.bfloat16)
|
||||||
|
v_i = v_i.to(torch.bfloat16)
|
||||||
|
out_i = _attention.optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
if out_i.dtype != out_dtype:
|
||||||
|
out_i = out_i.to(out_dtype)
|
||||||
|
out.append(out_i.squeeze(0).permute(1, 0, 2))
|
||||||
|
|
||||||
|
out = torch.cat(out, dim=0)
|
||||||
|
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||||
|
|
||||||
|
|
||||||
|
optimized_var_attention = var_attention_optimized_split
|
||||||
72
comfy/ldm/seedvr/constants.py
Normal file
72
comfy/ldm/seedvr/constants.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
"""Named constants for the SeedVR2 integration, grouped by provenance.
|
||||||
|
|
||||||
|
Provenance prefixes:
|
||||||
|
- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline.
|
||||||
|
- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites
|
||||||
|
the upstream config/source path it was lifted from.
|
||||||
|
- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature /
|
||||||
|
ISO / CIE values; cite the standard.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment)
|
||||||
|
# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN)
|
||||||
|
# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070
|
||||||
|
# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT).
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB)
|
||||||
|
SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
# B. Fork heuristics (SEEDVR2 - this integration)
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim.
|
||||||
|
# (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.)
|
||||||
|
SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry.
|
||||||
|
SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case).
|
||||||
|
SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM.
|
||||||
|
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk.
|
||||||
|
SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels).
|
||||||
|
SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16).
|
||||||
|
|
||||||
|
# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing)
|
||||||
|
SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk.
|
||||||
|
SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path.
|
||||||
|
SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path.
|
||||||
|
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path.
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR)
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm.
|
||||||
|
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.
|
||||||
|
BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem).
|
||||||
|
BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem).
|
||||||
|
BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28.
|
||||||
|
BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28.
|
||||||
|
BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16).
|
||||||
|
BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32).
|
||||||
|
BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11.
|
||||||
|
BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size).
|
||||||
|
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor.
|
||||||
|
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor).
|
||||||
|
BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling).
|
||||||
|
BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames).
|
||||||
|
BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency).
|
||||||
|
BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim).
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
# D. Published standards (cite the literature)
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864.
|
||||||
|
|
||||||
|
# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65).
|
||||||
|
CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta).
|
||||||
|
CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa).
|
||||||
|
D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1).
|
||||||
|
D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn.
|
||||||
|
WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR).
|
||||||
|
|
||||||
|
# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and
|
||||||
|
# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the
|
||||||
|
# exact existing coefficients move verbatim rather than being retyped here.
|
||||||
1487
comfy/ldm/seedvr/model.py
Normal file
1487
comfy/ldm/seedvr/model.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -54,6 +54,7 @@ import comfy.ldm.pixeldit.model
|
|||||||
import comfy.ldm.pixeldit.pid
|
import comfy.ldm.pixeldit.pid
|
||||||
import comfy.ldm.ace.model
|
import comfy.ldm.ace.model
|
||||||
import comfy.ldm.omnigen.omnigen2
|
import comfy.ldm.omnigen.omnigen2
|
||||||
|
import comfy.ldm.seedvr.model
|
||||||
import comfy.ldm.qwen_image.model
|
import comfy.ldm.qwen_image.model
|
||||||
import comfy.ldm.ideogram4.model
|
import comfy.ldm.ideogram4.model
|
||||||
import comfy.ldm.kandinsky5.model
|
import comfy.ldm.kandinsky5.model
|
||||||
@ -929,6 +930,16 @@ class HunyuanDiT(BaseModel):
|
|||||||
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class SeedVR2(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
condition = kwargs.get("condition", None)
|
||||||
|
if condition is not None:
|
||||||
|
out["condition"] = comfy.conds.CONDRegular(condition)
|
||||||
|
return out
|
||||||
|
|
||||||
class PixArt(BaseModel):
|
class PixArt(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
|
||||||
|
|||||||
@ -598,6 +598,53 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "seedvr2"
|
||||||
|
dit_config["vid_dim"] = 3072
|
||||||
|
dit_config["heads"] = 24
|
||||||
|
dit_config["num_layers"] = 36
|
||||||
|
# 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.``
|
||||||
|
# submodules) at EVERY block — verified by inspecting the 7B
|
||||||
|
# state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means
|
||||||
|
# ``MMModule.shared_weights=False``). Native NaDiT computes
|
||||||
|
# per-block ``shared_weights = not (i < mm_layers)``, so to keep
|
||||||
|
# every block non-shared we set ``mm_layers = num_layers``.
|
||||||
|
# Without this, blocks at index >= mm_layers (default 10) try to
|
||||||
|
# load ``blocks.N.*.all.*`` keys that don't exist in the file,
|
||||||
|
# silently miss-load → all-black output.
|
||||||
|
dit_config["mm_layers"] = 36
|
||||||
|
dit_config["norm_eps"] = 1e-5
|
||||||
|
dit_config["rope_type"] = "rope3d"
|
||||||
|
dit_config["rope_dim"] = 64
|
||||||
|
dit_config["mlp_type"] = "normal"
|
||||||
|
return dit_config
|
||||||
|
elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "seedvr2"
|
||||||
|
dit_config["vid_dim"] = 3072
|
||||||
|
dit_config["heads"] = 24
|
||||||
|
dit_config["num_layers"] = 36
|
||||||
|
# This checkpoint layout carries shared ``all.`` MMModule keys.
|
||||||
|
# Preserve the historical split: the initial blocks use separate
|
||||||
|
# vid/txt modules, later blocks use shared modules.
|
||||||
|
dit_config["mm_layers"] = 10
|
||||||
|
dit_config["norm_eps"] = 1e-5
|
||||||
|
dit_config["rope_type"] = "rope3d"
|
||||||
|
dit_config["rope_dim"] = 64
|
||||||
|
dit_config["mlp_type"] = "swiglu"
|
||||||
|
return dit_config
|
||||||
|
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "seedvr2"
|
||||||
|
dit_config["vid_dim"] = 2560
|
||||||
|
dit_config["heads"] = 20
|
||||||
|
dit_config["num_layers"] = 32
|
||||||
|
dit_config["norm_eps"] = 1.0e-05
|
||||||
|
dit_config["mlp_type"] = "swiglu"
|
||||||
|
dit_config["vid_out_norm"] = True
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "wan2.1"
|
dit_config["image_model"] = "wan2.1"
|
||||||
|
|||||||
@ -1683,6 +1683,35 @@ class Chroma(supported_models_base.BASE):
|
|||||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||||
|
|
||||||
|
class SeedVR2(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "seedvr2"
|
||||||
|
}
|
||||||
|
latent_format = comfy.latent_formats.SeedVR2
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
|
||||||
|
if (
|
||||||
|
dtype == torch.float16
|
||||||
|
and manual_cast_dtype is None
|
||||||
|
and comfy.model_management.should_use_bf16(device)
|
||||||
|
):
|
||||||
|
manual_cast_dtype = torch.bfloat16
|
||||||
|
super().set_inference_dtype(dtype, manual_cast_dtype, device=device)
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.SeedVR2(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
return None
|
||||||
|
|
||||||
class ChromaRadiance(Chroma):
|
class ChromaRadiance(Chroma):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "chroma_radiance",
|
"image_model": "chroma_radiance",
|
||||||
@ -2296,6 +2325,7 @@ models = [
|
|||||||
HiDream,
|
HiDream,
|
||||||
HiDreamO1,
|
HiDreamO1,
|
||||||
Chroma,
|
Chroma,
|
||||||
|
SeedVR2,
|
||||||
ChromaRadiance,
|
ChromaRadiance,
|
||||||
ACEStep,
|
ACEStep,
|
||||||
ACEStep15,
|
ACEStep15,
|
||||||
|
|||||||
@ -115,7 +115,7 @@ class BASE:
|
|||||||
replace_prefix = {"": self.vae_key_prefix[0]}
|
replace_prefix = {"": self.vae_key_prefix[0]}
|
||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
|
||||||
self.unet_config['dtype'] = dtype
|
self.unet_config['dtype'] = dtype
|
||||||
self.manual_cast_dtype = manual_cast_dtype
|
self.manual_cast_dtype = manual_cast_dtype
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user