mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
more type documentation, fix installation with dependency group
This commit is contained in:
parent
a31e5f216d
commit
b62d4f05e1
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@ -35,8 +35,8 @@ jobs:
|
||||
run: |
|
||||
export UV_BREAK_SYSTEM_PACKAGES=true
|
||||
export UV_SYSTEM_PYTHON=true
|
||||
pip freeze | grep numpy > numpy_override.txt
|
||||
uv pip install . --inexact --group dev --override numpy_override.txt
|
||||
uv pip freeze | grep numpy > numpy_override.txt
|
||||
uv pip install ".[dev]" --inexact --override numpy_override.txt
|
||||
- name: Run tests
|
||||
run: |
|
||||
nvidia-smi
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, Any, NamedTuple, Protocol, Callable
|
||||
|
||||
from typing import Protocol, List, Dict, Optional, NamedTuple, Callable, Literal, Any
|
||||
import torch
|
||||
|
||||
PatchOffset = tuple[int, int, int]
|
||||
@ -31,3 +29,47 @@ class PatchTuple(NamedTuple):
|
||||
|
||||
|
||||
ModelPatchesDictValue = list[PatchTuple | PatchWeightTuple]
|
||||
|
||||
|
||||
class PatchSupport(Protocol):
|
||||
"""
|
||||
Defines the interface for a model that supports LoRA patching.
|
||||
"""
|
||||
|
||||
def add_patches(
|
||||
self,
|
||||
patches: PatchDict,
|
||||
strength_patch: float = 1.0,
|
||||
strength_model: float = 1.0
|
||||
) -> List[PatchDictKey]:
|
||||
"""
|
||||
Applies a set of patches (like LoRA weights) to the model.
|
||||
|
||||
Args:
|
||||
patches (PatchDict): A dictionary containing the patch weights and metadata.
|
||||
strength_patch (float): The strength multiplier for the patch itself.
|
||||
strength_model (float): The strength multiplier for the original model weights.
|
||||
|
||||
Returns:
|
||||
List[PatchDictKey]: A list of keys for the weights that were successfully patched.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_key_patches(
|
||||
self,
|
||||
filter_prefix: Optional[str] = None
|
||||
) -> Dict[str, ModelPatchesDictValue]:
|
||||
"""
|
||||
Retrieves all active patches, optionally filtered by a key prefix.
|
||||
|
||||
The returned dictionary maps a model weight's key to a list. The first
|
||||
element in the list is a tuple containing the original weight, and subsequent
|
||||
elements are the applied patch tuples.
|
||||
|
||||
Args:
|
||||
filter_prefix (Optional[str]): A prefix to filter which weight patches are returned.
|
||||
|
||||
Returns:
|
||||
Dict[str, ModelPatchesDictValue]: A dictionary of the model's patched weights.
|
||||
"""
|
||||
...
|
||||
|
||||
@ -2,12 +2,13 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple
|
||||
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, override
|
||||
|
||||
import torch
|
||||
import torch.nn
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
from .comfy_types import UnetWrapperFunction
|
||||
from .latent_formats import LatentFormat
|
||||
|
||||
ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable')
|
||||
@ -25,7 +26,71 @@ class DeviceSettable(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class ModelManageable(Protocol, metaclass=ABCMeta):
|
||||
class HooksSupport(Protocol, metaclass=ABCMeta):
|
||||
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
|
||||
return
|
||||
|
||||
|
||||
class TrainingSupport(Protocol, metaclass=ABCMeta):
|
||||
def set_model_compute_dtype(self, dtype: torch.dtype):
|
||||
return
|
||||
|
||||
def add_weight_wrapper(self, name, function):
|
||||
return
|
||||
|
||||
|
||||
class ModelManageableExtras(Protocol, metaclass=ABCMeta):
|
||||
@property
|
||||
def current_device(self) -> torch.device:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
class ModelManageableRequired(Protocol):
|
||||
"""
|
||||
The bare minimum that must be implemented to support model management when inheriting from ModelManageable
|
||||
|
||||
Attributes:
|
||||
load_device (torch.device): the device that this model's weights will be loaded onto for inference, typically the GPU
|
||||
offload_device (torch.device): the device that this model's weights will be offloaded onto when not being used for inference or when performing CPU offloading, typically the CPU
|
||||
model (torch.nn.Module): in principle this can be any callable, but it should be a torch model to work with the rest of the machinery
|
||||
:see: ModelManageable
|
||||
:see: PatchSupport
|
||||
"""
|
||||
load_device: torch.device
|
||||
offload_device: torch.device
|
||||
model: torch.nn.Module
|
||||
|
||||
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
|
||||
"""
|
||||
Called by ModelManageable
|
||||
|
||||
An implementation of this method should
|
||||
(1) Loads the model by moving it to the target device
|
||||
(2) Fusing the LoRA weights ("patches", if applicable)
|
||||
|
||||
:param device_to:
|
||||
:param lowvram_model_memory:
|
||||
:param load_weights:
|
||||
:param force_patch_weights:
|
||||
:return:
|
||||
"""
|
||||
...
|
||||
|
||||
def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
"""
|
||||
Called by ModelManageable
|
||||
|
||||
Unloads the model by:
|
||||
(1) Unfusing the LoRA weights ("unpatching", if applicable)
|
||||
(1) Moving the weights to the provided device
|
||||
:param device_to:
|
||||
:param unpatch_weights:
|
||||
:return:
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol, metaclass=ABCMeta):
|
||||
"""
|
||||
Objects which implement this protocol can be managed by
|
||||
|
||||
@ -35,13 +100,20 @@ class ModelManageable(Protocol, metaclass=ABCMeta):
|
||||
>>>
|
||||
>>> some_model = ModelWrapper()
|
||||
>>> load_models_gpu([some_model])
|
||||
|
||||
The minimum required
|
||||
"""
|
||||
load_device: torch.device
|
||||
offload_device: torch.device
|
||||
model: torch.nn.Module
|
||||
|
||||
@property
|
||||
@override
|
||||
def current_device(self) -> torch.device:
|
||||
"""
|
||||
Only needed in Hidden Switch, does not need to be overridden
|
||||
:return:
|
||||
"""
|
||||
return next(self.model.parameters()).device
|
||||
|
||||
def is_clone(self, other: ModelManageableT) -> bool:
|
||||
@ -60,19 +132,10 @@ class ModelManageable(Protocol, metaclass=ABCMeta):
|
||||
def model_dtype(self) -> torch.dtype:
|
||||
return next(self.model.parameters()).dtype
|
||||
|
||||
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
|
||||
...
|
||||
|
||||
def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
"""
|
||||
Unloads the model by moving it to the offload device
|
||||
:param device_to:
|
||||
:param unpatch_weights:
|
||||
:return:
|
||||
"""
|
||||
...
|
||||
|
||||
def lowvram_patch_counter(self) -> int:
|
||||
"""
|
||||
Returns a counter related to low VRAM patching, used to decide if a reload is necessary.
|
||||
"""
|
||||
return 0
|
||||
|
||||
def partially_load(self, device_to: torch.device, extra_memory: int = 0, force_patch_weights: bool = False):
|
||||
@ -120,26 +183,27 @@ class ModelManageable(Protocol, metaclass=ABCMeta):
|
||||
|
||||
@property
|
||||
def parent(self) -> ModelManageableT | None:
|
||||
"""
|
||||
Used for tracking a parent model from which this was cloned
|
||||
:return:
|
||||
"""
|
||||
return None
|
||||
|
||||
def detach(self, unpatch_all: bool = True):
|
||||
self.model_patches_to(self.offload_device)
|
||||
if unpatch_all:
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
||||
"""
|
||||
Unloads the model
|
||||
:param unpatch_all:
|
||||
:return:
|
||||
"""
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
||||
return self.model
|
||||
|
||||
def set_model_compute_dtype(self, dtype: torch.dtype):
|
||||
pass
|
||||
|
||||
def add_weight_wrapper(self, name, function):
|
||||
pass
|
||||
|
||||
@property
|
||||
def force_cast_weights(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
|
||||
pass
|
||||
def model_patches_models(self) -> list[ModelManageableT]:
|
||||
"""
|
||||
Used to implement Qwen DiffSynth Controlnets (?)
|
||||
:return:
|
||||
"""
|
||||
return []
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -169,24 +233,38 @@ class MemoryMeasurements:
|
||||
self._device = value
|
||||
|
||||
|
||||
class HasModels(Protocol):
|
||||
"""A protocol for any object that has a .models() method returning a list."""
|
||||
|
||||
def models(self) -> list:
|
||||
...
|
||||
|
||||
|
||||
class HasTo(Protocol):
|
||||
def to(self, device: torch.device):
|
||||
...
|
||||
|
||||
|
||||
class TransformerOptions(TypedDict, total=False):
|
||||
cond_or_uncond: NotRequired[list]
|
||||
patches: NotRequired[dict]
|
||||
patches: NotRequired[dict[str, list[HasModels]]]
|
||||
sigmas: NotRequired[torch.Tensor]
|
||||
patches_replace: NotRequired[dict[str, dict[Any, HasModels]]]
|
||||
|
||||
|
||||
class ModelOptions(TypedDict, total=False):
|
||||
transformer_options: NotRequired[dict]
|
||||
# signature of BaseModel.apply_model
|
||||
model_function_wrapper: NotRequired[Callable]
|
||||
model_function_wrapper: NotRequired[Callable | UnetWrapperFunction | HasModels | HasTo]
|
||||
sampler_cfg_function: NotRequired[Callable]
|
||||
sampler_post_cfg_function: NotRequired[list[Callable]]
|
||||
disable_cfg1_optimization: NotRequired[bool]
|
||||
denoise_mask_function: NotRequired[Callable]
|
||||
patches: NotRequired[dict[str, list]]
|
||||
|
||||
|
||||
class LoadingListItem(NamedTuple):
|
||||
module_size: int
|
||||
name: str
|
||||
module: torch.nn.Module
|
||||
params: list[str]
|
||||
params: list[str]
|
||||
|
||||
@ -40,9 +40,9 @@ from .component_model.deprecation import _deprecate_method
|
||||
from .float import stochastic_rounding
|
||||
from .gguf import move_patch_to_device, is_torch_compatible, is_quantized, GGMLOps
|
||||
from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks
|
||||
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue
|
||||
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue, PatchSupport
|
||||
from .model_base import BaseModel
|
||||
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT, LoadingListItem
|
||||
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT, LoadingListItem, TrainingSupport, HooksSupport
|
||||
from .patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -230,7 +230,7 @@ class GGUFQuantization:
|
||||
patch_on_device: bool = False
|
||||
|
||||
|
||||
class ModelPatcher(ModelManageable):
|
||||
class ModelPatcher(ModelManageable, TrainingSupport, HooksSupport, PatchSupport):
|
||||
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
||||
self.size = size
|
||||
self.model: BaseModel | torch.nn.Module = model
|
||||
|
||||
Loading…
Reference in New Issue
Block a user