This commit is contained in:
Xiaoyu Xu 2025-12-14 09:31:17 +01:00 committed by GitHub
commit 98e9e004cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -305,7 +305,9 @@ class BaseModel(torch.nn.Module):
to_load[k[len(unet_prefix):]] = sd.pop(k)
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
# assign=True will reuse the tensor storage in state dict, this will avoid copy and saving CPU memory
# when loading large models with mmap.
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))