archive the model defined dtypes

Scan created models and save off the dtypes as defined by the model
creation process. This is needed for assign=True, which will override
the dtypes.
This commit is contained in:
Rattus 2026-01-22 00:00:33 +10:00
parent 6e641d88ed
commit 4979c075c9
4 changed files with 13 additions and 0 deletions

View File

@ -148,6 +148,8 @@ class BaseModel(torch.nn.Module):
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
comfy.model_management.archive_model_dtypes(self.diffusion_model)
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0

View File

@ -774,6 +774,11 @@ def cleanup_models_gc():
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
def archive_model_dtypes(model):
for name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
def cleanup_models():
to_delete = []

View File

@ -296,6 +296,8 @@ class disable_weight_init:
self.weight = None
self.bias = None
self.comfy_need_lazy_init_bias=bias
self.weight_comfy_model_dtype = dtype
self.bias_comfy_model_dtype = dtype
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):

View File

@ -127,6 +127,8 @@ class CLIP:
self.cond_stage_model.to(offload_device)
logging.warning("Had to shift TE back.")
model_management.archive_model_dtypes(self.cond_stage_model)
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
@ -675,6 +677,8 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
model_management.archive_model_dtypes(self.first_stage_model)
if device is None:
device = model_management.vae_device()
self.device = device