This commit is contained in:
strint 2025-10-16 22:25:17 +08:00
parent c1eac555c0
commit 9352987e9b

View File

@ -60,6 +60,7 @@ import math
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher from comfy.model_patcher import ModelPatcher
from comfy.model_management import get_free_memory
class ModelType(Enum): class ModelType(Enum):
EPS = 1 EPS = 1
@ -291,18 +292,19 @@ class BaseModel(torch.nn.Module):
return out return out
def load_model_weights(self, sd, unet_prefix=""): def load_model_weights(self, sd, unet_prefix=""):
import pdb; pdb.set_trace()
to_load = {} to_load = {}
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
if k.startswith(unet_prefix): if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k) to_load[k[len(unet_prefix):]] = sd.pop(k)
logging.info(f"load model weights start, keys {keys}") free_cpu_memory = get_free_memory(torch.device("cpu"))
logging.info(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
to_load = self.model_config.process_unet_state_dict(to_load) to_load = self.model_config.process_unet_state_dict(to_load)
logging.info(f"load model {self.model_config} weights process end, keys {keys}") logging.info(f"load model {self.model_config} weights process end")
m, u = self.diffusion_model.load_state_dict(to_load, strict=False) m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
logging.info(f"load model {self.model_config} weights end, keys {keys}") free_cpu_memory = get_free_memory(torch.device("cpu"))
logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
if len(m) > 0: if len(m) > 0:
logging.warning("unet missing: {}".format(m)) logging.warning("unet missing: {}".format(m))