add debug log of cpu load

This commit is contained in:
strint 2025-10-17 16:33:14 +08:00
parent e5ff6a1b53
commit 5c3c6c02b2
4 changed files with 35 additions and 0 deletions

View File

@ -911,3 +911,15 @@ class UNetModel(nn.Module):
return self.id_predictor(h)
else:
return self.out(h)
def load_state_dict(self, state_dict, strict=True):
"""Override load_state_dict() to add logging"""
logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}")
# Call parent's load_state_dict method
result = super().load_state_dict(state_dict, strict=strict)
logging.info(f"UNetModel load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}")
return result

View File

@ -303,6 +303,8 @@ class BaseModel(torch.nn.Module):
logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}")
to_load = self.model_config.process_unet_state_dict(to_load)
logging.info(f"load model {self.model_config} weights process end")
# TODO(sf): to mmap
# diffusion_model is UNetModel
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
free_cpu_memory = get_free_memory(torch.device("cpu"))
logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
@ -384,6 +386,21 @@ class BaseModel(torch.nn.Module):
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
def to(self, *args, **kwargs):
"""Override to() to add custom device management logic"""
old_device = self.device if hasattr(self, 'device') else None
result = super().to(*args, **kwargs)
if len(args) > 0:
if isinstance(args[0], (torch.device, str)):
new_device = torch.device(args[0]) if isinstance(args[0], str) else args[0]
if 'device' in kwargs:
new_device = kwargs['device']
logging.info(f"BaseModel moved from {old_device} to {new_device}")
return result
def extra_conds_shapes(self, **kwargs):
return {}

View File

@ -486,6 +486,7 @@ class ModelPatcher:
return comfy.utils.get_attr(self.model, name)
def model_patches_to(self, device):
# TODO(sf): to mmap
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
@ -783,6 +784,8 @@ class ModelPatcher:
self.backup.clear()
if device_to is not None:
# TODO(sf): to mmap
# self.model is what module?
self.model.to(device_to)
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
@ -837,6 +840,8 @@ class ModelPatcher:
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
# TODO(sf): to mmap
# m is what module?
m.to(device_to)
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:

View File

@ -1321,6 +1321,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device()
logging.info(f"loader load model to offload device: {offload_device}")
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None:
weight_dtype = None