Provide a protocol for plugins to declare model-management-manageable models. Docs will be updated to specify that plugin authors should use ModelPatcher generally.

This commit is contained in:
doctorpangloss 2024-05-09 16:07:18 -07:00
parent c2fa74f625
commit 779ff30c17
3 changed files with 55 additions and 2 deletions

View File

@ -12,6 +12,8 @@ from threading import RLock
import torch
import sys
from .model_management_types import ModelManageable
model_management_lock = RLock()
class VRAMState(Enum):
@ -278,7 +280,7 @@ def module_size(module):
return module_mem
class LoadedModel:
def __init__(self, model):
def __init__(self, model: ModelManageable):
self.model = model
self.device = model.load_device
self.weights_loaded = False

View File

@ -0,0 +1,49 @@
from __future__ import annotations
from typing import Protocol, Optional
import torch
class ModelManageable(Protocol):
"""
Objects which implement this protocol can be managed by
>>> import comfy.model_management
>>> class SomeObj("ModelManageable"):
>>> ...
>>>
>>> comfy.model_management.load_model_gpu(SomeObj())
"""
load_device: torch.device
offload_device: torch.device
model: torch.nn.Module
current_device: torch.device
@property
def dtype(self) -> torch.dtype:
...
def is_clone(self, other: torch.nn.Module) -> bool:
pass
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool:
pass
def model_size(self) -> int:
pass
def model_patches_to(self, arg: torch.device | torch.dtype):
pass
def model_dtype(self) -> torch.dtype:
pass
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int) -> torch.nn.Module:
pass
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
pass
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
pass

View File

@ -6,6 +6,8 @@ import uuid
from . import utils
from . import model_management
from .model_management_types import ModelManageable
def apply_weight_decompose(dora_scale, weight):
weight_norm = (
@ -39,7 +41,7 @@ def set_model_options_patch_replace(model_options, patch, name, block_name, numb
model_options["transformer_options"] = to
return model_options
class ModelPatcher:
class ModelPatcher(ModelManageable):
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size
self.model = model