mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 07:10:52 +08:00
Provide a protocol for plugins to declare model-management-manageable models. Docs will be updated to specify that plugin authors should use ModelPatcher generally.
This commit is contained in:
parent
c2fa74f625
commit
779ff30c17
@ -12,6 +12,8 @@ from threading import RLock
|
||||
import torch
|
||||
import sys
|
||||
|
||||
from .model_management_types import ModelManageable
|
||||
|
||||
model_management_lock = RLock()
|
||||
|
||||
class VRAMState(Enum):
|
||||
@ -278,7 +280,7 @@ def module_size(module):
|
||||
return module_mem
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
def __init__(self, model: ModelManageable):
|
||||
self.model = model
|
||||
self.device = model.load_device
|
||||
self.weights_loaded = False
|
||||
|
||||
49
comfy/model_management_types.py
Normal file
49
comfy/model_management_types.py
Normal file
@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ModelManageable(Protocol):
|
||||
"""
|
||||
Objects which implement this protocol can be managed by
|
||||
|
||||
>>> import comfy.model_management
|
||||
>>> class SomeObj("ModelManageable"):
|
||||
>>> ...
|
||||
>>>
|
||||
>>> comfy.model_management.load_model_gpu(SomeObj())
|
||||
"""
|
||||
load_device: torch.device
|
||||
offload_device: torch.device
|
||||
model: torch.nn.Module
|
||||
current_device: torch.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
...
|
||||
|
||||
def is_clone(self, other: torch.nn.Module) -> bool:
|
||||
pass
|
||||
|
||||
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool:
|
||||
pass
|
||||
|
||||
def model_size(self) -> int:
|
||||
pass
|
||||
|
||||
def model_patches_to(self, arg: torch.device | torch.dtype):
|
||||
pass
|
||||
|
||||
def model_dtype(self) -> torch.dtype:
|
||||
pass
|
||||
|
||||
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int) -> torch.nn.Module:
|
||||
pass
|
||||
|
||||
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
|
||||
pass
|
||||
|
||||
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
pass
|
||||
@ -6,6 +6,8 @@ import uuid
|
||||
|
||||
from . import utils
|
||||
from . import model_management
|
||||
from .model_management_types import ModelManageable
|
||||
|
||||
|
||||
def apply_weight_decompose(dora_scale, weight):
|
||||
weight_norm = (
|
||||
@ -39,7 +41,7 @@ def set_model_options_patch_replace(model_options, patch, name, block_name, numb
|
||||
model_options["transformer_options"] = to
|
||||
return model_options
|
||||
|
||||
class ModelPatcher:
|
||||
class ModelPatcher(ModelManageable):
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||
self.size = size
|
||||
self.model = model
|
||||
|
||||
Loading…
Reference in New Issue
Block a user