From 6583cc0142466473922a59d2e646881693cff011 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 17 Oct 2025 18:28:25 +0800 Subject: [PATCH] debug load mem --- comfy/ldm/flux/model.py | 13 +++++++++++++ comfy/ldm/modules/diffusionmodules/openaimodel.py | 1 + comfy/model_base.py | 2 ++ comfy/sd.py | 4 ++++ comfy/utils.py | 6 ++++++ 5 files changed, 26 insertions(+) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 14f90cea5..263cdae26 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -7,6 +7,7 @@ from torch import Tensor, nn from einops import rearrange, repeat import comfy.ldm.common_dit import comfy.patcher_extension +import logging from .layers import ( DoubleStreamBlock, @@ -278,3 +279,15 @@ class Flux(nn.Module): out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] + + def load_state_dict(self, state_dict, strict=True): + import pdb; pdb.set_trace() + """Override load_state_dict() to add logging""" + logging.info(f"Flux load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") + + # Call parent's load_state_dict method + result = super().load_state_dict(state_dict, strict=strict) + + logging.info(f"Flux load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}") + + return result diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index ff6e96a3c..e847700c6 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -914,6 +914,7 @@ class UNetModel(nn.Module): def load_state_dict(self, state_dict, strict=True): + import pdb; pdb.set_trace() """Override load_state_dict() to add logging""" logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") diff --git a/comfy/model_base.py b/comfy/model_base.py index 7d474a76a..34dd16037 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -305,6 +305,8 @@ class BaseModel(torch.nn.Module): logging.info(f"load model {self.model_config} weights process end") # TODO(sf): to mmap # diffusion_model is UNetModel + import pdb; pdb.set_trace() + # TODO(sf): here needs to avoid load mmap into cpu mem m, u = self.diffusion_model.load_state_dict(to_load, strict=False) free_cpu_memory = get_free_memory(torch.device("cpu")) logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") diff --git a/comfy/sd.py b/comfy/sd.py index 89a1f30b8..7005a1b53 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1338,6 +1338,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") + import pdb; pdb.set_trace() model = model.to(offload_device) model.load_model_weights(new_sd, "") left_over = sd.keys() @@ -1347,10 +1348,13 @@ def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model(unet_path, model_options={}): + # TODO(sf): here load file into mem sd = comfy.utils.load_torch_file(unet_path) logging.info(f"load model start, path {unet_path}") + import pdb; pdb.set_trace() model = load_diffusion_model_state_dict(sd, model_options=model_options) logging.info(f"load model end, path {unet_path}") + import pdb; pdb.set_trace() if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) diff --git a/comfy/utils.py b/comfy/utils.py index 0fd03f165..a66402451 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -55,11 +55,15 @@ else: logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): + # TODO(sf): here load file into mmap + logging.info(f"load_torch_file start, ckpt={ckpt}, safe_load={safe_load}, device={device}, return_metadata={return_metadata}") if device is None: device = torch.device("cpu") metadata = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: + if not DISABLE_MMAP: + logging.info(f"load_torch_file safetensors mmap True") with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: sd = {} for k in f.keys(): @@ -80,6 +84,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): else: torch_args = {} if MMAP_TORCH_FILES: + logging.info(f"load_torch_file mmap True") torch_args["mmap"] = True if safe_load or ALWAYS_SAFE_LOAD: @@ -97,6 +102,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): sd = pl_sd else: sd = pl_sd + import pdb; pdb.set_trace() return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None):