diff --git a/comfy/model_base.py b/comfy/model_base.py index 8aeb057f1..85acdb66a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -149,6 +149,8 @@ class BaseModel(torch.nn.Module): self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) + comfy.model_management.archive_model_dtypes(self.diffusion_model) + self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 diff --git a/comfy/model_management.py b/comfy/model_management.py index 4a3a0f886..cdb9542c0 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -774,6 +774,11 @@ def cleanup_models_gc(): logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) +def archive_model_dtypes(model): + for name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + setattr(module, f"{param_name}_comfy_model_dtype", param.dtype) + def cleanup_models(): to_delete = [] diff --git a/comfy/ops.py b/comfy/ops.py index ee8b32f18..31bcc8a77 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -296,6 +296,8 @@ class disable_weight_init: self.weight = None self.bias = None self.comfy_need_lazy_init_bias=bias + self.weight_comfy_model_dtype = dtype + self.bias_comfy_model_dtype = dtype def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): diff --git a/comfy/sd.py b/comfy/sd.py index 42b2fd6df..28c1ee062 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -128,6 +128,8 @@ class CLIP: self.cond_stage_model.to(offload_device) logging.warning("Had to shift TE back.") + model_management.archive_model_dtypes(self.cond_stage_model) + self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) #Match torch.float32 hardcode upcast in TE implemention @@ -675,6 +677,8 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() + model_management.archive_model_dtypes(self.first_stage_model) + if device is None: device = model_management.vae_device() self.device = device