mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Fix torch compile
This commit is contained in:
parent
735a133ad4
commit
7ce781b0fd
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple
|
||||
|
||||
import torch
|
||||
@ -24,7 +25,7 @@ class DeviceSettable(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class ModelManageable(Protocol):
|
||||
class ModelManageable(Protocol, metaclass=ABCMeta):
|
||||
"""
|
||||
Objects which implement this protocol can be managed by
|
||||
|
||||
@ -137,6 +138,10 @@ class ModelManageable(Protocol):
|
||||
def force_cast_weights(self) -> bool:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MemoryMeasurements:
|
||||
|
||||
@ -27,9 +27,6 @@ def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable], model
|
||||
for key, value in compiled_module_dict.items():
|
||||
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
|
||||
comfy.utils.set_attr(executor.class_obj, key, value)
|
||||
# todo: compilation has to patch all weights
|
||||
if model_patcher is not None:
|
||||
model_patcher.patch_model(device_to=model_management.get_torch_device(), force_patch_weights=True)
|
||||
return executor(*args, **kwargs)
|
||||
finally:
|
||||
for key, value in orig_modules.items():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user