From 3c9d311dee8570b915a380c8828958d675158b38 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Wed, 30 Jul 2025 19:27:40 -0700 Subject: [PATCH] Improvements to torch.compile - the model weights generally have to be patched ahead of time for compilation to work - the model downloader matches the folder_paths API a bit better - tweak the logging from the execution node --- comfy/cmd/execution.py | 1 - comfy/model_downloader.py | 4 ++++ comfy/model_patcher.py | 4 ++-- comfy_api/torch_helpers/torch_compile.py | 27 ++++++++++++++++-------- 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 0f2caa428..e6f833432 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -1145,7 +1145,6 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty for reason in reasons: logger.error(f" - {reason['message']}: {reason['details']}") node_errors[node_id]["dependent_outputs"].append(o) - logger.error("Output will be ignored") if len(good_outputs) == 0: errors_list = [] diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 53b38141e..b43d75ab2 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -58,6 +58,10 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str: return res +def get_full_path(folder_name: str, filename: str) -> Optional[str]: + return get_or_download(folder_name, filename) + + def get_or_download(folder_name: str, filename: str, known_files: Optional[List[Downloadable] | KnownDownloadables] = None) -> Optional[str]: if known_files is None: known_files = _get_known_models_for_folder_name(folder_name) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 5dc31c8e4..0075c2a35 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -766,10 +766,10 @@ class ModelPatcher(ModelManageable): x.module.to(device_to) if lowvram_counter > 0: - logger.debug("loaded partially lowvram_model_memory={}MB mem_counter={}MB patch_counter={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) + logger.debug(f"loaded partially lowvram_model_memory={lowvram_model_memory / (1024 * 1024):.1f}MB mem_counter={mem_counter / (1024 * 1024):.1f}MB patch_counter={patch_counter}") self._memory_measurements.model_lowvram = True else: - logger.debug("loaded completely lowvram_model_memory={}MB mem_counter={}MB full_load={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + logger.debug(f"loaded completely lowvram_model_memory={lowvram_model_memory / (1024 * 1024):.1f}MB mem_counter={mem_counter / (1024 * 1024):.1f}MB full_load={full_load}") self._memory_measurements.model_lowvram = False if full_load: self.model.to(device_to) diff --git a/comfy_api/torch_helpers/torch_compile.py b/comfy_api/torch_helpers/torch_compile.py index 9223f58db..cabe2bd35 100644 --- a/comfy_api/torch_helpers/torch_compile.py +++ b/comfy_api/torch_helpers/torch_compile.py @@ -1,38 +1,46 @@ from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Optional + import torch import comfy.utils +from comfy import model_management from comfy.patcher_extension import WrappersMP -from typing import TYPE_CHECKING, Callable, Optional + if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher from comfy.patcher_extension import WrapperExecutor - COMPILE_KEY = "torch.compile" TORCH_COMPILE_KWARGS = "torch_compile_kwargs" -def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable: +def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable], model_patcher: Optional[ModelPatcher] = None) -> Callable: ''' Create a wrapper that will refer to the compiled_diffusion_model. ''' + def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs): try: orig_modules = {} for key, value in compiled_module_dict.items(): orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key) comfy.utils.set_attr(executor.class_obj, key, value) + # todo: compilation has to patch all weights + if model_patcher is not None: + model_patcher.patch_model(device_to=model_management.get_torch_device(), force_patch_weights=True) return executor(*args, **kwargs) finally: for key, value in orig_modules.items(): comfy.utils.set_attr(executor.class_obj, key, value) + return apply_torch_compile_wrapper -def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None, - mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None, - keys: list[str]=["diffusion_model"], *args, **kwargs): +def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str, str]] = None, + mode: Optional[str] = None, fullgraph=False, dynamic: Optional[bool] = None, + keys: list[str] = ["diffusion_model"], *args, **kwargs): ''' Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance. @@ -56,12 +64,13 @@ def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Option compiled_modules = {} for key in keys: compiled_modules[key] = torch.compile( - model=model.get_model_object(key), - **compile_kwargs, - ) + model=model.get_model_object(key), + **compile_kwargs, + ) # add torch.compile wrapper wrapper_func = apply_torch_compile_factory( compiled_module_dict=compiled_modules, + model_patcher=model, ) # store wrapper to run on BaseModel's apply_model function model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)