more type documentation, fix installation with dependency group

This commit is contained in:
doctorpangloss 2025-09-08 11:35:19 -07:00
parent a31e5f216d
commit b62d4f05e1
4 changed files with 160 additions and 40 deletions

View File

@ -35,8 +35,8 @@ jobs:
run: |
export UV_BREAK_SYSTEM_PACKAGES=true
export UV_SYSTEM_PYTHON=true
pip freeze | grep numpy > numpy_override.txt
uv pip install . --inexact --group dev --override numpy_override.txt
uv pip freeze | grep numpy > numpy_override.txt
uv pip install ".[dev]" --inexact --override numpy_override.txt
- name: Run tests
run: |
nvidia-smi

View File

@ -1,7 +1,5 @@
from __future__ import annotations
from typing import Literal, Any, NamedTuple, Protocol, Callable
from typing import Protocol, List, Dict, Optional, NamedTuple, Callable, Literal, Any
import torch
PatchOffset = tuple[int, int, int]
@ -31,3 +29,47 @@ class PatchTuple(NamedTuple):
ModelPatchesDictValue = list[PatchTuple | PatchWeightTuple]
class PatchSupport(Protocol):
"""
Defines the interface for a model that supports LoRA patching.
"""
def add_patches(
self,
patches: PatchDict,
strength_patch: float = 1.0,
strength_model: float = 1.0
) -> List[PatchDictKey]:
"""
Applies a set of patches (like LoRA weights) to the model.
Args:
patches (PatchDict): A dictionary containing the patch weights and metadata.
strength_patch (float): The strength multiplier for the patch itself.
strength_model (float): The strength multiplier for the original model weights.
Returns:
List[PatchDictKey]: A list of keys for the weights that were successfully patched.
"""
...
def get_key_patches(
self,
filter_prefix: Optional[str] = None
) -> Dict[str, ModelPatchesDictValue]:
"""
Retrieves all active patches, optionally filtered by a key prefix.
The returned dictionary maps a model weight's key to a list. The first
element in the list is a tuple containing the original weight, and subsequent
elements are the applied patch tuples.
Args:
filter_prefix (Optional[str]): A prefix to filter which weight patches are returned.
Returns:
Dict[str, ModelPatchesDictValue]: A dictionary of the model's patched weights.
"""
...

View File

