From 7ce781b0fd4ee268427847e8f6a26f6ccf47130c Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Fri, 22 Aug 2025 17:43:12 -0700 Subject: [PATCH] Fix torch compile --- comfy/model_management_types.py | 7 ++++++- comfy_api/torch_helpers/torch_compile.py | 3 --- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/comfy/model_management_types.py b/comfy/model_management_types.py index 73f4bf1fa..4f7e5e34c 100644 --- a/comfy/model_management_types.py +++ b/comfy/model_management_types.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses +from abc import ABCMeta, abstractmethod from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple import torch @@ -24,7 +25,7 @@ class DeviceSettable(Protocol): ... -class ModelManageable(Protocol): +class ModelManageable(Protocol, metaclass=ABCMeta): """ Objects which implement this protocol can be managed by @@ -137,6 +138,10 @@ class ModelManageable(Protocol): def force_cast_weights(self) -> bool: return False + @abstractmethod + def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): + pass + @dataclasses.dataclass class MemoryMeasurements: diff --git a/comfy_api/torch_helpers/torch_compile.py b/comfy_api/torch_helpers/torch_compile.py index cabe2bd35..721fd43aa 100644 --- a/comfy_api/torch_helpers/torch_compile.py +++ b/comfy_api/torch_helpers/torch_compile.py @@ -27,9 +27,6 @@ def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable], model for key, value in compiled_module_dict.items(): orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key) comfy.utils.set_attr(executor.class_obj, key, value) - # todo: compilation has to patch all weights - if model_patcher is not None: - model_patcher.patch_model(device_to=model_management.get_torch_device(), force_patch_weights=True) return executor(*args, **kwargs) finally: for key, value in orig_modules.items():