mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 22:30:19 +08:00
Fix torch compile
This commit is contained in:
parent
735a133ad4
commit
7ce781b0fd
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple
|
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -24,7 +25,7 @@ class DeviceSettable(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class ModelManageable(Protocol):
|
class ModelManageable(Protocol, metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
Objects which implement this protocol can be managed by
|
Objects which implement this protocol can be managed by
|
||||||
|
|
||||||
@ -137,6 +138,10 @@ class ModelManageable(Protocol):
|
|||||||
def force_cast_weights(self) -> bool:
|
def force_cast_weights(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class MemoryMeasurements:
|
class MemoryMeasurements:
|
||||||
|
|||||||
@ -27,9 +27,6 @@ def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable], model
|
|||||||
for key, value in compiled_module_dict.items():
|
for key, value in compiled_module_dict.items():
|
||||||
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
|
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
|
||||||
comfy.utils.set_attr(executor.class_obj, key, value)
|
comfy.utils.set_attr(executor.class_obj, key, value)
|
||||||
# todo: compilation has to patch all weights
|
|
||||||
if model_patcher is not None:
|
|
||||||
model_patcher.patch_model(device_to=model_management.get_torch_device(), force_patch_weights=True)
|
|
||||||
return executor(*args, **kwargs)
|
return executor(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
for key, value in orig_modules.items():
|
for key, value in orig_modules.items():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user