Tweak torch compile

This commit is contained in:
doctorpangloss 2025-09-12 14:58:26 -07:00
parent 112017f2d2
commit 10448e67ef

View File

@ -16,7 +16,7 @@ COMPILE_KEY = "torch.compile"
TORCH_COMPILE_KWARGS = "torch_compile_kwargs"
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable], model_patcher: Optional[ModelPatcher] = None) -> Callable:
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable:
'''
Create a wrapper that will refer to the compiled_diffusion_model.
'''
@ -67,7 +67,6 @@ def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Option
# add torch.compile wrapper
wrapper_func = apply_torch_compile_factory(
compiled_module_dict=compiled_modules,
model_patcher=model,
)
# store wrapper to run on BaseModel's apply_model function
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)