mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-03 10:10:20 +08:00
archive the model defined dtypes
Scan created models and save off the dtypes as defined by the model creation process. This is needed for assign=True, which will override the dtypes.
This commit is contained in:
parent
6e641d88ed
commit
4979c075c9
@ -148,6 +148,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.model_sampling = model_sampling(model_config, model_type)
|
self.model_sampling = model_sampling(model_config, model_type)
|
||||||
|
|
||||||
|
comfy.model_management.archive_model_dtypes(self.diffusion_model)
|
||||||
|
|
||||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||||
if self.adm_channels is None:
|
if self.adm_channels is None:
|
||||||
self.adm_channels = 0
|
self.adm_channels = 0
|
||||||
|
|||||||
@ -774,6 +774,11 @@ def cleanup_models_gc():
|
|||||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
||||||
|
|
||||||
|
|
||||||
|
def archive_model_dtypes(model):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
for param_name, param in module.named_parameters(recurse=False):
|
||||||
|
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
|
||||||
|
|
||||||
|
|
||||||
def cleanup_models():
|
def cleanup_models():
|
||||||
to_delete = []
|
to_delete = []
|
||||||
|
|||||||
@ -296,6 +296,8 @@ class disable_weight_init:
|
|||||||
self.weight = None
|
self.weight = None
|
||||||
self.bias = None
|
self.bias = None
|
||||||
self.comfy_need_lazy_init_bias=bias
|
self.comfy_need_lazy_init_bias=bias
|
||||||
|
self.weight_comfy_model_dtype = dtype
|
||||||
|
self.bias_comfy_model_dtype = dtype
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
strict, missing_keys, unexpected_keys, error_msgs):
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|||||||
@ -127,6 +127,8 @@ class CLIP:
|
|||||||
self.cond_stage_model.to(offload_device)
|
self.cond_stage_model.to(offload_device)
|
||||||
logging.warning("Had to shift TE back.")
|
logging.warning("Had to shift TE back.")
|
||||||
|
|
||||||
|
model_management.archive_model_dtypes(self.cond_stage_model)
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
#Match torch.float32 hardcode upcast in TE implemention
|
#Match torch.float32 hardcode upcast in TE implemention
|
||||||
@ -675,6 +677,8 @@ class VAE:
|
|||||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||||
self.first_stage_model = self.first_stage_model.eval()
|
self.first_stage_model = self.first_stage_model.eval()
|
||||||
|
|
||||||
|
model_management.archive_model_dtypes(self.first_stage_model)
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = model_management.vae_device()
|
device = model_management.vae_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user