mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 22:00:19 +08:00
Provide a protocol for plugins to declare model-management-manageable models. Docs will be updated to specify that plugin authors should use ModelPatcher generally.
This commit is contained in:
parent
c2fa74f625
commit
779ff30c17
@ -12,6 +12,8 @@ from threading import RLock
|
|||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from .model_management_types import ModelManageable
|
||||||
|
|
||||||
model_management_lock = RLock()
|
model_management_lock = RLock()
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
@ -278,7 +280,7 @@ def module_size(module):
|
|||||||
return module_mem
|
return module_mem
|
||||||
|
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
def __init__(self, model):
|
def __init__(self, model: ModelManageable):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.device = model.load_device
|
self.device = model.load_device
|
||||||
self.weights_loaded = False
|
self.weights_loaded = False
|
||||||
|
|||||||
49
comfy/model_management_types.py
Normal file
49
comfy/model_management_types.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Protocol, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManageable(Protocol):
|
||||||
|
"""
|
||||||
|
Objects which implement this protocol can be managed by
|
||||||
|
|
||||||
|
>>> import comfy.model_management
|
||||||
|
>>> class SomeObj("ModelManageable"):
|
||||||
|
>>> ...
|
||||||
|
>>>
|
||||||
|
>>> comfy.model_management.load_model_gpu(SomeObj())
|
||||||
|
"""
|
||||||
|
load_device: torch.device
|
||||||
|
offload_device: torch.device
|
||||||
|
model: torch.nn.Module
|
||||||
|
current_device: torch.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> torch.dtype:
|
||||||
|
...
|
||||||
|
|
||||||
|
def is_clone(self, other: torch.nn.Module) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def model_size(self) -> int:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def model_patches_to(self, arg: torch.device | torch.dtype):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def model_dtype(self) -> torch.dtype:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int) -> torch.nn.Module:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||||
|
pass
|
||||||
@ -6,6 +6,8 @@ import uuid
|
|||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import model_management
|
from . import model_management
|
||||||
|
from .model_management_types import ModelManageable
|
||||||
|
|
||||||
|
|
||||||
def apply_weight_decompose(dora_scale, weight):
|
def apply_weight_decompose(dora_scale, weight):
|
||||||
weight_norm = (
|
weight_norm = (
|
||||||
@ -39,7 +41,7 @@ def set_model_options_patch_replace(model_options, patch, name, block_name, numb
|
|||||||
model_options["transformer_options"] = to
|
model_options["transformer_options"] = to
|
||||||
return model_options
|
return model_options
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher(ModelManageable):
|
||||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user