Improvements to torch.compile

- the model weights generally have to be patched ahead of time for compilation to work
 - the model downloader matches the folder_paths API a bit better
 - tweak the logging from the execution node
This commit is contained in:
doctorpangloss 2025-07-30 19:27:40 -07:00
parent 83184916c1
commit 3c9d311dee
4 changed files with 24 additions and 12 deletions

View File

@ -1145,7 +1145,6 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
for reason in reasons:
logger.error(f" - {reason['message']}: {reason['details']}")
node_errors[node_id]["dependent_outputs"].append(o)
logger.error("Output will be ignored")
if len(good_outputs) == 0:
errors_list = []

View File

@ -58,6 +58,10 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str:
return res
def get_full_path(folder_name: str, filename: str) -> Optional[str]:
return get_or_download(folder_name, filename)
def get_or_download(folder_name: str, filename: str, known_files: Optional[List[Downloadable] | KnownDownloadables] = None) -> Optional[str]:
if known_files is None:
known_files = _get_known_models_for_folder_name(folder_name)

View File

@ -766,10 +766,10 @@ class ModelPatcher(ModelManageable):
x.module.to(device_to)
if lowvram_counter > 0:
logger.debug("loaded partially lowvram_model_memory={}MB mem_counter={}MB patch_counter={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
logger.debug(f"loaded partially lowvram_model_memory={lowvram_model_memory / (1024 * 1024):.1f}MB mem_counter={mem_counter / (1024 * 1024):.1f}MB patch_counter={patch_counter}")
self._memory_measurements.model_lowvram = True
else:
logger.debug("loaded completely lowvram_model_memory={}MB mem_counter={}MB full_load={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
logger.debug(f"loaded completely lowvram_model_memory={lowvram_model_memory / (1024 * 1024):.1f}MB mem_counter={mem_counter / (1024 * 1024):.1f}MB full_load={full_load}")
self._memory_measurements.model_lowvram = False
if full_load:
self.model.to(device_to)

View File

@ -1,38 +1,46 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable, Optional
import torch
import comfy.utils
from comfy import model_management
from comfy.patcher_extension import WrappersMP
from typing import TYPE_CHECKING, Callable, Optional
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.patcher_extension import WrapperExecutor
COMPILE_KEY = "torch.compile"
TORCH_COMPILE_KWARGS = "torch_compile_kwargs"
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable:
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable], model_patcher: Optional[ModelPatcher] = None) -> Callable:
'''
Create a wrapper that will refer to the compiled_diffusion_model.
'''
def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs):
try:
orig_modules = {}
for key, value in compiled_module_dict.items():
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
comfy.utils.set_attr(executor.class_obj, key, value)
# todo: compilation has to patch all weights
if model_patcher is not None:
model_patcher.patch_model(device_to=model_management.get_torch_device(), force_patch_weights=True)
return executor(*args, **kwargs)
finally:
for key, value in orig_modules.items():
comfy.utils.set_attr(executor.class_obj, key, value)
return apply_torch_compile_wrapper
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None,
mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None,
keys: list[str]=["diffusion_model"], *args, **kwargs):
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str, str]] = None,
mode: Optional[str] = None, fullgraph=False, dynamic: Optional[bool] = None,
keys: list[str] = ["diffusion_model"], *args, **kwargs):
'''
Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance.
@ -56,12 +64,13 @@ def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Option
compiled_modules = {}
for key in keys:
compiled_modules[key] = torch.compile(
model=model.get_model_object(key),
**compile_kwargs,
)
model=model.get_model_object(key),
**compile_kwargs,
)
# add torch.compile wrapper
wrapper_func = apply_torch_compile_factory(
compiled_module_dict=compiled_modules,
model_patcher=model,
)
# store wrapper to run on BaseModel's apply_model function
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)