From 4f6f3e919797fc48117b0a94abd4216cfb21ab72 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 9 Sep 2025 13:26:46 -0700 Subject: [PATCH] Fix linting issues, improve HooksSupport dummy implementation --- comfy/ldm/modules/sage_attention_dispatcher.py | 2 +- comfy/model_management_types.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/modules/sage_attention_dispatcher.py b/comfy/ldm/modules/sage_attention_dispatcher.py index 88f71fb20..4dc77c638 100644 --- a/comfy/ldm/modules/sage_attention_dispatcher.py +++ b/comfy/ldm/modules/sage_attention_dispatcher.py @@ -2,7 +2,7 @@ from typing import Optional, Any import torch # only imported when sage attention is enabled -from sageattention import * # pylint: disable=import-error +from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp8_cuda, sageattn_qk_int8_pv_fp8_cuda_sm90 # pylint: disable=import-error def get_cuda_arch_versions(): diff --git a/comfy/model_management_types.py b/comfy/model_management_types.py index 57b66517f..2383f369a 100644 --- a/comfy/model_management_types.py +++ b/comfy/model_management_types.py @@ -44,7 +44,6 @@ class HooksSupport(Protocol, metaclass=ABCMeta): setattr(self, "_hook_mode", EnumHookMode.MaxSpeed) return getattr(self, "_hook_mode") - @hook_mode.setter def hook_mode(self, value): setattr(self, "_hook_mode", value) @@ -76,9 +75,12 @@ class HooksSupport(Protocol, metaclass=ABCMeta): pass def pre_run(self): - from .model_base import BaseModel - if hasattr(self, "model") and isinstance(self.model, BaseModel): - self.model.current_patcher = self + if hasattr(self, "model"): + model = getattr(self, "model") + from .model_base import BaseModel + + if isinstance(model, BaseModel) or hasattr(model, "current_patcher") and isinstance(self, ModelManageable): + model.current_patcher = self def prepare_state(self, *args, **kwargs): pass @@ -107,7 +109,7 @@ class ModelManageableExtras(Protocol, metaclass=ABCMeta): return torch.device("cpu") -class ModelManageableRequired(Protocol): +class ModelManageableRequired(Protocol, metaclass=ABCMeta): """ The bare minimum that must be implemented to support model management when inheriting from ModelManageable @@ -122,6 +124,7 @@ class ModelManageableRequired(Protocol): offload_device: torch.device model: torch.nn.Module + @abstractmethod def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module: """ Called by ModelManageable @@ -138,6 +141,7 @@ class ModelManageableRequired(Protocol): """ ... + @abstractmethod def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: """ Called by ModelManageable @@ -152,6 +156,7 @@ class ModelManageableRequired(Protocol): ... +@runtime_checkable class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol, metaclass=ABCMeta): """ Objects which implement this protocol can be managed by