From 5c3c6c02b237b3728348f90567b7236cfc45b8b7 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 17 Oct 2025 16:33:14 +0800 Subject: [PATCH] add debug log of cpu load --- .../ldm/modules/diffusionmodules/openaimodel.py | 12 ++++++++++++ comfy/model_base.py | 17 +++++++++++++++++ comfy/model_patcher.py | 5 +++++ comfy/sd.py | 1 + 4 files changed, 35 insertions(+) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4c8d53cac..ff6e96a3c 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -911,3 +911,15 @@ class UNetModel(nn.Module): return self.id_predictor(h) else: return self.out(h) + + + def load_state_dict(self, state_dict, strict=True): + """Override load_state_dict() to add logging""" + logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") + + # Call parent's load_state_dict method + result = super().load_state_dict(state_dict, strict=strict) + + logging.info(f"UNetModel load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}") + + return result diff --git a/comfy/model_base.py b/comfy/model_base.py index b0bb0cfb0..7d474a76a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -303,6 +303,8 @@ class BaseModel(torch.nn.Module): logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}") to_load = self.model_config.process_unet_state_dict(to_load) logging.info(f"load model {self.model_config} weights process end") + # TODO(sf): to mmap + # diffusion_model is UNetModel m, u = self.diffusion_model.load_state_dict(to_load, strict=False) free_cpu_memory = get_free_memory(torch.device("cpu")) logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") @@ -384,6 +386,21 @@ class BaseModel(torch.nn.Module): #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024) + + def to(self, *args, **kwargs): + """Override to() to add custom device management logic""" + old_device = self.device if hasattr(self, 'device') else None + + result = super().to(*args, **kwargs) + + if len(args) > 0: + if isinstance(args[0], (torch.device, str)): + new_device = torch.device(args[0]) if isinstance(args[0], str) else args[0] + if 'device' in kwargs: + new_device = kwargs['device'] + + logging.info(f"BaseModel moved from {old_device} to {new_device}") + return result def extra_conds_shapes(self, **kwargs): return {} diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c0b68fb8c..ea91bd2c5 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -486,6 +486,7 @@ class ModelPatcher: return comfy.utils.get_attr(self.model, name) def model_patches_to(self, device): + # TODO(sf): to mmap to = self.model_options["transformer_options"] if "patches" in to: patches = to["patches"] @@ -783,6 +784,8 @@ class ModelPatcher: self.backup.clear() if device_to is not None: + # TODO(sf): to mmap + # self.model is what module? self.model.to(device_to) self.model.device = device_to self.model.model_loaded_weight_memory = 0 @@ -837,6 +840,8 @@ class ModelPatcher: bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights + # TODO(sf): to mmap + # m is what module? m.to(device_to) module_mem += move_weight_functions(m, device_to) if lowvram_possible: diff --git a/comfy/sd.py b/comfy/sd.py index 16d54f08b..89a1f30b8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1321,6 +1321,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): logging.warning("{} {}".format(diffusers_keys[k], k)) offload_device = model_management.unet_offload_device() + logging.info(f"loader load model to offload device: {offload_device}") unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.scaled_fp8 is not None: weight_dtype = None