mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Improvements to torch.compile
- the model weights generally have to be patched ahead of time for compilation to work - the model downloader matches the folder_paths API a bit better - tweak the logging from the execution node
This commit is contained in:
parent
83184916c1
commit
3c9d311dee
@ -1145,7 +1145,6 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
|
|||||||
for reason in reasons:
|
for reason in reasons:
|
||||||
logger.error(f" - {reason['message']}: {reason['details']}")
|
logger.error(f" - {reason['message']}: {reason['details']}")
|
||||||
node_errors[node_id]["dependent_outputs"].append(o)
|
node_errors[node_id]["dependent_outputs"].append(o)
|
||||||
logger.error("Output will be ignored")
|
|
||||||
|
|
||||||
if len(good_outputs) == 0:
|
if len(good_outputs) == 0:
|
||||||
errors_list = []
|
errors_list = []
|
||||||
|
|||||||
@ -58,6 +58,10 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def get_full_path(folder_name: str, filename: str) -> Optional[str]:
|
||||||
|
return get_or_download(folder_name, filename)
|
||||||
|
|
||||||
|
|
||||||
def get_or_download(folder_name: str, filename: str, known_files: Optional[List[Downloadable] | KnownDownloadables] = None) -> Optional[str]:
|
def get_or_download(folder_name: str, filename: str, known_files: Optional[List[Downloadable] | KnownDownloadables] = None) -> Optional[str]:
|
||||||
if known_files is None:
|
if known_files is None:
|
||||||
known_files = _get_known_models_for_folder_name(folder_name)
|
known_files = _get_known_models_for_folder_name(folder_name)
|
||||||
|
|||||||
@ -766,10 +766,10 @@ class ModelPatcher(ModelManageable):
|
|||||||
x.module.to(device_to)
|
x.module.to(device_to)
|
||||||
|
|
||||||
if lowvram_counter > 0:
|
if lowvram_counter > 0:
|
||||||
logger.debug("loaded partially lowvram_model_memory={}MB mem_counter={}MB patch_counter={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
logger.debug(f"loaded partially lowvram_model_memory={lowvram_model_memory / (1024 * 1024):.1f}MB mem_counter={mem_counter / (1024 * 1024):.1f}MB patch_counter={patch_counter}")
|
||||||
self._memory_measurements.model_lowvram = True
|
self._memory_measurements.model_lowvram = True
|
||||||
else:
|
else:
|
||||||
logger.debug("loaded completely lowvram_model_memory={}MB mem_counter={}MB full_load={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
logger.debug(f"loaded completely lowvram_model_memory={lowvram_model_memory / (1024 * 1024):.1f}MB mem_counter={mem_counter / (1024 * 1024):.1f}MB full_load={full_load}")
|
||||||
self._memory_measurements.model_lowvram = False
|
self._memory_measurements.model_lowvram = False
|
||||||
if full_load:
|
if full_load:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
|
|||||||
@ -1,38 +1,46 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
from comfy import model_management
|
||||||
from comfy.patcher_extension import WrappersMP
|
from comfy.patcher_extension import WrappersMP
|
||||||
from typing import TYPE_CHECKING, Callable, Optional
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
from comfy.patcher_extension import WrapperExecutor
|
from comfy.patcher_extension import WrapperExecutor
|
||||||
|
|
||||||
|
|
||||||
COMPILE_KEY = "torch.compile"
|
COMPILE_KEY = "torch.compile"
|
||||||
TORCH_COMPILE_KWARGS = "torch_compile_kwargs"
|
TORCH_COMPILE_KWARGS = "torch_compile_kwargs"
|
||||||
|
|
||||||
|
|
||||||
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable:
|
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable], model_patcher: Optional[ModelPatcher] = None) -> Callable:
|
||||||
'''
|
'''
|
||||||
Create a wrapper that will refer to the compiled_diffusion_model.
|
Create a wrapper that will refer to the compiled_diffusion_model.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs):
|
def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
orig_modules = {}
|
orig_modules = {}
|
||||||
for key, value in compiled_module_dict.items():
|
for key, value in compiled_module_dict.items():
|
||||||
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
|
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
|
||||||
comfy.utils.set_attr(executor.class_obj, key, value)
|
comfy.utils.set_attr(executor.class_obj, key, value)
|
||||||
|
# todo: compilation has to patch all weights
|
||||||
|
if model_patcher is not None:
|
||||||
|
model_patcher.patch_model(device_to=model_management.get_torch_device(), force_patch_weights=True)
|
||||||
return executor(*args, **kwargs)
|
return executor(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
for key, value in orig_modules.items():
|
for key, value in orig_modules.items():
|
||||||
comfy.utils.set_attr(executor.class_obj, key, value)
|
comfy.utils.set_attr(executor.class_obj, key, value)
|
||||||
|
|
||||||
return apply_torch_compile_wrapper
|
return apply_torch_compile_wrapper
|
||||||
|
|
||||||
|
|
||||||
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None,
|
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str, str]] = None,
|
||||||
mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None,
|
mode: Optional[str] = None, fullgraph=False, dynamic: Optional[bool] = None,
|
||||||
keys: list[str]=["diffusion_model"], *args, **kwargs):
|
keys: list[str] = ["diffusion_model"], *args, **kwargs):
|
||||||
'''
|
'''
|
||||||
Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance.
|
Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance.
|
||||||
|
|
||||||
@ -56,12 +64,13 @@ def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Option
|
|||||||
compiled_modules = {}
|
compiled_modules = {}
|
||||||
for key in keys:
|
for key in keys:
|
||||||
compiled_modules[key] = torch.compile(
|
compiled_modules[key] = torch.compile(
|
||||||
model=model.get_model_object(key),
|
model=model.get_model_object(key),
|
||||||
**compile_kwargs,
|
**compile_kwargs,
|
||||||
)
|
)
|
||||||
# add torch.compile wrapper
|
# add torch.compile wrapper
|
||||||
wrapper_func = apply_torch_compile_factory(
|
wrapper_func = apply_torch_compile_factory(
|
||||||
compiled_module_dict=compiled_modules,
|
compiled_module_dict=compiled_modules,
|
||||||
|
model_patcher=model,
|
||||||
)
|
)
|
||||||
# store wrapper to run on BaseModel's apply_model function
|
# store wrapper to run on BaseModel's apply_model function
|
||||||
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)
|
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user