mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 20:10:48 +08:00
add debug log of cpu load
This commit is contained in:
parent
e5ff6a1b53
commit
5c3c6c02b2
@ -911,3 +911,15 @@ class UNetModel(nn.Module):
|
||||
return self.id_predictor(h)
|
||||
else:
|
||||
return self.out(h)
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
"""Override load_state_dict() to add logging"""
|
||||
logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}")
|
||||
|
||||
# Call parent's load_state_dict method
|
||||
result = super().load_state_dict(state_dict, strict=strict)
|
||||
|
||||
logging.info(f"UNetModel load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}")
|
||||
|
||||
return result
|
||||
|
||||
@ -303,6 +303,8 @@ class BaseModel(torch.nn.Module):
|
||||
logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}")
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
logging.info(f"load model {self.model_config} weights process end")
|
||||
# TODO(sf): to mmap
|
||||
# diffusion_model is UNetModel
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
free_cpu_memory = get_free_memory(torch.device("cpu"))
|
||||
logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
|
||||
@ -384,6 +386,21 @@ class BaseModel(torch.nn.Module):
|
||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Override to() to add custom device management logic"""
|
||||
old_device = self.device if hasattr(self, 'device') else None
|
||||
|
||||
result = super().to(*args, **kwargs)
|
||||
|
||||
if len(args) > 0:
|
||||
if isinstance(args[0], (torch.device, str)):
|
||||
new_device = torch.device(args[0]) if isinstance(args[0], str) else args[0]
|
||||
if 'device' in kwargs:
|
||||
new_device = kwargs['device']
|
||||
|
||||
logging.info(f"BaseModel moved from {old_device} to {new_device}")
|
||||
return result
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
return {}
|
||||
|
||||
@ -486,6 +486,7 @@ class ModelPatcher:
|
||||
return comfy.utils.get_attr(self.model, name)
|
||||
|
||||
def model_patches_to(self, device):
|
||||
# TODO(sf): to mmap
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
patches = to["patches"]
|
||||
@ -783,6 +784,8 @@ class ModelPatcher:
|
||||
self.backup.clear()
|
||||
|
||||
if device_to is not None:
|
||||
# TODO(sf): to mmap
|
||||
# self.model is what module?
|
||||
self.model.to(device_to)
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
@ -837,6 +840,8 @@ class ModelPatcher:
|
||||
bias_key = "{}.bias".format(n)
|
||||
if move_weight:
|
||||
cast_weight = self.force_cast_weights
|
||||
# TODO(sf): to mmap
|
||||
# m is what module?
|
||||
m.to(device_to)
|
||||
module_mem += move_weight_functions(m, device_to)
|
||||
if lowvram_possible:
|
||||
|
||||
@ -1321,6 +1321,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
logging.info(f"loader load model to offload device: {offload_device}")
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.scaled_fp8 is not None:
|
||||
weight_dtype = None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user