mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 20:10:48 +08:00
add debug log of cpu load
This commit is contained in:
parent
e5ff6a1b53
commit
5c3c6c02b2
@ -911,3 +911,15 @@ class UNetModel(nn.Module):
|
|||||||
return self.id_predictor(h)
|
return self.id_predictor(h)
|
||||||
else:
|
else:
|
||||||
return self.out(h)
|
return self.out(h)
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict, strict=True):
|
||||||
|
"""Override load_state_dict() to add logging"""
|
||||||
|
logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}")
|
||||||
|
|
||||||
|
# Call parent's load_state_dict method
|
||||||
|
result = super().load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
|
logging.info(f"UNetModel load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|||||||
@ -303,6 +303,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}")
|
logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}")
|
||||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||||
logging.info(f"load model {self.model_config} weights process end")
|
logging.info(f"load model {self.model_config} weights process end")
|
||||||
|
# TODO(sf): to mmap
|
||||||
|
# diffusion_model is UNetModel
|
||||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||||
free_cpu_memory = get_free_memory(torch.device("cpu"))
|
free_cpu_memory = get_free_memory(torch.device("cpu"))
|
||||||
logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
|
logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
|
||||||
@ -384,6 +386,21 @@ class BaseModel(torch.nn.Module):
|
|||||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||||
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||||
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
|
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
"""Override to() to add custom device management logic"""
|
||||||
|
old_device = self.device if hasattr(self, 'device') else None
|
||||||
|
|
||||||
|
result = super().to(*args, **kwargs)
|
||||||
|
|
||||||
|
if len(args) > 0:
|
||||||
|
if isinstance(args[0], (torch.device, str)):
|
||||||
|
new_device = torch.device(args[0]) if isinstance(args[0], str) else args[0]
|
||||||
|
if 'device' in kwargs:
|
||||||
|
new_device = kwargs['device']
|
||||||
|
|
||||||
|
logging.info(f"BaseModel moved from {old_device} to {new_device}")
|
||||||
|
return result
|
||||||
|
|
||||||
def extra_conds_shapes(self, **kwargs):
|
def extra_conds_shapes(self, **kwargs):
|
||||||
return {}
|
return {}
|
||||||
|
|||||||
@ -486,6 +486,7 @@ class ModelPatcher:
|
|||||||
return comfy.utils.get_attr(self.model, name)
|
return comfy.utils.get_attr(self.model, name)
|
||||||
|
|
||||||
def model_patches_to(self, device):
|
def model_patches_to(self, device):
|
||||||
|
# TODO(sf): to mmap
|
||||||
to = self.model_options["transformer_options"]
|
to = self.model_options["transformer_options"]
|
||||||
if "patches" in to:
|
if "patches" in to:
|
||||||
patches = to["patches"]
|
patches = to["patches"]
|
||||||
@ -783,6 +784,8 @@ class ModelPatcher:
|
|||||||
self.backup.clear()
|
self.backup.clear()
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
|
# TODO(sf): to mmap
|
||||||
|
# self.model is what module?
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
@ -837,6 +840,8 @@ class ModelPatcher:
|
|||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
if move_weight:
|
if move_weight:
|
||||||
cast_weight = self.force_cast_weights
|
cast_weight = self.force_cast_weights
|
||||||
|
# TODO(sf): to mmap
|
||||||
|
# m is what module?
|
||||||
m.to(device_to)
|
m.to(device_to)
|
||||||
module_mem += move_weight_functions(m, device_to)
|
module_mem += move_weight_functions(m, device_to)
|
||||||
if lowvram_possible:
|
if lowvram_possible:
|
||||||
|
|||||||
@ -1321,6 +1321,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
|||||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
|
logging.info(f"loader load model to offload device: {offload_device}")
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if model_config.scaled_fp8 is not None:
|
if model_config.scaled_fp8 is not None:
|
||||||
weight_dtype = None
|
weight_dtype = None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user