diff --git a/comfy/model_management.py b/comfy/model_management.py index 0ae010ac6..a9118361e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -12,6 +12,8 @@ from threading import RLock import torch import sys +from .model_management_types import ModelManageable + model_management_lock = RLock() class VRAMState(Enum): @@ -278,7 +280,7 @@ def module_size(module): return module_mem class LoadedModel: - def __init__(self, model): + def __init__(self, model: ModelManageable): self.model = model self.device = model.load_device self.weights_loaded = False diff --git a/comfy/model_management_types.py b/comfy/model_management_types.py new file mode 100644 index 000000000..e8c096095 --- /dev/null +++ b/comfy/model_management_types.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import Protocol, Optional + +import torch + + +class ModelManageable(Protocol): + """ + Objects which implement this protocol can be managed by + + >>> import comfy.model_management + >>> class SomeObj("ModelManageable"): + >>> ... + >>> + >>> comfy.model_management.load_model_gpu(SomeObj()) + """ + load_device: torch.device + offload_device: torch.device + model: torch.nn.Module + current_device: torch.device + + @property + def dtype(self) -> torch.dtype: + ... + + def is_clone(self, other: torch.nn.Module) -> bool: + pass + + def clone_has_same_weights(self, clone: torch.nn.Module) -> bool: + pass + + def model_size(self) -> int: + pass + + def model_patches_to(self, arg: torch.device | torch.dtype): + pass + + def model_dtype(self) -> torch.dtype: + pass + + def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int) -> torch.nn.Module: + pass + + def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module: + pass + + def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: + pass diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 7ab71ef5f..c8eba6b73 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -6,6 +6,8 @@ import uuid from . import utils from . import model_management +from .model_management_types import ModelManageable + def apply_weight_decompose(dora_scale, weight): weight_norm = ( @@ -39,7 +41,7 @@ def set_model_options_patch_replace(model_options, patch, name, block_name, numb model_options["transformer_options"] = to return model_options -class ModelPatcher: +class ModelPatcher(ModelManageable): def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): self.size = size self.model = model