mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Fix linting issues, improve HooksSupport dummy implementation
This commit is contained in:
parent
f5e29f0e61
commit
4f6f3e9197
@ -2,7 +2,7 @@ from typing import Optional, Any
|
||||
|
||||
import torch
|
||||
# only imported when sage attention is enabled
|
||||
from sageattention import * # pylint: disable=import-error
|
||||
from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp8_cuda, sageattn_qk_int8_pv_fp8_cuda_sm90 # pylint: disable=import-error
|
||||
|
||||
|
||||
def get_cuda_arch_versions():
|
||||
|
||||
@ -44,7 +44,6 @@ class HooksSupport(Protocol, metaclass=ABCMeta):
|
||||
setattr(self, "_hook_mode", EnumHookMode.MaxSpeed)
|
||||
return getattr(self, "_hook_mode")
|
||||
|
||||
|
||||
@hook_mode.setter
|
||||
def hook_mode(self, value):
|
||||
setattr(self, "_hook_mode", value)
|
||||
@ -76,9 +75,12 @@ class HooksSupport(Protocol, metaclass=ABCMeta):
|
||||
pass
|
||||
|
||||
def pre_run(self):
|
||||
from .model_base import BaseModel
|
||||
if hasattr(self, "model") and isinstance(self.model, BaseModel):
|
||||
self.model.current_patcher = self
|
||||
if hasattr(self, "model"):
|
||||
model = getattr(self, "model")
|
||||
from .model_base import BaseModel
|
||||
|
||||
if isinstance(model, BaseModel) or hasattr(model, "current_patcher") and isinstance(self, ModelManageable):
|
||||
model.current_patcher = self
|
||||
|
||||
def prepare_state(self, *args, **kwargs):
|
||||
pass
|
||||
@ -107,7 +109,7 @@ class ModelManageableExtras(Protocol, metaclass=ABCMeta):
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
class ModelManageableRequired(Protocol):
|
||||
class ModelManageableRequired(Protocol, metaclass=ABCMeta):
|
||||
"""
|
||||
The bare minimum that must be implemented to support model management when inheriting from ModelManageable
|
||||
|
||||
@ -122,6 +124,7 @@ class ModelManageableRequired(Protocol):
|
||||
offload_device: torch.device
|
||||
model: torch.nn.Module
|
||||
|
||||
@abstractmethod
|
||||
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
|
||||
"""
|
||||
Called by ModelManageable
|
||||
@ -138,6 +141,7 @@ class ModelManageableRequired(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
"""
|
||||
Called by ModelManageable
|
||||
@ -152,6 +156,7 @@ class ModelManageableRequired(Protocol):
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol, metaclass=ABCMeta):
|
||||
"""
|
||||
Objects which implement this protocol can be managed by
|
||||
|
||||
Loading…
Reference in New Issue
Block a user