diff --git a/comfy/model_management_types.py b/comfy/model_management_types.py index 73f4bf1fa..4f7e5e34c 100644 --- a/comfy/model_management_types.py +++ b/comfy/model_management_types.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses +from abc import ABCMeta, abstractmethod from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple import torch @@ -24,7 +25,7 @@ class DeviceSettable(Protocol): ... -class ModelManageable(Protocol): +class ModelManageable(Protocol, metaclass=ABCMeta): """ Objects which implement this protocol can be managed by @@ -137,6 +138,10 @@ class ModelManageable(Protocol): def force_cast_weights(self) -> bool: return False + @abstractmethod + def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): + pass + @dataclasses.dataclass class MemoryMeasurements: diff --git a/comfy_api/torch_helpers/torch_compile.py b/comfy_api/torch_helpers/torch_compile.py index cabe2bd35..721fd43aa 100644 --- a/comfy_api/torch_helpers/torch_compile.py +++ b/comfy_api/torch_helpers/torch_compile.py @@ -27,9 +27,6 @@ def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable], model for key, value in compiled_module_dict.items(): orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key) comfy.utils.set_attr(executor.class_obj, key, value) - # todo: compilation has to patch all weights - if model_patcher is not None: - model_patcher.patch_model(device_to=model_management.get_torch_device(), force_patch_weights=True) return executor(*args, **kwargs) finally: for key, value in orig_modules.items():