Fix torch compile

This commit is contained in:
doctorpangloss 2025-08-22 17:43:12 -07:00
parent 735a133ad4
commit 7ce781b0fd
2 changed files with 6 additions and 4 deletions

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import dataclasses
from abc import ABCMeta, abstractmethod
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple
import torch
@ -24,7 +25,7 @@ class DeviceSettable(Protocol):
...
class ModelManageable(Protocol):
class ModelManageable(Protocol, metaclass=ABCMeta):
"""
Objects which implement this protocol can be managed by
@ -137,6 +138,10 @@ class ModelManageable(Protocol):
def force_cast_weights(self) -> bool:
return False
@abstractmethod
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
pass
@dataclasses.dataclass
class MemoryMeasurements:

View File

@ -27,9 +27,6 @@ def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable], model
for key, value in compiled_module_dict.items():
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
comfy.utils.set_attr(executor.class_obj, key, value)
# todo: compilation has to patch all weights
if model_patcher is not None:
model_patcher.patch_model(device_to=model_management.get_torch_device(), force_patch_weights=True)
return executor(*args, **kwargs)
finally:
for key, value in orig_modules.items():