@ -2,12 +2,13 @@ from __future__ import annotations
import dataclasses
from abc import ABCMeta, abstractmethod
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, override
import torch
import torch.nn
from typing_extensions import TypedDict, NotRequired
from .comfy_types import UnetWrapperFunction
from .latent_formats import LatentFormat
ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable')
@ -25,7 +26,71 @@ class DeviceSettable(Protocol):
...
class ModelManageable(Protocol, metaclass=ABCMeta):
class HooksSupport(Protocol, metaclass=ABCMeta):
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
return
class TrainingSupport(Protocol, metaclass=ABCMeta):
def set_model_compute_dtype(self, dtype: torch.dtype):
return
def add_weight_wrapper(self, name, function):
return
class ModelManageableExtras(Protocol, metaclass=ABCMeta):
@property
def current_device(self) -> torch.device:
return torch.device("cpu")
class ModelManageableRequired(Protocol):
"""
The bare minimum that must be implemented to support model management when inheriting from ModelManageable
Attributes:
load_device (torch.device): the device that this model's weights will be loaded onto for inference, typically the GPU
offload_device (torch.device): the device that this model's weights will be offloaded onto when not being used for inference or when performing CPU offloading, typically the CPU
model (torch.nn.Module): in principle this can be any callable, but it should be a torch model to work with the rest of the machinery
:see: ModelManageable
:see: PatchSupport
"""
load_device: torch.device
offload_device: torch.device
model: torch.nn.Module
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
"""
Called by ModelManageable
An implementation of this method should
(1) Loads the model by moving it to the target device
(2) Fusing the LoRA weights ("patches", if applicable)
:param device_to:
:param lowvram_model_memory:
:param load_weights:
:param force_patch_weights:
:return:
"""
...
def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
"""
Called by ModelManageable
Unloads the model by:
(1) Unfusing the LoRA weights ("unpatching", if applicable)
(1) Moving the weights to the provided device
:param device_to:
:param unpatch_weights:
:return:
"""
...
class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol, metaclass=ABCMeta):
"""
Objects which implement this protocol can be managed by
@ -35,13 +100,20 @@ class ModelManageable(Protocol, metaclass=ABCMeta):
>>>
>>> some_model = ModelWrapper()
>>> load_models_gpu([some_model])
The minimum required
"""
load_device: torch.device
offload_device: torch.device
model: torch.nn.Module
@property
@override
def current_device(self) -> torch.device:
"""
Only needed in Hidden Switch, does not need to be overridden
:return:
"""
return next(self.model.parameters()).device
def is_clone(self, other: ModelManageableT) -> bool:
@ -60,19 +132,10 @@ class ModelManageable(Protocol, metaclass=ABCMeta):
def model_dtype(self) -> torch.dtype:
return next(self.model.parameters()).dtype
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
...
def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
"""
Unloads the model by moving it to the offload device
:param device_to:
:param unpatch_weights:
:return:
"""
...
def lowvram_patch_counter(self) -> int:
"""
Returns a counter related to low VRAM patching, used to decide if a reload is necessary.
"""
return 0
def partially_load(self, device_to: torch.device, extra_memory: int = 0, force_patch_weights: bool = False):
@ -120,26 +183,27 @@ class ModelManageable(Protocol, metaclass=ABCMeta):
@property
def parent(self) -> ModelManageableT | None:
"""
Used for tracking a parent model from which this was cloned
:return:
"""
return None
def detach(self, unpatch_all: bool = True):
self.model_patches_to(self.offload_device)
if unpatch_all:
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
"""
Unloads the model
:param unpatch_all:
:return:
"""
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
return self.model
def set_model_compute_dtype(self, dtype: torch.dtype):
pass
def add_weight_wrapper(self, name, function):
pass
@property
def force_cast_weights(self) -> bool:
return False
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
pass
def model_patches_models(self) -> list[ModelManageableT]:
"""
Used to implement Qwen DiffSynth Controlnets (?)
:return:
"""
return []
@dataclasses.dataclass
@ -169,24 +233,38 @@ class MemoryMeasurements:
self._device = value
class HasModels(Protocol):
"""A protocol for any object that has a .models() method returning a list."""
def models(self) -> list:
...
class HasTo(Protocol):
def to(self, device: torch.device):
...
class TransformerOptions(TypedDict, total=False):
cond_or_uncond: NotRequired[list]
patches: NotRequired[dict]
patches: NotRequired[dict[str, list[HasModels]]]
sigmas: NotRequired[torch.Tensor]
patches_replace: NotRequired[dict[str, dict[Any, HasModels]]]
class ModelOptions(TypedDict, total=False):
transformer_options: NotRequired[dict]
# signature of BaseModel.apply_model
model_function_wrapper: NotRequired[Callable]
model_function_wrapper: NotRequired[Callable | UnetWrapperFunction | HasModels | HasTo]
sampler_cfg_function: NotRequired[Callable]
sampler_post_cfg_function: NotRequired[list[Callable]]
disable_cfg1_optimization: NotRequired[bool]
denoise_mask_function: NotRequired[Callable]
patches: NotRequired[dict[str, list]]
class LoadingListItem(NamedTuple):
module_size: int
name: str
module: torch.nn.Module
params: list[str]
params: list[str]

View File

@ -40,9 +40,9 @@ from .component_model.deprecation import _deprecate_method
from .float import stochastic_rounding
from .gguf import move_patch_to_device, is_torch_compatible, is_quantized, GGMLOps
from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue, PatchSupport
from .model_base import BaseModel
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT, LoadingListItem
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT, LoadingListItem, TrainingSupport, HooksSupport
from .patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
logger = logging.getLogger(__name__)
@ -230,7 +230,7 @@ class GGUFQuantization:
patch_on_device: bool = False
class ModelPatcher(ModelManageable):
class ModelPatcher(ModelManageable, TrainingSupport, HooksSupport, PatchSupport):
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
self.size = size
self.model: BaseModel | torch.nn.Module = model