mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
more type documentation, fix installation with dependency group
This commit is contained in:
parent
a31e5f216d
commit
b62d4f05e1
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@ -35,8 +35,8 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
export UV_BREAK_SYSTEM_PACKAGES=true
|
export UV_BREAK_SYSTEM_PACKAGES=true
|
||||||
export UV_SYSTEM_PYTHON=true
|
export UV_SYSTEM_PYTHON=true
|
||||||
pip freeze | grep numpy > numpy_override.txt
|
uv pip freeze | grep numpy > numpy_override.txt
|
||||||
uv pip install . --inexact --group dev --override numpy_override.txt
|
uv pip install ".[dev]" --inexact --override numpy_override.txt
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
nvidia-smi
|
nvidia-smi
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
from typing import Protocol, List, Dict, Optional, NamedTuple, Callable, Literal, Any
|
||||||
from typing import Literal, Any, NamedTuple, Protocol, Callable
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
PatchOffset = tuple[int, int, int]
|
PatchOffset = tuple[int, int, int]
|
||||||
@ -31,3 +29,47 @@ class PatchTuple(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
ModelPatchesDictValue = list[PatchTuple | PatchWeightTuple]
|
ModelPatchesDictValue = list[PatchTuple | PatchWeightTuple]
|
||||||
|
|
||||||
|
|
||||||
|
class PatchSupport(Protocol):
|
||||||
|
"""
|
||||||
|
Defines the interface for a model that supports LoRA patching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def add_patches(
|
||||||
|
self,
|
||||||
|
patches: PatchDict,
|
||||||
|
strength_patch: float = 1.0,
|
||||||
|
strength_model: float = 1.0
|
||||||
|
) -> List[PatchDictKey]:
|
||||||
|
"""
|
||||||
|
Applies a set of patches (like LoRA weights) to the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patches (PatchDict): A dictionary containing the patch weights and metadata.
|
||||||
|
strength_patch (float): The strength multiplier for the patch itself.
|
||||||
|
strength_model (float): The strength multiplier for the original model weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[PatchDictKey]: A list of keys for the weights that were successfully patched.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_key_patches(
|
||||||
|
self,
|
||||||
|
filter_prefix: Optional[str] = None
|
||||||
|
) -> Dict[str, ModelPatchesDictValue]:
|
||||||
|
"""
|
||||||
|
Retrieves all active patches, optionally filtered by a key prefix.
|
||||||
|
|
||||||
|
The returned dictionary maps a model weight's key to a list. The first
|
||||||
|
element in the list is a tuple containing the original weight, and subsequent
|
||||||
|
elements are the applied patch tuples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_prefix (Optional[str]): A prefix to filter which weight patches are returned.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, ModelPatchesDictValue]: A dictionary of the model's patched weights.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|||||||
@ -2,12 +2,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple
|
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, override
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn
|
import torch.nn
|
||||||
from typing_extensions import TypedDict, NotRequired
|
from typing_extensions import TypedDict, NotRequired
|
||||||
|
|
||||||
|
from .comfy_types import UnetWrapperFunction
|
||||||
from .latent_formats import LatentFormat
|
from .latent_formats import LatentFormat
|
||||||
|
|
||||||
ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable')
|
ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable')
|
||||||
@ -25,7 +26,71 @@ class DeviceSettable(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class ModelManageable(Protocol, metaclass=ABCMeta):
|
class HooksSupport(Protocol, metaclass=ABCMeta):
|
||||||
|
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingSupport(Protocol, metaclass=ABCMeta):
|
||||||
|
def set_model_compute_dtype(self, dtype: torch.dtype):
|
||||||
|
return
|
||||||
|
|
||||||
|
def add_weight_wrapper(self, name, function):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManageableExtras(Protocol, metaclass=ABCMeta):
|
||||||
|
@property
|
||||||
|
def current_device(self) -> torch.device:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManageableRequired(Protocol):
|
||||||
|
"""
|
||||||
|
The bare minimum that must be implemented to support model management when inheriting from ModelManageable
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
load_device (torch.device): the device that this model's weights will be loaded onto for inference, typically the GPU
|
||||||
|
offload_device (torch.device): the device that this model's weights will be offloaded onto when not being used for inference or when performing CPU offloading, typically the CPU
|
||||||
|
model (torch.nn.Module): in principle this can be any callable, but it should be a torch model to work with the rest of the machinery
|
||||||
|
:see: ModelManageable
|
||||||
|
:see: PatchSupport
|
||||||
|
"""
|
||||||
|
load_device: torch.device
|
||||||
|
offload_device: torch.device
|
||||||
|
model: torch.nn.Module
|
||||||
|
|
||||||
|
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
|
||||||
|
"""
|
||||||
|
Called by ModelManageable
|
||||||
|
|
||||||
|
An implementation of this method should
|
||||||
|
(1) Loads the model by moving it to the target device
|
||||||
|
(2) Fusing the LoRA weights ("patches", if applicable)
|
||||||
|
|
||||||
|
:param device_to:
|
||||||
|
:param lowvram_model_memory:
|
||||||
|
:param load_weights:
|
||||||
|
:param force_patch_weights:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||||
|
"""
|
||||||
|
Called by ModelManageable
|
||||||
|
|
||||||
|
Unloads the model by:
|
||||||
|
(1) Unfusing the LoRA weights ("unpatching", if applicable)
|
||||||
|
(1) Moving the weights to the provided device
|
||||||
|
:param device_to:
|
||||||
|
:param unpatch_weights:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol, metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
Objects which implement this protocol can be managed by
|
Objects which implement this protocol can be managed by
|
||||||
|
|
||||||
@ -35,13 +100,20 @@ class ModelManageable(Protocol, metaclass=ABCMeta):
|
|||||||
>>>
|
>>>
|
||||||
>>> some_model = ModelWrapper()
|
>>> some_model = ModelWrapper()
|
||||||
>>> load_models_gpu([some_model])
|
>>> load_models_gpu([some_model])
|
||||||
|
|
||||||
|
The minimum required
|
||||||
"""
|
"""
|
||||||
load_device: torch.device
|
load_device: torch.device
|
||||||
offload_device: torch.device
|
offload_device: torch.device
|
||||||
model: torch.nn.Module
|
model: torch.nn.Module
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@override
|
||||||
def current_device(self) -> torch.device:
|
def current_device(self) -> torch.device:
|
||||||
|
"""
|
||||||
|
Only needed in Hidden Switch, does not need to be overridden
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
return next(self.model.parameters()).device
|
return next(self.model.parameters()).device
|
||||||
|
|
||||||
def is_clone(self, other: ModelManageableT) -> bool:
|
def is_clone(self, other: ModelManageableT) -> bool:
|
||||||
@ -60,19 +132,10 @@ class ModelManageable(Protocol, metaclass=ABCMeta):
|
|||||||
def model_dtype(self) -> torch.dtype:
|
def model_dtype(self) -> torch.dtype:
|
||||||
return next(self.model.parameters()).dtype
|
return next(self.model.parameters()).dtype
|
||||||
|
|
||||||
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
|
|
||||||
...
|
|
||||||
|
|
||||||
def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
|
||||||
"""
|
|
||||||
Unloads the model by moving it to the offload device
|
|
||||||
:param device_to:
|
|
||||||
:param unpatch_weights:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def lowvram_patch_counter(self) -> int:
|
def lowvram_patch_counter(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns a counter related to low VRAM patching, used to decide if a reload is necessary.
|
||||||
|
"""
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def partially_load(self, device_to: torch.device, extra_memory: int = 0, force_patch_weights: bool = False):
|
def partially_load(self, device_to: torch.device, extra_memory: int = 0, force_patch_weights: bool = False):
|
||||||
@ -120,26 +183,27 @@ class ModelManageable(Protocol, metaclass=ABCMeta):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def parent(self) -> ModelManageableT | None:
|
def parent(self) -> ModelManageableT | None:
|
||||||
|
"""
|
||||||
|
Used for tracking a parent model from which this was cloned
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def detach(self, unpatch_all: bool = True):
|
def detach(self, unpatch_all: bool = True):
|
||||||
self.model_patches_to(self.offload_device)
|
"""
|
||||||
if unpatch_all:
|
Unloads the model
|
||||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
:param unpatch_all:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def set_model_compute_dtype(self, dtype: torch.dtype):
|
def model_patches_models(self) -> list[ModelManageableT]:
|
||||||
pass
|
"""
|
||||||
|
Used to implement Qwen DiffSynth Controlnets (?)
|
||||||
def add_weight_wrapper(self, name, function):
|
:return:
|
||||||
pass
|
"""
|
||||||
|
return []
|
||||||
@property
|
|
||||||
def force_cast_weights(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@ -169,24 +233,38 @@ class MemoryMeasurements:
|
|||||||
self._device = value
|
self._device = value
|
||||||
|
|
||||||
|
|
||||||
|
class HasModels(Protocol):
|
||||||
|
"""A protocol for any object that has a .models() method returning a list."""
|
||||||
|
|
||||||
|
def models(self) -> list:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class HasTo(Protocol):
|
||||||
|
def to(self, device: torch.device):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class TransformerOptions(TypedDict, total=False):
|
class TransformerOptions(TypedDict, total=False):
|
||||||
cond_or_uncond: NotRequired[list]
|
cond_or_uncond: NotRequired[list]
|
||||||
patches: NotRequired[dict]
|
patches: NotRequired[dict[str, list[HasModels]]]
|
||||||
sigmas: NotRequired[torch.Tensor]
|
sigmas: NotRequired[torch.Tensor]
|
||||||
|
patches_replace: NotRequired[dict[str, dict[Any, HasModels]]]
|
||||||
|
|
||||||
|
|
||||||
class ModelOptions(TypedDict, total=False):
|
class ModelOptions(TypedDict, total=False):
|
||||||
transformer_options: NotRequired[dict]
|
transformer_options: NotRequired[dict]
|
||||||
# signature of BaseModel.apply_model
|
# signature of BaseModel.apply_model
|
||||||
model_function_wrapper: NotRequired[Callable]
|
model_function_wrapper: NotRequired[Callable | UnetWrapperFunction | HasModels | HasTo]
|
||||||
sampler_cfg_function: NotRequired[Callable]
|
sampler_cfg_function: NotRequired[Callable]
|
||||||
sampler_post_cfg_function: NotRequired[list[Callable]]
|
sampler_post_cfg_function: NotRequired[list[Callable]]
|
||||||
disable_cfg1_optimization: NotRequired[bool]
|
disable_cfg1_optimization: NotRequired[bool]
|
||||||
denoise_mask_function: NotRequired[Callable]
|
denoise_mask_function: NotRequired[Callable]
|
||||||
patches: NotRequired[dict[str, list]]
|
patches: NotRequired[dict[str, list]]
|
||||||
|
|
||||||
|
|
||||||
class LoadingListItem(NamedTuple):
|
class LoadingListItem(NamedTuple):
|
||||||
module_size: int
|
module_size: int
|
||||||
name: str
|
name: str
|
||||||
module: torch.nn.Module
|
module: torch.nn.Module
|
||||||
params: list[str]
|
params: list[str]
|
||||||
|
|||||||
@ -40,9 +40,9 @@ from .component_model.deprecation import _deprecate_method
|
|||||||
from .float import stochastic_rounding
|
from .float import stochastic_rounding
|
||||||
from .gguf import move_patch_to_device, is_torch_compatible, is_quantized, GGMLOps
|
from .gguf import move_patch_to_device, is_torch_compatible, is_quantized, GGMLOps
|
||||||
from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks
|
from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks
|
||||||
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue
|
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue, PatchSupport
|
||||||
from .model_base import BaseModel
|
from .model_base import BaseModel
|
||||||
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT, LoadingListItem
|
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT, LoadingListItem, TrainingSupport, HooksSupport
|
||||||
from .patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
from .patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -230,7 +230,7 @@ class GGUFQuantization:
|
|||||||
patch_on_device: bool = False
|
patch_on_device: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher(ModelManageable):
|
class ModelPatcher(ModelManageable, TrainingSupport, HooksSupport, PatchSupport):
|
||||||
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.model: BaseModel | torch.nn.Module = model
|
self.model: BaseModel | torch.nn.Module = model
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user