Improvements to torch.compile

- the model weights generally have to be patched ahead of time for compilation to work
 - the model downloader matches the folder_paths API a bit better
 - tweak the logging from the execution node
This commit is contained in:
doctorpangloss 2025-07-30 19:27:40 -07:00
parent 83184916c1
commit 3c9d311dee
4 changed files with 24 additions and 12 deletions

View File

@ -1145,7 +1145,6 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
for reason in reasons: for reason in reasons:
logger.error(f" - {reason['message']}: {reason['details']}") logger.error(f" - {reason['message']}: {reason['details']}")
node_errors[node_id]["dependent_outputs"].append(o) node_errors[node_id]["dependent_outputs"].append(o)
logger.error("Output will be ignored")
if len(good_outputs) == 0: if len(good_outputs) == 0:
errors_list = [] errors_list = []

View File

@ -58,6 +58,10 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str:
return res return res
def get_full_path(folder_name: str, filename: str) -> Optional[str]:
return get_or_download(folder_name, filename)
def get_or_download(folder_name: str, filename: str, known_files: Optional[List[Downloadable] | KnownDownloadables] = None) -> Optional[str]: def get_or_download(folder_name: str, filename: str, known_files: Optional[List[Downloadable] | KnownDownloadables] = None) -> Optional[str]:
if known_files is None: if known_files is None:
known_files = _get_known_models_for_folder_name(folder_name) known_files = _get_known_models_for_folder_name(folder_name)

View File

@ -766,10 +766,10 @@ class ModelPatcher(ModelManageable):
x.module.to(device_to) x.module.to(device_to)
if lowvram_counter > 0: if lowvram_counter > 0:
logger.debug("loaded partially lowvram_model_memory={}MB mem_counter={}MB patch_counter={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) logger.debug(f"loaded partially lowvram_model_memory={lowvram_model_memory / (1024 * 1024):.1f}MB mem_counter={mem_counter / (1024 * 1024):.1f}MB patch_counter={patch_counter}")
self._memory_measurements.model_lowvram = True self._memory_measurements.model_lowvram = True
else: else:
logger.debug("loaded completely lowvram_model_memory={}MB mem_counter={}MB full_load={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) logger.debug(f"loaded completely lowvram_model_memory={lowvram_model_memory / (1024 * 1024):.1f}MB mem_counter={mem_counter / (1024 * 1024):.1f}MB full_load={full_load}")
self._memory_measurements.model_lowvram = False self._memory_measurements.model_lowvram = False
if full_load: if full_load:
self.model.to(device_to) self.model.to(device_to)

View File

@ -1,38 +1,46 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Callable, Optional
import torch import torch
import comfy.utils import comfy.utils
from comfy import model_management
from comfy.patcher_extension import WrappersMP from comfy.patcher_extension import WrappersMP
from typing import TYPE_CHECKING, Callable, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher from comfy.model_patcher import ModelPatcher
from comfy.patcher_extension import WrapperExecutor from comfy.patcher_extension import WrapperExecutor
COMPILE_KEY = "torch.compile" COMPILE_KEY = "torch.compile"
TORCH_COMPILE_KWARGS = "torch_compile_kwargs" TORCH_COMPILE_KWARGS = "torch_compile_kwargs"
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable: def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable], model_patcher: Optional[ModelPatcher] = None) -> Callable:
''' '''
Create a wrapper that will refer to the compiled_diffusion_model. Create a wrapper that will refer to the compiled_diffusion_model.
''' '''
def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs): def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs):
try: try:
orig_modules = {} orig_modules = {}
for key, value in compiled_module_dict.items(): for key, value in compiled_module_dict.items():
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key) orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
comfy.utils.set_attr(executor.class_obj, key, value) comfy.utils.set_attr(executor.class_obj, key, value)
# todo: compilation has to patch all weights
if model_patcher is not None:
model_patcher.patch_model(device_to=model_management.get_torch_device(), force_patch_weights=True)
return executor(*args, **kwargs) return executor(*args, **kwargs)
finally: finally:
for key, value in orig_modules.items(): for key, value in orig_modules.items():
comfy.utils.set_attr(executor.class_obj, key, value) comfy.utils.set_attr(executor.class_obj, key, value)
return apply_torch_compile_wrapper return apply_torch_compile_wrapper
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None, def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str, str]] = None,
mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None, mode: Optional[str] = None, fullgraph=False, dynamic: Optional[bool] = None,
keys: list[str]=["diffusion_model"], *args, **kwargs): keys: list[str] = ["diffusion_model"], *args, **kwargs):
''' '''
Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance. Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance.
@ -56,12 +64,13 @@ def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Option
compiled_modules = {} compiled_modules = {}
for key in keys: for key in keys:
compiled_modules[key] = torch.compile( compiled_modules[key] = torch.compile(
model=model.get_model_object(key), model=model.get_model_object(key),
**compile_kwargs, **compile_kwargs,
) )
# add torch.compile wrapper # add torch.compile wrapper
wrapper_func = apply_torch_compile_factory( wrapper_func = apply_torch_compile_factory(
compiled_module_dict=compiled_modules, compiled_module_dict=compiled_modules,
model_patcher=model,
) )
# store wrapper to run on BaseModel's apply_model function # store wrapper to run on BaseModel's apply_model function
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func) model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)