changed vae handling

This commit is contained in:
S.S. 2025-10-10 04:26:55 +02:00 committed by GitHub
parent 81e4dac107
commit 20d7b6b3fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -271,9 +271,31 @@ class CLIP:
return self.patcher.get_key_patches()
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None, model_options={}):
if sd and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys():
sd = diffusers_convert.convert_vae_state_dict(sd)
load_device = model_options.get("load_device", model_management.vae_device())
offload_device = model_options.get("offload_device", model_management.vae_offload_device())
if dtype is None:
dtype = model_management.vae_dtype(load_device)
initial_device = model_options.get(
"initial_device",
model_management.text_encoder_initial_device(load_device, offload_device)
)
self.device = initial_device
self.dtype = dtype
self.model_options = model_options
logging.info(
f"VAE model load device: {load_device}, offload device: {offload_device}, "
f"current: {self.device}, dtype: {dtype}"
)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)