mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Improvements to torch.compile
- the model weights generally have to be patched ahead of time for compilation to work - the model downloader matches the folder_paths API a bit better - tweak the logging from the execution node
This commit is contained in:
parent
83184916c1
commit
3c9d311dee
@ -1145,7 +1145,6 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
|
||||
for reason in reasons:
|
||||
logger.error(f" - {reason['message']}: {reason['details']}")
|
||||
node_errors[node_id]["dependent_outputs"].append(o)
|
||||
logger.error("Output will be ignored")
|
||||
|
||||
if len(good_outputs) == 0:
|
||||
errors_list = []
|
||||
|
||||
@ -58,6 +58,10 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||
return res
|
||||
|
||||
|
||||
def get_full_path(folder_name: str, filename: str) -> Optional[str]:
|
||||
return get_or_download(folder_name, filename)
|
||||
|
||||
|
||||
def get_or_download(folder_name: str, filename: str, known_files: Optional[List[Downloadable] | KnownDownloadables] = None) -> Optional[str]:
|
||||
if known_files is None:
|
||||
known_files = _get_known_models_for_folder_name(folder_name)
|
||||
|
||||
@ -766,10 +766,10 @@ class ModelPatcher(ModelManageable):
|
||||
x.module.to(device_to)
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logger.debug("loaded partially lowvram_model_memory={}MB mem_counter={}MB patch_counter={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
logger.debug(f"loaded partially lowvram_model_memory={lowvram_model_memory / (1024 * 1024):.1f}MB mem_counter={mem_counter / (1024 * 1024):.1f}MB patch_counter={patch_counter}")
|
||||
self._memory_measurements.model_lowvram = True
|
||||
else:
|
||||
logger.debug("loaded completely lowvram_model_memory={}MB mem_counter={}MB full_load={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
logger.debug(f"loaded completely lowvram_model_memory={lowvram_model_memory / (1024 * 1024):.1f}MB mem_counter={mem_counter / (1024 * 1024):.1f}MB full_load={full_load}")
|
||||
self._memory_measurements.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
|
||||
@ -1,38 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
from comfy import model_management
|
||||
from comfy.patcher_extension import WrappersMP
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.patcher_extension import WrapperExecutor
|
||||
|
||||
|
||||
COMPILE_KEY = "torch.compile"
|
||||
TORCH_COMPILE_KWARGS = "torch_compile_kwargs"
|
||||
|
||||
|
||||
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable:
|
||||
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable], model_patcher: Optional[ModelPatcher] = None) -> Callable:
|
||||
'''
|
||||
Create a wrapper that will refer to the compiled_diffusion_model.
|
||||
'''
|
||||
|
||||
def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs):
|
||||
try:
|
||||
orig_modules = {}
|
||||
for key, value in compiled_module_dict.items():
|
||||
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
|
||||
comfy.utils.set_attr(executor.class_obj, key, value)
|
||||
# todo: compilation has to patch all weights
|
||||
if model_patcher is not None:
|
||||
model_patcher.patch_model(device_to=model_management.get_torch_device(), force_patch_weights=True)
|
||||
return executor(*args, **kwargs)
|
||||
finally:
|
||||
for key, value in orig_modules.items():
|
||||
comfy.utils.set_attr(executor.class_obj, key, value)
|
||||
|
||||
return apply_torch_compile_wrapper
|
||||
|
||||
|
||||
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None,
|
||||
mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None,
|
||||
keys: list[str]=["diffusion_model"], *args, **kwargs):
|
||||
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str, str]] = None,
|
||||
mode: Optional[str] = None, fullgraph=False, dynamic: Optional[bool] = None,
|
||||
keys: list[str] = ["diffusion_model"], *args, **kwargs):
|
||||
'''
|
||||
Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance.
|
||||
|
||||
@ -56,12 +64,13 @@ def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Option
|
||||
compiled_modules = {}
|
||||
for key in keys:
|
||||
compiled_modules[key] = torch.compile(
|
||||
model=model.get_model_object(key),
|
||||
**compile_kwargs,
|
||||
)
|
||||
model=model.get_model_object(key),
|
||||
**compile_kwargs,
|
||||
)
|
||||
# add torch.compile wrapper
|
||||
wrapper_func = apply_torch_compile_factory(
|
||||
compiled_module_dict=compiled_modules,
|
||||
model_patcher=model,
|
||||
)
|
||||
# store wrapper to run on BaseModel's apply_model function
|
||||
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user