From 6e33ee391abfa8c118688ba4f803a47d7e851132 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 16:45:08 +0800 Subject: [PATCH 01/35] debug error --- comfy/model_management.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index d82d5b8b0..e0a097761 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -509,12 +509,29 @@ class LoadedModel: return False def model_unload(self, memory_to_free=None, unpatch_weights=True): - if memory_to_free is not None: - if memory_to_free < self.model.loaded_size(): - freed = self.model.partially_unload(self.model.offload_device, memory_to_free) - if freed >= memory_to_free: - return False - self.model.detach(unpatch_weights) + logging.info(f"model_unload: {self.model.model.__class__.__name__}") + logging.info(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB") + logging.info(f"unpatch_weights: {unpatch_weights}") + logging.info(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB") + logging.info(f"offload_device: {self.model.offload_device}") + available_memory = get_free_memory(self.model.offload_device) + logging.info(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + try: + if memory_to_free is not None: + if memory_to_free < self.model.loaded_size(): + logging.info("Do partially unload") + freed = self.model.partially_unload(self.model.offload_device, memory_to_free) + logging.info(f"partially_unload freed: {freed/(1024*1024*1024)} GB") + if freed >= memory_to_free: + return False + logging.info("Do full unload") + self.model.detach(unpatch_weights) + logging.info("Do full unload done") + except Exception as e: + logging.error(f"Error in model_unload: {e}") + available_memory = get_free_memory(self.model.offload_device) + logging.info(f"after error, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + return False self.model_finalizer.detach() self.model_finalizer = None self.real_model = None @@ -567,6 +584,7 @@ def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() def free_memory(memory_required, device, keep_loaded=[]): + logging.info("start to free mem") cleanup_models_gc() unloaded_model = [] can_unload = [] @@ -587,7 +605,7 @@ def free_memory(memory_required, device, keep_loaded=[]): if free_mem > memory_required: break memory_to_free = memory_required - free_mem - logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") + logging.info(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") if current_loaded_models[i].model_unload(memory_to_free): unloaded_model.append(i) @@ -604,6 +622,7 @@ def free_memory(memory_required, device, keep_loaded=[]): return unloaded_models def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): + logging.info(f"start to load models") cleanup_models_gc() global vram_state @@ -625,6 +644,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models_to_load = [] for x in models: + logging.info(f"loading model: {x.model.__class__.__name__}") loaded_model = LoadedModel(x) try: loaded_model_index = current_loaded_models.index(loaded_model) From fa19dd46200e5708f0e17e24622939257bfcffca Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 17:00:47 +0800 Subject: [PATCH 02/35] debug offload --- comfy/model_management.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index e0a097761..840239a27 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -516,6 +516,9 @@ class LoadedModel: logging.info(f"offload_device: {self.model.offload_device}") available_memory = get_free_memory(self.model.offload_device) logging.info(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + if available_memory < memory_to_free: + logging.error(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Required: {memory_to_free/(1024*1024*1024)} GB") + return False try: if memory_to_free is not None: if memory_to_free < self.model.loaded_size(): From f40e00cb357754ae99a2eac59d8fbfae2f23607a Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 19:38:13 +0800 Subject: [PATCH 03/35] add detail debug --- execution.py | 26 ++++++++++++++++++++++++++ server.py | 2 +- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index 1dc35738b..69bd53502 100644 --- a/execution.py +++ b/execution.py @@ -400,7 +400,12 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + + # Log node execution start + logging.info(f"๐Ÿ“ Node [{display_node_id}] START: {class_type}") + if caches.outputs.get(unique_id) is not None: + logging.info(f"โœ… Node [{display_node_id}] CACHED: {class_type} (using cached output)") if server.client_id is not None: cached_output = caches.ui.get(unique_id) or {} server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) @@ -446,15 +451,20 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) + logging.info(f"๐Ÿ” Node [{display_node_id}] Getting input data for {class_type}") input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) + logging.info(f"๐Ÿ“ฅ Node [{display_node_id}] Input data ready, keys: {list(input_data_all.keys())}") if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) obj = caches.objects.get(unique_id) if obj is None: + logging.info(f"๐Ÿ—๏ธ Node [{display_node_id}] Creating new instance of {class_type}") obj = class_def() caches.objects.set(unique_id, obj) + else: + logging.info(f"โ™ป๏ธ Node [{display_node_id}] Reusing cached instance of {class_type}") if issubclass(class_def, _ComfyNodeInternal): lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None @@ -493,7 +503,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) + logging.info(f"โš™๏ธ Node [{display_node_id}] Executing {class_type}") output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) + logging.info(f"๐Ÿ“ค Node [{display_node_id}] Execution completed, has_subgraph: {has_subgraph}, has_pending: {has_pending_tasks}") if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) @@ -572,6 +584,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, for name, inputs in input_data_all.items(): input_data_formatted[name] = [format_value(x) for x in inputs] + logging.error(f"โŒ Node [{display_node_id}] FAILED: {class_type}") logging.error(f"!!! Exception during processing !!! {ex}") logging.error(traceback.format_exc()) tips = "" @@ -593,6 +606,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, get_progress_state().finish_progress(unique_id) executed.add(unique_id) + + logging.info(f"โœ… Node [{display_node_id}] SUCCESS: {class_type} completed") return (ExecutionResult.SUCCESS, None, None) @@ -649,6 +664,7 @@ class PromptExecutor: asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): + logging.info(f"๐Ÿš€ Workflow execution START: prompt_id={prompt_id}, nodes_count={len(prompt)}") nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -672,6 +688,9 @@ class PromptExecutor: for node_id in prompt: if self.caches.outputs.get(node_id) is not None: cached_nodes.append(node_id) + + if len(cached_nodes) > 0: + logging.info(f"๐Ÿ’พ Workflow has {len(cached_nodes)} cached nodes: {cached_nodes}") comfy.model_management.cleanup_models_gc() self.add_message("execution_cached", @@ -684,6 +703,8 @@ class PromptExecutor: current_outputs = self.caches.outputs.all_node_ids() for node_id in list(execute_outputs): execution_list.add_node(node_id) + + logging.info(f"๐Ÿ“‹ Workflow execution list prepared, executing {len(execute_outputs)} output nodes") while not execution_list.is_empty(): node_id, error, ex = await execution_list.stage_node_execution() @@ -695,6 +716,7 @@ class PromptExecutor: result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: + logging.error(f"๐Ÿ’ฅ Workflow execution FAILED at node {node_id}") self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) break elif result == ExecutionResult.PENDING: @@ -703,6 +725,7 @@ class PromptExecutor: execution_list.complete_node_execution() else: # Only execute when the while-loop ends without break + logging.info(f"๐ŸŽ‰ Workflow execution SUCCESS: prompt_id={prompt_id}, executed_nodes={len(executed)}") self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) ui_outputs = {} @@ -719,7 +742,10 @@ class PromptExecutor: } self.server.last_node_id = None if comfy.model_management.DISABLE_SMART_MEMORY: + logging.info("๐Ÿงน Unloading all models (DISABLE_SMART_MEMORY is enabled)") comfy.model_management.unload_all_models() + + logging.info(f"โœจ Workflow execution COMPLETED: prompt_id={prompt_id}") async def validate_inputs(prompt_id, prompt, item, validated): diff --git a/server.py b/server.py index 80e9d3fa7..515307bf6 100644 --- a/server.py +++ b/server.py @@ -673,7 +673,7 @@ class PromptServer(): @routes.post("/prompt") async def post_prompt(request): - logging.info("got prompt") + logging.info("got prompt in debug comfyui") json_data = await request.json() json_data = self.trigger_on_prompt(json_data) From 2b222962c3f9168d2333a73ea2dd525880ec215c Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 21:42:02 +0800 Subject: [PATCH 04/35] add debug log --- comfy/model_base.py | 3 +++ comfy/sd.py | 2 ++ nodes.py | 2 ++ 3 files changed, 7 insertions(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8274c7dea..7dead0167 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -297,8 +297,11 @@ class BaseModel(torch.nn.Module): if k.startswith(unet_prefix): to_load[k[len(unet_prefix):]] = sd.pop(k) + logging.info(f"load model weights start, keys {keys}") to_load = self.model_config.process_unet_state_dict(to_load) + logging.info(f"load model {self.model_config} weights process end, keys {keys}") m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + logging.info(f"load model {self.model_config} weights end, keys {keys}") if len(m) > 0: logging.warning("unet missing: {}".format(m)) diff --git a/comfy/sd.py b/comfy/sd.py index 28bee248d..16d54f08b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1347,7 +1347,9 @@ def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model(unet_path, model_options={}): sd = comfy.utils.load_torch_file(unet_path) + logging.info(f"load model start, path {unet_path}") model = load_diffusion_model_state_dict(sd, model_options=model_options) + logging.info(f"load model end, path {unet_path}") if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) diff --git a/nodes.py b/nodes.py index 7cfa8ca14..25ccc9e42 100644 --- a/nodes.py +++ b/nodes.py @@ -922,7 +922,9 @@ class UNETLoader: model_options["dtype"] = torch.float8_e5m2 unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) + logging.info(f"load unet node start, path {unet_path}") model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) + logging.info(f"load unet node end, path {unet_path}") return (model,) class CLIPLoader: From c1eac555c011f05ff4a3393ce0c86964314ccc18 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 21:42:48 +0800 Subject: [PATCH 05/35] add debug log --- comfy/model_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 7dead0167..6c8ee69b4 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -291,6 +291,7 @@ class BaseModel(torch.nn.Module): return out def load_model_weights(self, sd, unet_prefix=""): + import pdb; pdb.set_trace() to_load = {} keys = list(sd.keys()) for k in keys: From 9352987e9bc625dd5b4f1acdbf059ad5c2382172 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 22:25:17 +0800 Subject: [PATCH 06/35] add log --- comfy/model_base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 6c8ee69b4..75d469221 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -60,6 +60,7 @@ import math from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher +from comfy.model_management import get_free_memory class ModelType(Enum): EPS = 1 @@ -291,18 +292,19 @@ class BaseModel(torch.nn.Module): return out def load_model_weights(self, sd, unet_prefix=""): - import pdb; pdb.set_trace() to_load = {} keys = list(sd.keys()) for k in keys: if k.startswith(unet_prefix): to_load[k[len(unet_prefix):]] = sd.pop(k) - logging.info(f"load model weights start, keys {keys}") + free_cpu_memory = get_free_memory(torch.device("cpu")) + logging.info(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") to_load = self.model_config.process_unet_state_dict(to_load) - logging.info(f"load model {self.model_config} weights process end, keys {keys}") + logging.info(f"load model {self.model_config} weights process end") m, u = self.diffusion_model.load_state_dict(to_load, strict=False) - logging.info(f"load model {self.model_config} weights end, keys {keys}") + free_cpu_memory = get_free_memory(torch.device("cpu")) + logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") if len(m) > 0: logging.warning("unet missing: {}".format(m)) From a207301c25e7fd83723152fc343a5ac49f983f4d Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 22:28:06 +0800 Subject: [PATCH 07/35] rm useless log --- execution.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/execution.py b/execution.py index 69bd53502..c3a4cc5fa 100644 --- a/execution.py +++ b/execution.py @@ -401,11 +401,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - # Log node execution start - logging.info(f"๐Ÿ“ Node [{display_node_id}] START: {class_type}") if caches.outputs.get(unique_id) is not None: - logging.info(f"โœ… Node [{display_node_id}] CACHED: {class_type} (using cached output)") if server.client_id is not None: cached_output = caches.ui.get(unique_id) or {} server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) @@ -451,20 +448,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - logging.info(f"๐Ÿ” Node [{display_node_id}] Getting input data for {class_type}") input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) - logging.info(f"๐Ÿ“ฅ Node [{display_node_id}] Input data ready, keys: {list(input_data_all.keys())}") if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) obj = caches.objects.get(unique_id) if obj is None: - logging.info(f"๐Ÿ—๏ธ Node [{display_node_id}] Creating new instance of {class_type}") obj = class_def() caches.objects.set(unique_id, obj) - else: - logging.info(f"โ™ป๏ธ Node [{display_node_id}] Reusing cached instance of {class_type}") if issubclass(class_def, _ComfyNodeInternal): lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None @@ -503,9 +495,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - logging.info(f"โš™๏ธ Node [{display_node_id}] Executing {class_type}") output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) - logging.info(f"๐Ÿ“ค Node [{display_node_id}] Execution completed, has_subgraph: {has_subgraph}, has_pending: {has_pending_tasks}") if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) @@ -584,7 +574,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, for name, inputs in input_data_all.items(): input_data_formatted[name] = [format_value(x) for x in inputs] - logging.error(f"โŒ Node [{display_node_id}] FAILED: {class_type}") logging.error(f"!!! Exception during processing !!! {ex}") logging.error(traceback.format_exc()) tips = "" @@ -607,7 +596,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, get_progress_state().finish_progress(unique_id) executed.add(unique_id) - logging.info(f"โœ… Node [{display_node_id}] SUCCESS: {class_type} completed") return (ExecutionResult.SUCCESS, None, None) From 71b23d12e45e39fb2e94da510b823831e9a7b151 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 22:34:55 +0800 Subject: [PATCH 08/35] rm useless log --- execution.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/execution.py b/execution.py index c3a4cc5fa..53f295357 100644 --- a/execution.py +++ b/execution.py @@ -652,7 +652,6 @@ class PromptExecutor: asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): - logging.info(f"๐Ÿš€ Workflow execution START: prompt_id={prompt_id}, nodes_count={len(prompt)}") nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -676,9 +675,6 @@ class PromptExecutor: for node_id in prompt: if self.caches.outputs.get(node_id) is not None: cached_nodes.append(node_id) - - if len(cached_nodes) > 0: - logging.info(f"๐Ÿ’พ Workflow has {len(cached_nodes)} cached nodes: {cached_nodes}") comfy.model_management.cleanup_models_gc() self.add_message("execution_cached", @@ -691,8 +687,6 @@ class PromptExecutor: current_outputs = self.caches.outputs.all_node_ids() for node_id in list(execute_outputs): execution_list.add_node(node_id) - - logging.info(f"๐Ÿ“‹ Workflow execution list prepared, executing {len(execute_outputs)} output nodes") while not execution_list.is_empty(): node_id, error, ex = await execution_list.stage_node_execution() @@ -704,7 +698,6 @@ class PromptExecutor: result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: - logging.error(f"๐Ÿ’ฅ Workflow execution FAILED at node {node_id}") self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) break elif result == ExecutionResult.PENDING: @@ -713,7 +706,6 @@ class PromptExecutor: execution_list.complete_node_execution() else: # Only execute when the while-loop ends without break - logging.info(f"๐ŸŽ‰ Workflow execution SUCCESS: prompt_id={prompt_id}, executed_nodes={len(executed)}") self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) ui_outputs = {} @@ -730,10 +722,7 @@ class PromptExecutor: } self.server.last_node_id = None if comfy.model_management.DISABLE_SMART_MEMORY: - logging.info("๐Ÿงน Unloading all models (DISABLE_SMART_MEMORY is enabled)") comfy.model_management.unload_all_models() - - logging.info(f"โœจ Workflow execution COMPLETED: prompt_id={prompt_id}") async def validate_inputs(prompt_id, prompt, item, validated): From e5ff6a1b53211ce3130cc0de071ce137714e03a4 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 22:47:03 +0800 Subject: [PATCH 09/35] refine log --- comfy/model_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 75d469221..b0bb0cfb0 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -300,6 +300,7 @@ class BaseModel(torch.nn.Module): free_cpu_memory = get_free_memory(torch.device("cpu")) logging.info(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") + logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}") to_load = self.model_config.process_unet_state_dict(to_load) logging.info(f"load model {self.model_config} weights process end") m, u = self.diffusion_model.load_state_dict(to_load, strict=False) From 5c3c6c02b237b3728348f90567b7236cfc45b8b7 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 17 Oct 2025 16:33:14 +0800 Subject: [PATCH 10/35] add debug log of cpu load --- .../ldm/modules/diffusionmodules/openaimodel.py | 12 ++++++++++++ comfy/model_base.py | 17 +++++++++++++++++ comfy/model_patcher.py | 5 +++++ comfy/sd.py | 1 + 4 files changed, 35 insertions(+) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4c8d53cac..ff6e96a3c 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -911,3 +911,15 @@ class UNetModel(nn.Module): return self.id_predictor(h) else: return self.out(h) + + + def load_state_dict(self, state_dict, strict=True): + """Override load_state_dict() to add logging""" + logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") + + # Call parent's load_state_dict method + result = super().load_state_dict(state_dict, strict=strict) + + logging.info(f"UNetModel load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}") + + return result diff --git a/comfy/model_base.py b/comfy/model_base.py index b0bb0cfb0..7d474a76a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -303,6 +303,8 @@ class BaseModel(torch.nn.Module): logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}") to_load = self.model_config.process_unet_state_dict(to_load) logging.info(f"load model {self.model_config} weights process end") + # TODO(sf): to mmap + # diffusion_model is UNetModel m, u = self.diffusion_model.load_state_dict(to_load, strict=False) free_cpu_memory = get_free_memory(torch.device("cpu")) logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") @@ -384,6 +386,21 @@ class BaseModel(torch.nn.Module): #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024) + + def to(self, *args, **kwargs): + """Override to() to add custom device management logic""" + old_device = self.device if hasattr(self, 'device') else None + + result = super().to(*args, **kwargs) + + if len(args) > 0: + if isinstance(args[0], (torch.device, str)): + new_device = torch.device(args[0]) if isinstance(args[0], str) else args[0] + if 'device' in kwargs: + new_device = kwargs['device'] + + logging.info(f"BaseModel moved from {old_device} to {new_device}") + return result def extra_conds_shapes(self, **kwargs): return {} diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c0b68fb8c..ea91bd2c5 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -486,6 +486,7 @@ class ModelPatcher: return comfy.utils.get_attr(self.model, name) def model_patches_to(self, device): + # TODO(sf): to mmap to = self.model_options["transformer_options"] if "patches" in to: patches = to["patches"] @@ -783,6 +784,8 @@ class ModelPatcher: self.backup.clear() if device_to is not None: + # TODO(sf): to mmap + # self.model is what module? self.model.to(device_to) self.model.device = device_to self.model.model_loaded_weight_memory = 0 @@ -837,6 +840,8 @@ class ModelPatcher: bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights + # TODO(sf): to mmap + # m is what module? m.to(device_to) module_mem += move_weight_functions(m, device_to) if lowvram_possible: diff --git a/comfy/sd.py b/comfy/sd.py index 16d54f08b..89a1f30b8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1321,6 +1321,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): logging.warning("{} {}".format(diffusers_keys[k], k)) offload_device = model_management.unet_offload_device() + logging.info(f"loader load model to offload device: {offload_device}") unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.scaled_fp8 is not None: weight_dtype = None From 6583cc0142466473922a59d2e646881693cff011 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 17 Oct 2025 18:28:25 +0800 Subject: [PATCH 11/35] debug load mem --- comfy/ldm/flux/model.py | 13 +++++++++++++ comfy/ldm/modules/diffusionmodules/openaimodel.py | 1 + comfy/model_base.py | 2 ++ comfy/sd.py | 4 ++++ comfy/utils.py | 6 ++++++ 5 files changed, 26 insertions(+) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 14f90cea5..263cdae26 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -7,6 +7,7 @@ from torch import Tensor, nn from einops import rearrange, repeat import comfy.ldm.common_dit import comfy.patcher_extension +import logging from .layers import ( DoubleStreamBlock, @@ -278,3 +279,15 @@ class Flux(nn.Module): out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] + + def load_state_dict(self, state_dict, strict=True): + import pdb; pdb.set_trace() + """Override load_state_dict() to add logging""" + logging.info(f"Flux load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") + + # Call parent's load_state_dict method + result = super().load_state_dict(state_dict, strict=strict) + + logging.info(f"Flux load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}") + + return result diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index ff6e96a3c..e847700c6 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -914,6 +914,7 @@ class UNetModel(nn.Module): def load_state_dict(self, state_dict, strict=True): + import pdb; pdb.set_trace() """Override load_state_dict() to add logging""" logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") diff --git a/comfy/model_base.py b/comfy/model_base.py index 7d474a76a..34dd16037 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -305,6 +305,8 @@ class BaseModel(torch.nn.Module): logging.info(f"load model {self.model_config} weights process end") # TODO(sf): to mmap # diffusion_model is UNetModel + import pdb; pdb.set_trace() + # TODO(sf): here needs to avoid load mmap into cpu mem m, u = self.diffusion_model.load_state_dict(to_load, strict=False) free_cpu_memory = get_free_memory(torch.device("cpu")) logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") diff --git a/comfy/sd.py b/comfy/sd.py index 89a1f30b8..7005a1b53 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1338,6 +1338,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") + import pdb; pdb.set_trace() model = model.to(offload_device) model.load_model_weights(new_sd, "") left_over = sd.keys() @@ -1347,10 +1348,13 @@ def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model(unet_path, model_options={}): + # TODO(sf): here load file into mem sd = comfy.utils.load_torch_file(unet_path) logging.info(f"load model start, path {unet_path}") + import pdb; pdb.set_trace() model = load_diffusion_model_state_dict(sd, model_options=model_options) logging.info(f"load model end, path {unet_path}") + import pdb; pdb.set_trace() if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) diff --git a/comfy/utils.py b/comfy/utils.py index 0fd03f165..a66402451 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -55,11 +55,15 @@ else: logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): + # TODO(sf): here load file into mmap + logging.info(f"load_torch_file start, ckpt={ckpt}, safe_load={safe_load}, device={device}, return_metadata={return_metadata}") if device is None: device = torch.device("cpu") metadata = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: + if not DISABLE_MMAP: + logging.info(f"load_torch_file safetensors mmap True") with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: sd = {} for k in f.keys(): @@ -80,6 +84,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): else: torch_args = {} if MMAP_TORCH_FILES: + logging.info(f"load_torch_file mmap True") torch_args["mmap"] = True if safe_load or ALWAYS_SAFE_LOAD: @@ -97,6 +102,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): sd = pl_sd else: sd = pl_sd + import pdb; pdb.set_trace() return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None): From 49597bfa3e36c78635db4234106611846fbc4117 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 17 Oct 2025 21:43:49 +0800 Subject: [PATCH 12/35] load remains mmap --- comfy/ldm/flux/model.py | 6 +++--- comfy/ldm/modules/diffusionmodules/openaimodel.py | 6 +++--- comfy/model_base.py | 4 ++-- comfy/sd.py | 6 +++--- comfy/utils.py | 2 +- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 263cdae26..da46ed2ed 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -280,13 +280,13 @@ class Flux(nn.Module): out = out[:, :img_tokens] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] - def load_state_dict(self, state_dict, strict=True): - import pdb; pdb.set_trace() + def load_state_dict(self, state_dict, strict=True, assign=False): + # import pdb; pdb.set_trace() """Override load_state_dict() to add logging""" logging.info(f"Flux load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") # Call parent's load_state_dict method - result = super().load_state_dict(state_dict, strict=strict) + result = super().load_state_dict(state_dict, strict=strict, assign=assign) logging.info(f"Flux load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}") diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index e847700c6..2cdf711d4 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -913,13 +913,13 @@ class UNetModel(nn.Module): return self.out(h) - def load_state_dict(self, state_dict, strict=True): - import pdb; pdb.set_trace() + def load_state_dict(self, state_dict, strict=True, assign=False): + # import pdb; pdb.set_trace() """Override load_state_dict() to add logging""" logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") # Call parent's load_state_dict method - result = super().load_state_dict(state_dict, strict=strict) + result = super().load_state_dict(state_dict, strict=strict, assign=assign) logging.info(f"UNetModel load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}") diff --git a/comfy/model_base.py b/comfy/model_base.py index 34dd16037..409e7fb87 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -305,9 +305,9 @@ class BaseModel(torch.nn.Module): logging.info(f"load model {self.model_config} weights process end") # TODO(sf): to mmap # diffusion_model is UNetModel - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() # TODO(sf): here needs to avoid load mmap into cpu mem - m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True) free_cpu_memory = get_free_memory(torch.device("cpu")) logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") if len(m) > 0: diff --git a/comfy/sd.py b/comfy/sd.py index 7005a1b53..a956884fb 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1338,7 +1338,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() model = model.to(offload_device) model.load_model_weights(new_sd, "") left_over = sd.keys() @@ -1351,10 +1351,10 @@ def load_diffusion_model(unet_path, model_options={}): # TODO(sf): here load file into mem sd = comfy.utils.load_torch_file(unet_path) logging.info(f"load model start, path {unet_path}") - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() model = load_diffusion_model_state_dict(sd, model_options=model_options) logging.info(f"load model end, path {unet_path}") - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) diff --git a/comfy/utils.py b/comfy/utils.py index a66402451..4c22f684c 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -102,7 +102,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): sd = pl_sd else: sd = pl_sd - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None): From 21ebcada1da1466d2a3fe91c9e517156ed5172cf Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 20 Oct 2025 16:22:50 +0800 Subject: [PATCH 13/35] debug free mem --- comfy/model_management.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 840239a27..79f043419 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -588,6 +588,7 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[]): logging.info("start to free mem") + import pdb; pdb.set_trace() cleanup_models_gc() unloaded_model = [] can_unload = [] From 4ac827d56454838e051fb05b0047fea06359bcc7 Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 20 Oct 2025 18:27:38 +0800 Subject: [PATCH 14/35] unload partial --- comfy/model_management.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 79f043419..30a509670 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -516,9 +516,17 @@ class LoadedModel: logging.info(f"offload_device: {self.model.offload_device}") available_memory = get_free_memory(self.model.offload_device) logging.info(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") - if available_memory < memory_to_free: - logging.error(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Required: {memory_to_free/(1024*1024*1024)} GB") + reserved_memory = 1024*1024*1024 # 1GB reserved memory for other usage + if available_memory < reserved_memory: + logging.error(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB") return False + else: + offload_memory = available_memory - reserved_memory + + if offload_memory < memory_to_free: + memory_to_free = offload_memory + logging.info(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB, Offload: {offload_memory/(1024*1024*1024)} GB") + logging.info(f"Set memory_to_free to {memory_to_free/(1024*1024*1024)} GB") try: if memory_to_free is not None: if memory_to_free < self.model.loaded_size(): From e9e1d2f0e82af07b701a72e20c171625cdc1f402 Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 00:40:14 +0800 Subject: [PATCH 15/35] add mmap tensor --- comfy/model_patcher.py | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index ea91bd2c5..e4d8507d0 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -27,6 +27,7 @@ import uuid from typing import Callable, Optional import torch +import tensordict import comfy.float import comfy.hooks @@ -37,6 +38,9 @@ import comfy.utils from comfy.comfy_types import UnetWrapperFunction from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP +def to_mmap(t: torch.Tensor) -> tensordict.MemoryMappedTensor: + return tensordict.MemoryMappedTensor.from_tensor(t) + def string_to_seed(data): crc = 0xFFFFFFFF @@ -784,9 +788,37 @@ class ModelPatcher: self.backup.clear() if device_to is not None: - # TODO(sf): to mmap - # self.model is what module? - self.model.to(device_to) + # Temporarily register to_mmap method to the model + # Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244 + def _to_mmap_method(self): + """Convert all parameters and buffers to memory-mapped tensors + + This method mimics PyTorch's Module.to() behavior but converts + tensors to memory-mapped format instead. + """ + import pdb; pdb.set_trace() + logging.info(f"model {self.model.__class__.__name__} is calling to_mmap method") + def convert_fn(t): + if isinstance(t, torch.Tensor) and not isinstance(t, torch.nn.Parameter): + return to_mmap(t) + elif isinstance(t, torch.nn.Parameter): + # For parameters, convert the data and wrap back in Parameter + param_mmap = to_mmap(t.data) + return torch.nn.Parameter(param_mmap, requires_grad=t.requires_grad) + return t + + return self._apply(convert_fn) + + # Bind the method to the model instance + import types + self.model.to_mmap = types.MethodType(_to_mmap_method, self.model) + + # Call the to_mmap method + self.model.to_mmap() + + # Optionally clean up the temporary method + # delattr(self.model, 'to_mmap') + self.model.device = device_to self.model.model_loaded_weight_memory = 0 From 49561788cfccecd872808515a3975df772155a75 Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 02:03:38 +0800 Subject: [PATCH 16/35] fix log --- comfy/model_patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e4d8507d0..10ac1e7de 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -797,7 +797,7 @@ class ModelPatcher: tensors to memory-mapped format instead. """ import pdb; pdb.set_trace() - logging.info(f"model {self.model.__class__.__name__} is calling to_mmap method") + logging.info(f"model {self.__class__.__name__} is calling to_mmap method") def convert_fn(t): if isinstance(t, torch.Tensor) and not isinstance(t, torch.nn.Parameter): return to_mmap(t) From 8aeebbf7ef0e6b54e41473661fb0ea216d380e29 Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 02:27:40 +0800 Subject: [PATCH 17/35] fix to --- comfy/model_patcher.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 10ac1e7de..4c7cd5e3e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -794,17 +794,28 @@ class ModelPatcher: """Convert all parameters and buffers to memory-mapped tensors This method mimics PyTorch's Module.to() behavior but converts - tensors to memory-mapped format instead. + tensors to memory-mapped format instead, using _apply() method. + + Note: For Parameters, we modify .data in-place because + MemoryMappedTensor cannot be wrapped in torch.nn.Parameter. + For buffers, _apply() will automatically update the reference. """ - import pdb; pdb.set_trace() logging.info(f"model {self.__class__.__name__} is calling to_mmap method") + def convert_fn(t): - if isinstance(t, torch.Tensor) and not isinstance(t, torch.nn.Parameter): + """Convert function for _apply() + + - For Parameters: modify .data and return the Parameter object + - For buffers (plain Tensors): return new MemoryMappedTensor + """ + if isinstance(t, torch.nn.Parameter): + # For parameters, modify data in-place and return the parameter + if isinstance(t.data, torch.Tensor): + t.data = to_mmap(t.data) + return t + elif isinstance(t, torch.Tensor): + # For buffers (plain tensors), return the converted tensor return to_mmap(t) - elif isinstance(t, torch.nn.Parameter): - # For parameters, convert the data and wrap back in Parameter - param_mmap = to_mmap(t.data) - return torch.nn.Parameter(param_mmap, requires_grad=t.requires_grad) return t return self._apply(convert_fn) From 05c2518c6dea831cc15031dc8833afddbbb5a33e Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 02:59:51 +0800 Subject: [PATCH 18/35] refact mmap --- comfy/model_patcher.py | 91 ++++++++++++++++++++++-------------------- 1 file changed, 47 insertions(+), 44 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4c7cd5e3e..d2e3a296a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -37,9 +37,52 @@ import comfy.patcher_extension import comfy.utils from comfy.comfy_types import UnetWrapperFunction from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP +from comfy.model_management import get_free_memory def to_mmap(t: torch.Tensor) -> tensordict.MemoryMappedTensor: return tensordict.MemoryMappedTensor.from_tensor(t) + +def model_to_mmap(model: torch.nn.Module): + """Convert all parameters and buffers to memory-mapped tensors + + This function mimics PyTorch's Module.to() behavior but converts + tensors to memory-mapped format instead, using _apply() method. + + Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244 + + Note: For Parameters, we modify .data in-place because + MemoryMappedTensor cannot be wrapped in torch.nn.Parameter. + For buffers, _apply() will automatically update the reference. + + Args: + model: PyTorch module to convert + + Returns: + The same model with all tensors converted to memory-mapped format + """ + free_cpu_mem = get_free_memory(torch.device("cpu")) + logging.info(f"Converting model {model.__class__.__name__} to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + + def convert_fn(t): + """Convert function for _apply() + + - For Parameters: modify .data and return the Parameter object + - For buffers (plain Tensors): return new MemoryMappedTensor + """ + if isinstance(t, torch.nn.Parameter): + # For parameters, modify data in-place and return the parameter + if isinstance(t.data, torch.Tensor): + t.data = to_mmap(t.data) + return t + elif isinstance(t, torch.Tensor): + # For buffers (plain tensors), return the converted tensor + return to_mmap(t) + return t + + new_model = model._apply(convert_fn) + free_cpu_mem = get_free_memory(torch.device("cpu")) + logging.info(f"Model {model.__class__.__name__} converted to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + return new_model def string_to_seed(data): @@ -787,50 +830,9 @@ class ModelPatcher: self.model.current_weight_patches_uuid = None self.backup.clear() - if device_to is not None: - # Temporarily register to_mmap method to the model - # Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244 - def _to_mmap_method(self): - """Convert all parameters and buffers to memory-mapped tensors - - This method mimics PyTorch's Module.to() behavior but converts - tensors to memory-mapped format instead, using _apply() method. - - Note: For Parameters, we modify .data in-place because - MemoryMappedTensor cannot be wrapped in torch.nn.Parameter. - For buffers, _apply() will automatically update the reference. - """ - logging.info(f"model {self.__class__.__name__} is calling to_mmap method") - - def convert_fn(t): - """Convert function for _apply() - - - For Parameters: modify .data and return the Parameter object - - For buffers (plain Tensors): return new MemoryMappedTensor - """ - if isinstance(t, torch.nn.Parameter): - # For parameters, modify data in-place and return the parameter - if isinstance(t.data, torch.Tensor): - t.data = to_mmap(t.data) - return t - elif isinstance(t, torch.Tensor): - # For buffers (plain tensors), return the converted tensor - return to_mmap(t) - return t - - return self._apply(convert_fn) - # Bind the method to the model instance - import types - self.model.to_mmap = types.MethodType(_to_mmap_method, self.model) - - # Call the to_mmap method - self.model.to_mmap() - - # Optionally clean up the temporary method - # delattr(self.model, 'to_mmap') - - self.model.device = device_to + model_to_mmap(self.model) + self.model.device = device_to self.model.model_loaded_weight_memory = 0 for m in self.model.modules(): @@ -885,7 +887,8 @@ class ModelPatcher: cast_weight = self.force_cast_weights # TODO(sf): to mmap # m is what module? - m.to(device_to) + # m.to(device_to) + model_to_mmap(m) module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: From 2f0d56656eea7da5a9dda2e5b0061b31bc5aefbd Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 11:38:17 +0800 Subject: [PATCH 19/35] refine code --- comfy/ldm/flux/model.py | 12 ----------- .../modules/diffusionmodules/openaimodel.py | 14 +------------ comfy/model_base.py | 20 +------------------ comfy/model_management.py | 1 - comfy/model_patcher.py | 10 ++++++---- 5 files changed, 8 insertions(+), 49 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index da46ed2ed..a07c3ca95 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -279,15 +279,3 @@ class Flux(nn.Module): out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] - - def load_state_dict(self, state_dict, strict=True, assign=False): - # import pdb; pdb.set_trace() - """Override load_state_dict() to add logging""" - logging.info(f"Flux load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") - - # Call parent's load_state_dict method - result = super().load_state_dict(state_dict, strict=strict, assign=assign) - - logging.info(f"Flux load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}") - - return result diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 2cdf711d4..cd8997716 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -911,16 +911,4 @@ class UNetModel(nn.Module): return self.id_predictor(h) else: return self.out(h) - - - def load_state_dict(self, state_dict, strict=True, assign=False): - # import pdb; pdb.set_trace() - """Override load_state_dict() to add logging""" - logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") - - # Call parent's load_state_dict method - result = super().load_state_dict(state_dict, strict=strict, assign=assign) - - logging.info(f"UNetModel load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}") - - return result + \ No newline at end of file diff --git a/comfy/model_base.py b/comfy/model_base.py index 409e7fb87..d2d4aa93d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -303,10 +303,7 @@ class BaseModel(torch.nn.Module): logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}") to_load = self.model_config.process_unet_state_dict(to_load) logging.info(f"load model {self.model_config} weights process end") - # TODO(sf): to mmap - # diffusion_model is UNetModel - # import pdb; pdb.set_trace() - # TODO(sf): here needs to avoid load mmap into cpu mem + # replace tensor with mmap tensor by assign m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True) free_cpu_memory = get_free_memory(torch.device("cpu")) logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") @@ -389,21 +386,6 @@ class BaseModel(torch.nn.Module): area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024) - def to(self, *args, **kwargs): - """Override to() to add custom device management logic""" - old_device = self.device if hasattr(self, 'device') else None - - result = super().to(*args, **kwargs) - - if len(args) > 0: - if isinstance(args[0], (torch.device, str)): - new_device = torch.device(args[0]) if isinstance(args[0], str) else args[0] - if 'device' in kwargs: - new_device = kwargs['device'] - - logging.info(f"BaseModel moved from {old_device} to {new_device}") - return result - def extra_conds_shapes(self, **kwargs): return {} diff --git a/comfy/model_management.py b/comfy/model_management.py index 30a509670..4c29b07e1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -596,7 +596,6 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[]): logging.info("start to free mem") - import pdb; pdb.set_trace() cleanup_models_gc() unloaded_model = [] can_unload = [] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d2e3a296a..1c725663a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -831,8 +831,11 @@ class ModelPatcher: self.backup.clear() - model_to_mmap(self.model) - self.model.device = device_to + if device_to is not None: + # offload to mmap + model_to_mmap(self.model) + self.model.device = device_to + self.model.model_loaded_weight_memory = 0 for m in self.model.modules(): @@ -885,8 +888,7 @@ class ModelPatcher: bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - # TODO(sf): to mmap - # m is what module? + # offload to mmap # m.to(device_to) model_to_mmap(m) module_mem += move_weight_functions(m, device_to) From 2d010f545c9df6c3e07b7560ba7887432261947f Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 11:54:56 +0800 Subject: [PATCH 20/35] refine code --- comfy/ldm/flux/model.py | 1 - .../modules/diffusionmodules/openaimodel.py | 3 +- comfy/model_base.py | 10 +++---- comfy/model_management.py | 30 +++++++++---------- comfy/model_patcher.py | 4 +-- comfy/sd.py | 8 +---- comfy/utils.py | 7 ++--- execution.py | 3 -- nodes.py | 2 -- server.py | 2 +- 10 files changed, 27 insertions(+), 43 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index a07c3ca95..14f90cea5 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -7,7 +7,6 @@ from torch import Tensor, nn from einops import rearrange, repeat import comfy.ldm.common_dit import comfy.patcher_extension -import logging from .layers import ( DoubleStreamBlock, diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index cd8997716..4963811a8 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -910,5 +910,4 @@ class UNetModel(nn.Module): if self.predict_codebook_ids: return self.id_predictor(h) else: - return self.out(h) - \ No newline at end of file + return self.out(h) \ No newline at end of file diff --git a/comfy/model_base.py b/comfy/model_base.py index d2d4aa93d..d6ef644dd 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -299,14 +299,14 @@ class BaseModel(torch.nn.Module): to_load[k[len(unet_prefix):]] = sd.pop(k) free_cpu_memory = get_free_memory(torch.device("cpu")) - logging.info(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") - logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}") + logging.debug(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") + logging.debug(f"model destination device {next(self.diffusion_model.parameters()).device}") to_load = self.model_config.process_unet_state_dict(to_load) - logging.info(f"load model {self.model_config} weights process end") + logging.debug(f"load model {self.model_config} weights process end") # replace tensor with mmap tensor by assign m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True) free_cpu_memory = get_free_memory(torch.device("cpu")) - logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") + logging.debug(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") if len(m) > 0: logging.warning("unet missing: {}".format(m)) @@ -385,7 +385,7 @@ class BaseModel(torch.nn.Module): #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024) - + def extra_conds_shapes(self, **kwargs): return {} diff --git a/comfy/model_management.py b/comfy/model_management.py index 4c29b07e1..70a5039ef 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -509,16 +509,16 @@ class LoadedModel: return False def model_unload(self, memory_to_free=None, unpatch_weights=True): - logging.info(f"model_unload: {self.model.model.__class__.__name__}") - logging.info(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB") - logging.info(f"unpatch_weights: {unpatch_weights}") - logging.info(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB") - logging.info(f"offload_device: {self.model.offload_device}") + logging.debug(f"model_unload: {self.model.model.__class__.__name__}") + logging.debug(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB") + logging.debug(f"unpatch_weights: {unpatch_weights}") + logging.debug(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB") + logging.debug(f"offload_device: {self.model.offload_device}") available_memory = get_free_memory(self.model.offload_device) - logging.info(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") reserved_memory = 1024*1024*1024 # 1GB reserved memory for other usage if available_memory < reserved_memory: - logging.error(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB") + logging.warning(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB") return False else: offload_memory = available_memory - reserved_memory @@ -530,14 +530,14 @@ class LoadedModel: try: if memory_to_free is not None: if memory_to_free < self.model.loaded_size(): - logging.info("Do partially unload") + logging.debug("Do partially unload") freed = self.model.partially_unload(self.model.offload_device, memory_to_free) - logging.info(f"partially_unload freed: {freed/(1024*1024*1024)} GB") + logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB") if freed >= memory_to_free: return False - logging.info("Do full unload") + logging.debug("Do full unload") self.model.detach(unpatch_weights) - logging.info("Do full unload done") + logging.debug("Do full unload done") except Exception as e: logging.error(f"Error in model_unload: {e}") available_memory = get_free_memory(self.model.offload_device) @@ -595,7 +595,7 @@ def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() def free_memory(memory_required, device, keep_loaded=[]): - logging.info("start to free mem") + logging.debug("start to free mem") cleanup_models_gc() unloaded_model = [] can_unload = [] @@ -616,7 +616,7 @@ def free_memory(memory_required, device, keep_loaded=[]): if free_mem > memory_required: break memory_to_free = memory_required - free_mem - logging.info(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") + logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") if current_loaded_models[i].model_unload(memory_to_free): unloaded_model.append(i) @@ -633,7 +633,7 @@ def free_memory(memory_required, device, keep_loaded=[]): return unloaded_models def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): - logging.info(f"start to load models") + logging.debug(f"start to load models") cleanup_models_gc() global vram_state @@ -655,7 +655,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models_to_load = [] for x in models: - logging.info(f"loading model: {x.model.__class__.__name__}") + logging.debug(f"start loading model to vram: {x.model.__class__.__name__}") loaded_model = LoadedModel(x) try: loaded_model_index = current_loaded_models.index(loaded_model) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 1c725663a..63bae24d3 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -61,7 +61,7 @@ def model_to_mmap(model: torch.nn.Module): The same model with all tensors converted to memory-mapped format """ free_cpu_mem = get_free_memory(torch.device("cpu")) - logging.info(f"Converting model {model.__class__.__name__} to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") def convert_fn(t): """Convert function for _apply() @@ -81,7 +81,7 @@ def model_to_mmap(model: torch.nn.Module): new_model = model._apply(convert_fn) free_cpu_mem = get_free_memory(torch.device("cpu")) - logging.info(f"Model {model.__class__.__name__} converted to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + logging.debug(f"Model {model.__class__.__name__} converted to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") return new_model diff --git a/comfy/sd.py b/comfy/sd.py index a956884fb..3651da5e7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1321,7 +1321,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): logging.warning("{} {}".format(diffusers_keys[k], k)) offload_device = model_management.unet_offload_device() - logging.info(f"loader load model to offload device: {offload_device}") + logging.debug(f"loader load model to offload device: {offload_device}") unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.scaled_fp8 is not None: weight_dtype = None @@ -1338,7 +1338,6 @@ def load_diffusion_model_state_dict(sd, model_options={}): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") - # import pdb; pdb.set_trace() model = model.to(offload_device) model.load_model_weights(new_sd, "") left_over = sd.keys() @@ -1348,13 +1347,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model(unet_path, model_options={}): - # TODO(sf): here load file into mem sd = comfy.utils.load_torch_file(unet_path) - logging.info(f"load model start, path {unet_path}") - # import pdb; pdb.set_trace() model = load_diffusion_model_state_dict(sd, model_options=model_options) - logging.info(f"load model end, path {unet_path}") - # import pdb; pdb.set_trace() if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) diff --git a/comfy/utils.py b/comfy/utils.py index 4c22f684c..be6ab7596 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -55,15 +55,13 @@ else: logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): - # TODO(sf): here load file into mmap - logging.info(f"load_torch_file start, ckpt={ckpt}, safe_load={safe_load}, device={device}, return_metadata={return_metadata}") if device is None: device = torch.device("cpu") metadata = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: if not DISABLE_MMAP: - logging.info(f"load_torch_file safetensors mmap True") + logging.debug(f"load_torch_file of safetensors into mmap True") with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: sd = {} for k in f.keys(): @@ -84,7 +82,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): else: torch_args = {} if MMAP_TORCH_FILES: - logging.info(f"load_torch_file mmap True") + logging.debug(f"load_torch_file of torch state dict into mmap True") torch_args["mmap"] = True if safe_load or ALWAYS_SAFE_LOAD: @@ -102,7 +100,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): sd = pl_sd else: sd = pl_sd - # import pdb; pdb.set_trace() return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None): diff --git a/execution.py b/execution.py index 53f295357..1dc35738b 100644 --- a/execution.py +++ b/execution.py @@ -400,8 +400,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - - if caches.outputs.get(unique_id) is not None: if server.client_id is not None: cached_output = caches.ui.get(unique_id) or {} @@ -595,7 +593,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, get_progress_state().finish_progress(unique_id) executed.add(unique_id) - return (ExecutionResult.SUCCESS, None, None) diff --git a/nodes.py b/nodes.py index 25ccc9e42..7cfa8ca14 100644 --- a/nodes.py +++ b/nodes.py @@ -922,9 +922,7 @@ class UNETLoader: model_options["dtype"] = torch.float8_e5m2 unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) - logging.info(f"load unet node start, path {unet_path}") model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) - logging.info(f"load unet node end, path {unet_path}") return (model,) class CLIPLoader: diff --git a/server.py b/server.py index 515307bf6..80e9d3fa7 100644 --- a/server.py +++ b/server.py @@ -673,7 +673,7 @@ class PromptServer(): @routes.post("/prompt") async def post_prompt(request): - logging.info("got prompt in debug comfyui") + logging.info("got prompt") json_data = await request.json() json_data = self.trigger_on_prompt(json_data) From fff56de63cfe9ad7057d8403c13c6428d57593c5 Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 11:59:59 +0800 Subject: [PATCH 21/35] fix format --- comfy/ldm/modules/diffusionmodules/openaimodel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4963811a8..4c8d53cac 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -910,4 +910,4 @@ class UNetModel(nn.Module): if self.predict_codebook_ids: return self.id_predictor(h) else: - return self.out(h) \ No newline at end of file + return self.out(h) From 08e094ed81b66e23876a5cc8be1bb9f40f213061 Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 17:00:56 +0800 Subject: [PATCH 22/35] use native mmap --- comfy/model_patcher.py | 78 +++++++- tests/execution/test_model_mmap.py | 280 +++++++++++++++++++++++++++++ 2 files changed, 355 insertions(+), 3 deletions(-) create mode 100644 tests/execution/test_model_mmap.py diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 63bae24d3..0f4445d33 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -27,7 +27,10 @@ import uuid from typing import Callable, Optional import torch -import tensordict +import os +import tempfile +import weakref +import gc import comfy.float import comfy.hooks @@ -39,8 +42,77 @@ from comfy.comfy_types import UnetWrapperFunction from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP from comfy.model_management import get_free_memory -def to_mmap(t: torch.Tensor) -> tensordict.MemoryMappedTensor: - return tensordict.MemoryMappedTensor.from_tensor(t) + +def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: + """ + Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support. + """ + # Move to CPU if needed + if t.is_cuda: + cpu_tensor = t.cpu() + else: + cpu_tensor = t + + # Create temporary file + if filename is None: + temp_file = tempfile.mktemp(suffix='.pt', prefix='comfy_mmap_') + else: + temp_file = filename + + # Save tensor to file + torch.save(cpu_tensor, temp_file) + + # If we created a CPU copy from CUDA, delete it to free memory + if t.is_cuda: + del cpu_tensor + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Load with mmap - this doesn't load all data into RAM + mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False) + + # Register cleanup callback + def _cleanup(): + try: + if os.path.exists(temp_file): + os.remove(temp_file) + logging.debug(f"Cleaned up mmap file: {temp_file}") + except Exception: + pass + + weakref.finalize(mmap_tensor, _cleanup) + + # Save original 'to' method + original_to = mmap_tensor.to + + # Create custom 'to' method that cleans up file when moving to CUDA + def custom_to(*args, **kwargs): + # Determine target device + target_device = None + if len(args) > 0: + if isinstance(args[0], torch.device): + target_device = args[0] + elif isinstance(args[0], str): + target_device = torch.device(args[0]) + if 'device' in kwargs: + target_device = kwargs['device'] + if isinstance(target_device, str): + target_device = torch.device(target_device) + + # Call original 'to' method first to move data + result = original_to(*args, **kwargs) + + # If moved to CUDA, cleanup the mmap file after the move + if target_device is not None and target_device.type == 'cuda': + _cleanup() + + return result + + # Replace the 'to' method + mmap_tensor.to = custom_to + + return mmap_tensor def model_to_mmap(model: torch.nn.Module): """Convert all parameters and buffers to memory-mapped tensors diff --git a/tests/execution/test_model_mmap.py b/tests/execution/test_model_mmap.py new file mode 100644 index 000000000..65dbe01bd --- /dev/null +++ b/tests/execution/test_model_mmap.py @@ -0,0 +1,280 @@ +import pytest +import torch +import torch.nn as nn +import psutil +import os +import gc +import tempfile +from comfy.model_patcher import model_to_mmap, to_mmap + + +class LargeModel(nn.Module): + """A simple model with large parameters for testing memory mapping""" + + def __init__(self, size_gb=10): + super().__init__() + # Calculate number of float32 elements needed for target size + # 1 GB = 1024^3 bytes, float32 = 4 bytes + bytes_per_gb = 1024 * 1024 * 1024 + elements_per_gb = bytes_per_gb // 4 # float32 is 4 bytes + total_elements = int(size_gb * elements_per_gb) + + # Create a large linear layer + # Split into multiple layers to avoid single tensor size limits + self.layers = nn.ModuleList() + elements_per_layer = 500 * 1024 * 1024 # 500M elements per layer (~2GB) + num_layers = (total_elements + elements_per_layer - 1) // elements_per_layer + + for i in range(num_layers): + if i == num_layers - 1: + # Last layer gets the remaining elements + remaining = total_elements - (i * elements_per_layer) + in_features = int(remaining ** 0.5) + out_features = (remaining + in_features - 1) // in_features + else: + in_features = int(elements_per_layer ** 0.5) + out_features = (elements_per_layer + in_features - 1) // in_features + + # Create layer without bias to control size precisely + self.layers.append(nn.Linear(in_features, out_features, bias=False)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def get_process_memory_gb(): + """Get current process memory usage in GB""" + process = psutil.Process(os.getpid()) + mem_info = process.memory_info() + return mem_info.rss / (1024 ** 3) # Convert to GB + + +def get_model_size_gb(model): + """Calculate model size in GB""" + total_size = 0 + for param in model.parameters(): + total_size += param.nelement() * param.element_size() + for buffer in model.buffers(): + total_size += buffer.nelement() * buffer.element_size() + return total_size / (1024 ** 3) + + +def test_model_to_mmap_memory_efficiency(): + """Test that model_to_mmap reduces memory usage for a 10GB model to less than 1GB + + The typical use case is: + 1. Load a large model on CUDA + 2. Convert to mmap to offload from GPU to disk-backed memory + 3. This frees GPU memory and reduces CPU RAM usage + """ + + # Check if CUDA is available + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available, skipping test") + + # Force garbage collection before starting + gc.collect() + torch.cuda.empty_cache() + + # Record initial memory + initial_cpu_memory = get_process_memory_gb() + initial_gpu_memory = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"\nInitial CPU memory: {initial_cpu_memory:.2f} GB") + print(f"Initial GPU memory: {initial_gpu_memory:.2f} GB") + + # Create a 10GB model + print("Creating 10GB model...") + model = LargeModel(size_gb=10) + + # Verify model size + model_size = get_model_size_gb(model) + print(f"Model size: {model_size:.2f} GB") + assert model_size >= 9.5, f"Model size {model_size:.2f} GB is less than expected 10 GB" + + # Move model to CUDA + print("Moving model to CUDA...") + model = model.cuda() + torch.cuda.synchronize() + + # Memory after moving to CUDA + cpu_after_cuda = get_process_memory_gb() + gpu_after_cuda = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"CPU memory after moving to CUDA: {cpu_after_cuda:.2f} GB") + print(f"GPU memory after moving to CUDA: {gpu_after_cuda:.2f} GB") + + # Convert to mmap (this should move model from GPU to disk-backed memory) + # Note: model_to_mmap modifies the model in-place via _apply() + # so model and model_mmap will be the same object + print("Converting model to mmap...") + model_mmap = model_to_mmap(model) + + # Verify that model and model_mmap are the same object (in-place modification) + assert model is model_mmap, "model_to_mmap should modify the model in-place" + + # Force garbage collection and clear CUDA cache + # The original CUDA tensors should be automatically freed when replaced + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Memory after mmap conversion + cpu_after_mmap = get_process_memory_gb() + gpu_after_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"CPU memory after mmap: {cpu_after_mmap:.2f} GB") + print(f"GPU memory after mmap: {gpu_after_mmap:.2f} GB") + + # Calculate memory changes from CUDA state (the baseline we're converting from) + cpu_increase = cpu_after_mmap - cpu_after_cuda + gpu_decrease = gpu_after_cuda - gpu_after_mmap # Should be positive (freed) + print(f"\nCPU memory increase from CUDA: {cpu_increase:.2f} GB") + print(f"GPU memory freed: {gpu_decrease:.2f} GB") + + # Verify that CPU memory usage increase is less than 1GB + # The mmap should use disk-backed storage, keeping CPU RAM usage low + # We use 1.5 GB threshold to account for overhead + assert cpu_increase < 1.5, ( + f"CPU memory increase after mmap ({cpu_increase:.2f} GB) should be less than 1.5 GB. " + f"CUDA state: {cpu_after_cuda:.2f} GB, After mmap: {cpu_after_mmap:.2f} GB" + ) + + # Verify that GPU memory has been freed + # We expect at least 9 GB to be freed (original 10GB model with some tolerance) + assert gpu_decrease > 9.0, ( + f"GPU memory should be freed after mmap. " + f"Freed: {gpu_decrease:.2f} GB (from {gpu_after_cuda:.2f} to {gpu_after_mmap:.2f} GB), expected > 9 GB" + ) + + # Verify the model is still functional (basic sanity check) + assert model_mmap is not None + assert len(list(model_mmap.parameters())) > 0 + + print(f"\nโœ“ Test passed!") + print(f" CPU memory increase: {cpu_increase:.2f} GB < 1.5 GB") + print(f" GPU memory freed: {gpu_decrease:.2f} GB > 9.0 GB") + print(f" Model successfully offloaded from GPU to disk-backed memory") + + # Cleanup (model and model_mmap are the same object) + del model, model_mmap + gc.collect() + torch.cuda.empty_cache() + + +def test_to_mmap_cuda_cycle(): + """Test CUDA -> mmap -> CUDA cycle + + This test verifies: + 1. CUDA tensor can be converted to mmap tensor + 2. CPU memory increase is minimal when using mmap (< 0.1 GB) + 3. GPU memory is freed when converting to mmap + 4. mmap tensor can be moved back to CUDA + 5. Data remains consistent throughout the cycle + 6. mmap file is automatically cleaned up when moved to CUDA + """ + + # Check if CUDA is available + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available, skipping test") + + # Force garbage collection + gc.collect() + torch.cuda.empty_cache() + + print("\nTest: CUDA -> mmap -> CUDA cycle") + + # Record initial CPU memory + initial_cpu_memory = get_process_memory_gb() + print(f"Initial CPU memory: {initial_cpu_memory:.2f} GB") + + # Step 1: Create a CUDA tensor + print("\n1. Creating CUDA tensor...") + original_data = torch.randn(5000, 5000).cuda() + original_sum = original_data.sum().item() + print(f" Shape: {original_data.shape}") + print(f" Device: {original_data.device}") + print(f" Sum: {original_sum:.2f}") + + # Record GPU and CPU memory after CUDA allocation + cpu_after_cuda = get_process_memory_gb() + gpu_before_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + print(f" GPU memory: {gpu_before_mmap:.2f} GB") + print(f" CPU memory: {cpu_after_cuda:.2f} GB") + + # Step 2: Convert to mmap tensor + print("\n2. Converting to mmap tensor...") + mmap_tensor = to_mmap(original_data) + del original_data + gc.collect() + torch.cuda.empty_cache() + + print(f" Device: {mmap_tensor.device}") + print(f" Sum: {mmap_tensor.sum().item():.2f}") + + # Verify GPU memory is freed + gpu_after_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + cpu_after_mmap = get_process_memory_gb() + print(f" GPU memory freed: {gpu_before_mmap - gpu_after_mmap:.2f} GB") + print(f" CPU memory: {cpu_after_mmap:.2f} GB") + + # Verify GPU memory is freed + assert gpu_after_mmap < 0.1, f"GPU memory should be freed, but {gpu_after_mmap:.2f} GB still allocated" + + # Verify CPU memory increase is minimal (should be close to 0 due to mmap) + cpu_increase = cpu_after_mmap - cpu_after_cuda + print(f" CPU memory increase: {cpu_increase:.2f} GB") + assert cpu_increase < 0.1, f"CPU memory should increase minimally, but increased by {cpu_increase:.2f} GB" + + # Get the temp file path (we'll check if it gets cleaned up) + # The file should exist at this point + temp_files_before = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')]) + print(f" Temp mmap files exist: {temp_files_before}") + + # Step 3: Move back to CUDA + print("\n3. Moving back to CUDA...") + cuda_tensor = mmap_tensor.to('cuda') + torch.cuda.synchronize() + + print(f" Device: {cuda_tensor.device}") + final_sum = cuda_tensor.sum().item() + print(f" Sum: {final_sum:.2f}") + + # Verify GPU memory is used again + gpu_after_cuda = torch.cuda.memory_allocated() / (1024 ** 3) + print(f" GPU memory: {gpu_after_cuda:.2f} GB") + + # Step 4: Verify data consistency + print("\n4. Verifying data consistency...") + sum_diff = abs(original_sum - final_sum) + print(f" Original sum: {original_sum:.2f}") + print(f" Final sum: {final_sum:.2f}") + print(f" Difference: {sum_diff:.6f}") + assert sum_diff < 0.01, f"Data should be consistent, but difference is {sum_diff:.6f}" + + # Step 5: Verify file cleanup + print("\n5. Verifying file cleanup...") + gc.collect() + import time + time.sleep(0.1) # Give OS time to clean up + temp_files_after = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')]) + print(f" Temp mmap files after: {temp_files_after}") + # File should be cleaned up when moved to CUDA + assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after moving to CUDA" + + print("\nโœ“ Test passed!") + print(" CUDA -> mmap -> CUDA cycle works correctly") + print(f" CPU memory increase: {cpu_increase:.2f} GB < 0.1 GB (mmap efficiency)") + print(" Data consistency maintained") + print(" File cleanup successful") + + # Cleanup + del mmap_tensor, cuda_tensor + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == "__main__": + # Run the tests directly + test_model_to_mmap_memory_efficiency() + test_to_mmap_cuda_cycle() + From 80383932ec63056261d584771eba8c8c1eb51ebf Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 18:00:31 +0800 Subject: [PATCH 23/35] lazy rm file --- comfy/model_patcher.py | 55 +++++++++++++++--------------- tests/execution/test_model_mmap.py | 16 +++++---- 2 files changed, 37 insertions(+), 34 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0f4445d33..4b0c5b9c5 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -72,7 +72,7 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: # Load with mmap - this doesn't load all data into RAM mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False) - # Register cleanup callback + # Register cleanup callback - will be called when tensor is garbage collected def _cleanup(): try: if os.path.exists(temp_file): @@ -83,34 +83,35 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: weakref.finalize(mmap_tensor, _cleanup) - # Save original 'to' method - original_to = mmap_tensor.to + # # Save original 'to' method + # original_to = mmap_tensor.to - # Create custom 'to' method that cleans up file when moving to CUDA - def custom_to(*args, **kwargs): - # Determine target device - target_device = None - if len(args) > 0: - if isinstance(args[0], torch.device): - target_device = args[0] - elif isinstance(args[0], str): - target_device = torch.device(args[0]) - if 'device' in kwargs: - target_device = kwargs['device'] - if isinstance(target_device, str): - target_device = torch.device(target_device) - - # Call original 'to' method first to move data - result = original_to(*args, **kwargs) - - # If moved to CUDA, cleanup the mmap file after the move - if target_device is not None and target_device.type == 'cuda': - _cleanup() - - return result + # # Create custom 'to' method that cleans up file when moving to CUDA + # def custom_to(*args, **kwargs): + # # Determine target device + # target_device = None + # if len(args) > 0: + # if isinstance(args[0], torch.device): + # target_device = args[0] + # elif isinstance(args[0], str): + # target_device = torch.device(args[0]) + # if 'device' in kwargs: + # target_device = kwargs['device'] + # if isinstance(target_device, str): + # target_device = torch.device(target_device) + # + # # Call original 'to' method first to move data + # result = original_to(*args, **kwargs) + # + # # NOTE: Cleanup disabled to avoid blocking model load performance + # # If moved to CUDA, cleanup the mmap file after the move + # if target_device is not None and target_device.type == 'cuda': + # _cleanup() + # + # return result - # Replace the 'to' method - mmap_tensor.to = custom_to + # # Replace the 'to' method + # mmap_tensor.to = custom_to return mmap_tensor diff --git a/tests/execution/test_model_mmap.py b/tests/execution/test_model_mmap.py index 65dbe01bd..7a608c931 100644 --- a/tests/execution/test_model_mmap.py +++ b/tests/execution/test_model_mmap.py @@ -170,7 +170,7 @@ def test_to_mmap_cuda_cycle(): 3. GPU memory is freed when converting to mmap 4. mmap tensor can be moved back to CUDA 5. Data remains consistent throughout the cycle - 6. mmap file is automatically cleaned up when moved to CUDA + 6. mmap file is automatically cleaned up via garbage collection """ # Check if CUDA is available @@ -251,24 +251,26 @@ def test_to_mmap_cuda_cycle(): print(f" Difference: {sum_diff:.6f}") assert sum_diff < 0.01, f"Data should be consistent, but difference is {sum_diff:.6f}" - # Step 5: Verify file cleanup + # Step 5: Verify file cleanup (delayed until garbage collection) print("\n5. Verifying file cleanup...") + # Delete the mmap tensor reference to trigger garbage collection + del mmap_tensor gc.collect() import time time.sleep(0.1) # Give OS time to clean up temp_files_after = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')]) - print(f" Temp mmap files after: {temp_files_after}") - # File should be cleaned up when moved to CUDA - assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after moving to CUDA" + print(f" Temp mmap files after GC: {temp_files_after}") + # File should be cleaned up after garbage collection + assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after garbage collection" print("\nโœ“ Test passed!") print(" CUDA -> mmap -> CUDA cycle works correctly") print(f" CPU memory increase: {cpu_increase:.2f} GB < 0.1 GB (mmap efficiency)") print(" Data consistency maintained") - print(" File cleanup successful") + print(" File cleanup successful (via garbage collection)") # Cleanup - del mmap_tensor, cuda_tensor + del cuda_tensor # mmap_tensor already deleted in Step 5 gc.collect() torch.cuda.empty_cache() From 98ba3115110164d3c81ef01dc7b9790c67539328 Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 19:06:34 +0800 Subject: [PATCH 24/35] add env --- comfy/model_patcher.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4b0c5b9c5..f379c230b 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -42,6 +42,13 @@ from comfy.comfy_types import UnetWrapperFunction from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP from comfy.model_management import get_free_memory +def need_mmap() -> bool: + free_cpu_mem = get_free_memory(torch.device("cpu")) + mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "1024")) + if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024: + logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB") + return True + return False def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: """ @@ -905,8 +912,11 @@ class ModelPatcher: if device_to is not None: - # offload to mmap - model_to_mmap(self.model) + if need_mmap(): + # offload to mmap + model_to_mmap(self.model) + else: + self.model.to(device_to) self.model.device = device_to self.model.model_loaded_weight_memory = 0 @@ -961,9 +971,11 @@ class ModelPatcher: bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - # offload to mmap - # m.to(device_to) - model_to_mmap(m) + if need_mmap(): + # offload to mmap + model_to_mmap(m) + else: + m.to(device_to) module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: From aab0e244f7b221a00b9049f8dfaa0706185b22bd Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 23 Oct 2025 14:44:51 +0800 Subject: [PATCH 25/35] fix MMAP_MEM_THRESHOLD_GB default --- comfy/model_patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f379c230b..115e401b3 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -44,7 +44,7 @@ from comfy.model_management import get_free_memory def need_mmap() -> bool: free_cpu_mem = get_free_memory(torch.device("cpu")) - mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "1024")) + mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0")) if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024: logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB") return True From 58d28edade40ecd45ac7b20272f51a978a3045d2 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 23 Oct 2025 15:50:57 +0800 Subject: [PATCH 26/35] no limit for offload size --- comfy/model_management.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0dc471fb8..8bf4e68fb 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -526,17 +526,17 @@ class LoadedModel: logging.debug(f"offload_device: {self.model.offload_device}") available_memory = get_free_memory(self.model.offload_device) logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") - reserved_memory = 1024*1024*1024 # 1GB reserved memory for other usage - if available_memory < reserved_memory: - logging.warning(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB") - return False - else: - offload_memory = available_memory - reserved_memory - - if offload_memory < memory_to_free: - memory_to_free = offload_memory - logging.info(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB, Offload: {offload_memory/(1024*1024*1024)} GB") - logging.info(f"Set memory_to_free to {memory_to_free/(1024*1024*1024)} GB") + # reserved_memory = 1024*1024*1024 # 1GB reserved memory for other usage + # if available_memory < reserved_memory: + # logging.warning(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB") + # return False + # else: + # offload_memory = available_memory - reserved_memory + # + # if offload_memory < memory_to_free: + # memory_to_free = offload_memory + # logging.info(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB, Offload: {offload_memory/(1024*1024*1024)} GB") + # logging.info(f"Set memory_to_free to {memory_to_free/(1024*1024*1024)} GB") try: if memory_to_free is not None: if memory_to_free < self.model.loaded_size(): From c312733b8cd28010e3370716f1311d3b30067b13 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 23 Oct 2025 15:53:35 +0800 Subject: [PATCH 27/35] refine log --- comfy/model_management.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 8bf4e68fb..f4ed13899 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -553,6 +553,10 @@ class LoadedModel: available_memory = get_free_memory(self.model.offload_device) logging.info(f"after error, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") return False + finally: + available_memory = get_free_memory(self.model.offload_device) + logging.debug(f"after unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + self.model_finalizer.detach() self.model_finalizer = None self.real_model = None From dc7c77e78cb219f149c448cb961ae5122be7ce6b Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 23 Oct 2025 18:09:47 +0800 Subject: [PATCH 28/35] better partial unload --- comfy/model_management.py | 64 +++++++++++++++++++++++++-------------- comfy/model_patcher.py | 7 +++-- 2 files changed, 46 insertions(+), 25 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index f4ed13899..f2e23c446 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -26,6 +26,14 @@ import importlib import platform import weakref import gc +import os + +def get_mmap_mem_threshold_gb(): + mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0")) + return mmap_mem_threshold_gb + +def get_free_disk(): + return psutil.disk_usage("/").free class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -524,9 +532,7 @@ class LoadedModel: logging.debug(f"unpatch_weights: {unpatch_weights}") logging.debug(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB") logging.debug(f"offload_device: {self.model.offload_device}") - available_memory = get_free_memory(self.model.offload_device) - logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") - # reserved_memory = 1024*1024*1024 # 1GB reserved memory for other usage + # if available_memory < reserved_memory: # logging.warning(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB") # return False @@ -537,30 +543,42 @@ class LoadedModel: # memory_to_free = offload_memory # logging.info(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB, Offload: {offload_memory/(1024*1024*1024)} GB") # logging.info(f"Set memory_to_free to {memory_to_free/(1024*1024*1024)} GB") - try: - if memory_to_free is not None: - if memory_to_free < self.model.loaded_size(): - logging.debug("Do partially unload") - freed = self.model.partially_unload(self.model.offload_device, memory_to_free) - logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB") - if freed >= memory_to_free: - return False + + if memory_to_free is None: + # free the full model + memory_to_free = self.model.loaded_size() + + available_memory = get_free_memory(self.model.offload_device) + logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + + mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage + if memory_to_free > available_memory - mmap_mem_threshold or memory_to_free < self.model.loaded_size(): + partially_unload = True + else: + partially_unload = False + + if partially_unload: + logging.debug("Do partially unload") + freed = self.model.partially_unload(self.model.offload_device, memory_to_free) + logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB") + if freed < memory_to_free: + logging.warning(f"Partially unload not enough memory, freed {freed/(1024*1024*1024)} GB, memory_to_free {memory_to_free/(1024*1024*1024)} GB") + else: logging.debug("Do full unload") self.model.detach(unpatch_weights) logging.debug("Do full unload done") - except Exception as e: - logging.error(f"Error in model_unload: {e}") - available_memory = get_free_memory(self.model.offload_device) - logging.info(f"after error, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") - return False - finally: - available_memory = get_free_memory(self.model.offload_device) - logging.debug(f"after unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + self.model_finalizer.detach() + self.model_finalizer = None + self.real_model = None + + available_memory = get_free_memory(self.model.offload_device) + logging.debug(f"after unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + + if partially_unload: + return False + else: + return True - self.model_finalizer.detach() - self.model_finalizer = None - self.real_model = None - return True def model_use_more_vram(self, extra_memory, force_patch_weights=False): return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 115e401b3..361f15e5b 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -40,11 +40,11 @@ import comfy.patcher_extension import comfy.utils from comfy.comfy_types import UnetWrapperFunction from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP -from comfy.model_management import get_free_memory +from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk def need_mmap() -> bool: free_cpu_mem = get_free_memory(torch.device("cpu")) - mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0")) + mmap_mem_threshold_gb = get_mmap_mem_threshold_gb() if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024: logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB") return True @@ -972,6 +972,9 @@ class ModelPatcher: if move_weight: cast_weight = self.force_cast_weights if need_mmap(): + if get_free_disk() < module_mem: + logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB") + break # offload to mmap model_to_mmap(m) else: From 5c5fbddbbe71c986e43214002069d0edd1260445 Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 17 Nov 2025 15:34:50 +0800 Subject: [PATCH 29/35] debug mmap --- comfy/model_management.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index f2e23c446..a2ad5db2a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -28,8 +28,12 @@ import weakref import gc import os +from functools import lru_cache + +@lru_cache(maxsize=1) def get_mmap_mem_threshold_gb(): mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0")) + logging.debug(f"MMAP_MEM_THRESHOLD_GB: {mmap_mem_threshold_gb}") return mmap_mem_threshold_gb def get_free_disk(): From 7733d51c7670d467c5ce10fd3b40567857c74641 Mon Sep 17 00:00:00 2001 From: Xiaoyu Xu Date: Thu, 4 Dec 2025 15:45:36 +0800 Subject: [PATCH 30/35] try fix flux2 (#9) --- comfy/model_patcher.py | 50 +++++++----------------------------------- 1 file changed, 8 insertions(+), 42 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3bba7b35b..a09c9f2c0 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -41,6 +41,7 @@ import comfy.utils from comfy.comfy_types import UnetWrapperFunction from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk +from comfy.quant_ops import QuantizedTensor def need_mmap() -> bool: free_cpu_mem = get_free_memory(torch.device("cpu")) @@ -54,12 +55,6 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: """ Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support. """ - # Move to CPU if needed - if t.is_cuda: - cpu_tensor = t.cpu() - else: - cpu_tensor = t - # Create temporary file if filename is None: temp_file = tempfile.mktemp(suffix='.pt', prefix='comfy_mmap_') @@ -67,6 +62,7 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: temp_file = filename # Save tensor to file + cpu_tensor = t.cpu() torch.save(cpu_tensor, temp_file) # If we created a CPU copy from CUDA, delete it to free memory @@ -89,37 +85,7 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: pass weakref.finalize(mmap_tensor, _cleanup) - - # # Save original 'to' method - # original_to = mmap_tensor.to - - # # Create custom 'to' method that cleans up file when moving to CUDA - # def custom_to(*args, **kwargs): - # # Determine target device - # target_device = None - # if len(args) > 0: - # if isinstance(args[0], torch.device): - # target_device = args[0] - # elif isinstance(args[0], str): - # target_device = torch.device(args[0]) - # if 'device' in kwargs: - # target_device = kwargs['device'] - # if isinstance(target_device, str): - # target_device = torch.device(target_device) - # - # # Call original 'to' method first to move data - # result = original_to(*args, **kwargs) - # - # # NOTE: Cleanup disabled to avoid blocking model load performance - # # If moved to CUDA, cleanup the mmap file after the move - # if target_device is not None and target_device.type == 'cuda': - # _cleanup() - # - # return result - - # # Replace the 'to' method - # mmap_tensor.to = custom_to - + return mmap_tensor def model_to_mmap(model: torch.nn.Module): @@ -149,13 +115,13 @@ def model_to_mmap(model: torch.nn.Module): - For Parameters: modify .data and return the Parameter object - For buffers (plain Tensors): return new MemoryMappedTensor """ - if isinstance(t, torch.nn.Parameter): - # For parameters, modify data in-place and return the parameter - if isinstance(t.data, torch.Tensor): - t.data = to_mmap(t.data) + if isinstance(t, QuantizedTensor): + logging.debug(f"QuantizedTensor detected, skipping mmap conversion, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}") return t + elif isinstance(t, torch.nn.Parameter): + new_tensor = to_mmap(t.detach()) + return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad) elif isinstance(t, torch.Tensor): - # For buffers (plain tensors), return the converted tensor return to_mmap(t) return t From 1122cd0f6bc4fadcdbd0ac22553f85112906cb92 Mon Sep 17 00:00:00 2001 From: Xiaoyu Xu Date: Tue, 9 Dec 2025 18:07:09 +0800 Subject: [PATCH 31/35] allow offload quant (#10) * allow offload quant * rm cuda * refine and pass test --- comfy/model_patcher.py | 18 +++----- comfy/quant_ops.py | 45 ++++++++++++++++++- tests-unit/comfy_quant/test_quant_registry.py | 23 ++++++++++ .../test_model_mmap.py | 5 +++ 4 files changed, 77 insertions(+), 14 deletions(-) rename tests/{execution => inference}/test_model_mmap.py (98%) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index abbcbd9f8..da047ae8b 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -57,7 +57,7 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: """ # Create temporary file if filename is None: - temp_file = tempfile.mktemp(suffix='.pt', prefix='comfy_mmap_') + temp_file = tempfile.mkstemp(suffix='.pt', prefix='comfy_mmap_')[1] else: temp_file = filename @@ -65,12 +65,10 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: cpu_tensor = t.cpu() torch.save(cpu_tensor, temp_file) - # If we created a CPU copy from CUDA, delete it to free memory - if t.is_cuda: + # If we created a CPU copy from other device, delete it to free memory + if not t.device.type == 'cpu': del cpu_tensor gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() # Load with mmap - this doesn't load all data into RAM mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False) @@ -110,15 +108,9 @@ def model_to_mmap(model: torch.nn.Module): logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") def convert_fn(t): - """Convert function for _apply() - - - For Parameters: modify .data and return the Parameter object - - For buffers (plain Tensors): return new MemoryMappedTensor - """ if isinstance(t, QuantizedTensor): - logging.debug(f"QuantizedTensor detected, skipping mmap conversion, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}") - return t - elif isinstance(t, torch.nn.Parameter): + logging.debug(f"QuantizedTensor detected, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}") + if isinstance(t, torch.nn.Parameter): new_tensor = to_mmap(t.detach()) return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad) elif isinstance(t, torch.Tensor): diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 571d3f760..2f568967b 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -130,7 +130,19 @@ class QuantizedTensor(torch.Tensor): layout_type: Layout class (subclass of QuantizedLayout) layout_params: Dict with layout-specific parameters """ - return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) + # Use as_subclass so the QuantizedTensor instance shares the same + # storage and metadata as the underlying qdata tensor. This ensures + # torch.save/torch.load and the torch serialization storage scanning + # see a valid underlying storage (fixes data_ptr errors). + if not isinstance(qdata, torch.Tensor): + raise TypeError("qdata must be a torch.Tensor") + obj = qdata.as_subclass(cls) + # Ensure grad flag is consistent for quantized tensors + try: + obj.requires_grad_(False) + except Exception: + pass + return obj def __init__(self, qdata, layout_type, layout_params): self._qdata = qdata @@ -575,3 +587,34 @@ def fp8_func(func, args, kwargs): ar[0] = plain_input return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) return func(*args, **kwargs) + +def _rebuild_quantized_tensor(qdata, layout_type, layout_params): + """Rebuild QuantizedTensor during unpickling when qdata is already a tensor.""" + return QuantizedTensor(qdata, layout_type, layout_params) + + +def _rebuild_quantized_tensor_from_base(qdata_reduce, layout_type, layout_params): + """Rebuild QuantizedTensor during unpickling given the base tensor's reduce tuple. + + qdata_reduce is the tuple returned by qdata.__reduce_ex__(protocol) on the original + inner tensor. We call the provided rebuild function with its args to recreate the + inner tensor, then wrap it in QuantizedTensor. + """ + rebuild_fn, rebuild_args = qdata_reduce + qdata = rebuild_fn(*rebuild_args) + return QuantizedTensor(qdata, layout_type, layout_params) + + +# Register custom globals with torch.serialization so torch.load(..., weights_only=True) +# accepts these during unpickling. Wrapped in try/except for older PyTorch versions. +try: + import torch as _torch_serial + if hasattr(_torch_serial, "serialization") and hasattr(_torch_serial.serialization, "add_safe_globals"): + _torch_serial.serialization.add_safe_globals([ + QuantizedTensor, + _rebuild_quantized_tensor, + _rebuild_quantized_tensor_from_base, + ]) +except Exception: + # If add_safe_globals doesn't exist or registration fails, we silently continue. + pass diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 9cb54ede8..51d27dd26 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -47,6 +47,29 @@ class TestQuantizedTensor(unittest.TestCase): self.assertEqual(dequantized.dtype, torch.float32) self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + def test_save_load(self): + """Test creating a QuantizedTensor with TensorCoreFP8Layout""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} + + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._layout_params['scale'], scale) + self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) + self.assertEqual(qt._layout_type, "TensorCoreFP8Layout") + + torch.save(qt, "test.pt") + loaded_qt = torch.load("test.pt", weights_only=False) + # loaded_qt = torch.load("test.pt", map_location='cpu', mmap=True, weights_only=False) + + self.assertEqual(loaded_qt._layout_type, "TensorCoreFP8Layout") + self.assertEqual(loaded_qt._layout_params['scale'], scale) + self.assertEqual(loaded_qt._layout_params['orig_dtype'], torch.bfloat16) + def test_from_float(self): """Test creating QuantizedTensor from float tensor""" float_tensor = torch.randn(64, 32, dtype=torch.float32) diff --git a/tests/execution/test_model_mmap.py b/tests/inference/test_model_mmap.py similarity index 98% rename from tests/execution/test_model_mmap.py rename to tests/inference/test_model_mmap.py index 7a608c931..a7bff3bfc 100644 --- a/tests/execution/test_model_mmap.py +++ b/tests/inference/test_model_mmap.py @@ -5,6 +5,11 @@ import psutil import os import gc import tempfile +import sys + +# Ensure the project root is on the Python path (so `import comfy` works when running tests from this folder) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) + from comfy.model_patcher import model_to_mmap, to_mmap From 532eb01f0a041802aed9c2d757ca2be6e4604856 Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 9 Dec 2025 18:09:11 +0800 Subject: [PATCH 32/35] rm comment --- comfy/model_management.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index fe4da2751..d598a854c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -539,17 +539,6 @@ class LoadedModel: logging.debug(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB") logging.debug(f"offload_device: {self.model.offload_device}") - # if available_memory < reserved_memory: - # logging.warning(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB") - # return False - # else: - # offload_memory = available_memory - reserved_memory - # - # if offload_memory < memory_to_free: - # memory_to_free = offload_memory - # logging.info(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB, Offload: {offload_memory/(1024*1024*1024)} GB") - # logging.info(f"Set memory_to_free to {memory_to_free/(1024*1024*1024)} GB") - if memory_to_free is None: # free the full model memory_to_free = self.model.loaded_size() From 2c5b9da6c47f0eab24f59cf4bed6de02ee797f75 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 12 Dec 2025 17:50:35 +0800 Subject: [PATCH 33/35] rm debug log --- comfy/model_base.py | 10 +--------- comfy/model_management.py | 19 +------------------ comfy/sd.py | 1 - comfy/utils.py | 3 --- 4 files changed, 2 insertions(+), 31 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index c3c32810b..6b8a8454d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -61,7 +61,6 @@ import math from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher -from comfy.model_management import get_free_memory class ModelType(Enum): EPS = 1 @@ -305,15 +304,8 @@ class BaseModel(torch.nn.Module): if k.startswith(unet_prefix): to_load[k[len(unet_prefix):]] = sd.pop(k) - free_cpu_memory = get_free_memory(torch.device("cpu")) - logging.debug(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") - logging.debug(f"model destination device {next(self.diffusion_model.parameters()).device}") to_load = self.model_config.process_unet_state_dict(to_load) - logging.debug(f"load model {self.model_config} weights process end") - # replace tensor with mmap tensor by assign - m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True) - free_cpu_memory = get_free_memory(torch.device("cpu")) - logging.debug(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") + m, u = self.diffusion_model.load_state_dict(to_load, strict=False) if len(m) > 0: logging.warning("unet missing: {}".format(m)) diff --git a/comfy/model_management.py b/comfy/model_management.py index d598a854c..5105111c6 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -533,18 +533,11 @@ class LoadedModel: return False def model_unload(self, memory_to_free=None, unpatch_weights=True): - logging.debug(f"model_unload: {self.model.model.__class__.__name__}") - logging.debug(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB") - logging.debug(f"unpatch_weights: {unpatch_weights}") - logging.debug(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB") - logging.debug(f"offload_device: {self.model.offload_device}") - if memory_to_free is None: # free the full model memory_to_free = self.model.loaded_size() available_memory = get_free_memory(self.model.offload_device) - logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage if memory_to_free > available_memory - mmap_mem_threshold or memory_to_free < self.model.loaded_size(): @@ -553,22 +546,15 @@ class LoadedModel: partially_unload = False if partially_unload: - logging.debug("Do partially unload") freed = self.model.partially_unload(self.model.offload_device, memory_to_free) - logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB") if freed < memory_to_free: - logging.warning(f"Partially unload not enough memory, freed {freed/(1024*1024*1024)} GB, memory_to_free {memory_to_free/(1024*1024*1024)} GB") + logging.debug(f"Partially unload not enough memory, freed {freed/(1024*1024*1024)} GB, memory_to_free {memory_to_free/(1024*1024*1024)} GB") else: - logging.debug("Do full unload") self.model.detach(unpatch_weights) - logging.debug("Do full unload done") self.model_finalizer.detach() self.model_finalizer = None self.real_model = None - available_memory = get_free_memory(self.model.offload_device) - logging.debug(f"after unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") - if partially_unload: return False else: @@ -622,7 +608,6 @@ def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() def free_memory(memory_required, device, keep_loaded=[]): - logging.debug("start to free mem") cleanup_models_gc() unloaded_model = [] can_unload = [] @@ -660,7 +645,6 @@ def free_memory(memory_required, device, keep_loaded=[]): return unloaded_models def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): - logging.debug(f"start to load models") cleanup_models_gc() global vram_state @@ -682,7 +666,6 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models_to_load = [] for x in models: - logging.debug(f"start loading model to vram: {x.model.__class__.__name__}") loaded_model = LoadedModel(x) try: loaded_model_index = current_loaded_models.index(loaded_model) diff --git a/comfy/sd.py b/comfy/sd.py index 7c00337a6..754b1703d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1466,7 +1466,6 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): logging.warning("{} {}".format(diffusers_keys[k], k)) offload_device = model_management.unet_offload_device() - logging.debug(f"loader load model to offload device: {offload_device}") unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.quant_config is not None: weight_dtype = None diff --git a/comfy/utils.py b/comfy/utils.py index 8dc33a411..89846bc95 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -61,8 +61,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): metadata = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: - if not DISABLE_MMAP: - logging.debug(f"load_torch_file of safetensors into mmap True") with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: sd = {} for k in f.keys(): @@ -83,7 +81,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): else: torch_args = {} if MMAP_TORCH_FILES: - logging.debug(f"load_torch_file of torch state dict into mmap True") torch_args["mmap"] = True if safe_load or ALWAYS_SAFE_LOAD: From 5495b55ab2f301e89703acbf5f81eb177d8fe114 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 12 Dec 2025 18:03:09 +0800 Subject: [PATCH 34/35] rm useless --- comfy/model_patcher.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index da047ae8b..d3c69f614 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -89,15 +89,6 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: def model_to_mmap(model: torch.nn.Module): """Convert all parameters and buffers to memory-mapped tensors - This function mimics PyTorch's Module.to() behavior but converts - tensors to memory-mapped format instead, using _apply() method. - - Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244 - - Note: For Parameters, we modify .data in-place because - MemoryMappedTensor cannot be wrapped in torch.nn.Parameter. - For buffers, _apply() will automatically update the reference. - Args: model: PyTorch module to convert @@ -108,8 +99,6 @@ def model_to_mmap(model: torch.nn.Module): logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") def convert_fn(t): - if isinstance(t, QuantizedTensor): - logging.debug(f"QuantizedTensor detected, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}") if isinstance(t, torch.nn.Parameter): new_tensor = to_mmap(t.detach()) return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad) From fa674cc60df330960edfa8d0d417482204710efb Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 15 Dec 2025 18:47:35 +0800 Subject: [PATCH 35/35] refine --- comfy/cli_args.py | 1 + comfy/model_management.py | 20 +++++++++++++++----- comfy/model_patcher.py | 18 +++++++++++------- 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 209fc185b..f83afa258 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -136,6 +136,7 @@ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn' vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.") +parser.add_argument("--offload-reserve-ram-gb", type=float, default=None, help="Set the amount of ram in GB you want to reserve for other use. When the limit is reached, model on vram will be offloaded to mmap to save ram.") parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 5105111c6..c1ebb1282 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,10 +31,20 @@ import os from functools import lru_cache @lru_cache(maxsize=1) -def get_mmap_mem_threshold_gb(): - mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0")) - logging.debug(f"MMAP_MEM_THRESHOLD_GB: {mmap_mem_threshold_gb}") - return mmap_mem_threshold_gb +def get_offload_reserve_ram_gb(): + offload_reserve_ram_gb = 0 + try: + val = getattr(args, 'offload-reserve-ram-gb', None) + except Exception: + val = None + + if val is not None: + try: + offload_reserve_ram_gb = int(val) + except Exception: + logging.warning(f"Invalid args.offload-reserve-ram-gb value: {val}, defaulting to 0") + offload_reserve_ram_gb= 0 + return offload_reserve_ram_gb def get_free_disk(): return psutil.disk_usage("/").free @@ -613,7 +623,7 @@ def free_memory(memory_required, device, keep_loaded=[]): can_unload = [] unloaded_models = [] - for i in range(len(current_loaded_models) -1, -1, -1): + for i in range(len(current_loaded_models) -1, -1): shift_model = current_loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded and not shift_model.is_dead(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d3c69f614..5d6330321 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -40,15 +40,19 @@ import comfy.patcher_extension import comfy.utils from comfy.comfy_types import UnetWrapperFunction from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP -from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk +from comfy.model_management import get_free_memory, get_offload_reserve_ram_gb, get_free_disk from comfy.quant_ops import QuantizedTensor -def need_mmap() -> bool: +def enable_offload_to_mmap() -> bool: + if comfy.utils.DISABLE_MMAP: + return False + free_cpu_mem = get_free_memory(torch.device("cpu")) - mmap_mem_threshold_gb = get_mmap_mem_threshold_gb() - if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024: - logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB") + offload_reserve_ram_gb = get_offload_reserve_ram_gb() + if free_cpu_mem <= offload_reserve_ram_gb * 1024 * 1024 * 1024: + logging.debug(f"Enabling offload to mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {offload_reserve_ram_gb} GB") return True + return False def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: @@ -917,7 +921,7 @@ class ModelPatcher: if device_to is not None: - if need_mmap(): + if enable_offload_to_mmap(): # offload to mmap model_to_mmap(self.model) else: @@ -982,7 +986,7 @@ class ModelPatcher: bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - if need_mmap(): + if enable_offload_to_mmap(): if get_free_disk() < module_mem: logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB") break