mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 08:52:34 +08:00
changed vae handling
This commit is contained in:
parent
81e4dac107
commit
20d7b6b3fb
26
comfy/sd.py
26
comfy/sd.py
@ -271,9 +271,31 @@ class CLIP:
|
|||||||
return self.patcher.get_key_patches()
|
return self.patcher.get_key_patches()
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None, model_options={}):
|
||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
if sd and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys():
|
||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
|
|
||||||
|
load_device = model_options.get("load_device", model_management.vae_device())
|
||||||
|
offload_device = model_options.get("offload_device", model_management.vae_offload_device())
|
||||||
|
|
||||||
|
if dtype is None:
|
||||||
|
dtype = model_management.vae_dtype(load_device)
|
||||||
|
|
||||||
|
initial_device = model_options.get(
|
||||||
|
"initial_device",
|
||||||
|
model_management.text_encoder_initial_device(load_device, offload_device)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.device = initial_device
|
||||||
|
self.dtype = dtype
|
||||||
|
self.model_options = model_options
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"VAE model load device: {load_device}, offload device: {offload_device}, "
|
||||||
|
f"current: {self.device}, dtype: {dtype}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
||||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user