diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ffe45e8f0..8851c48b7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -35,8 +35,8 @@ jobs: run: | export UV_BREAK_SYSTEM_PACKAGES=true export UV_SYSTEM_PYTHON=true - pip freeze | grep numpy > numpy_override.txt - uv pip install . --inexact --group dev --override numpy_override.txt + uv pip freeze | grep numpy > numpy_override.txt + uv pip install ".[dev]" --inexact --override numpy_override.txt - name: Run tests run: | nvidia-smi diff --git a/comfy/lora_types.py b/comfy/lora_types.py index e19be55e0..631bcf87c 100644 --- a/comfy/lora_types.py +++ b/comfy/lora_types.py @@ -1,7 +1,5 @@ from __future__ import annotations - -from typing import Literal, Any, NamedTuple, Protocol, Callable - +from typing import Protocol, List, Dict, Optional, NamedTuple, Callable, Literal, Any import torch PatchOffset = tuple[int, int, int] @@ -31,3 +29,47 @@ class PatchTuple(NamedTuple): ModelPatchesDictValue = list[PatchTuple | PatchWeightTuple] + + +class PatchSupport(Protocol): + """ + Defines the interface for a model that supports LoRA patching. + """ + + def add_patches( + self, + patches: PatchDict, + strength_patch: float = 1.0, + strength_model: float = 1.0 + ) -> List[PatchDictKey]: + """ + Applies a set of patches (like LoRA weights) to the model. + + Args: + patches (PatchDict): A dictionary containing the patch weights and metadata. + strength_patch (float): The strength multiplier for the patch itself. + strength_model (float): The strength multiplier for the original model weights. + + Returns: + List[PatchDictKey]: A list of keys for the weights that were successfully patched. + """ + ... + + def get_key_patches( + self, + filter_prefix: Optional[str] = None + ) -> Dict[str, ModelPatchesDictValue]: + """ + Retrieves all active patches, optionally filtered by a key prefix. + + The returned dictionary maps a model weight's key to a list. The first + element in the list is a tuple containing the original weight, and subsequent + elements are the applied patch tuples. + + Args: + filter_prefix (Optional[str]): A prefix to filter which weight patches are returned. + + Returns: + Dict[str, ModelPatchesDictValue]: A dictionary of the model's patched weights. + """ + ... diff --git a/comfy/model_management_types.py b/comfy/model_management_types.py index 32d63c81b..b5e73d811 100644 --- a/comfy/model_management_types.py +++ b/comfy/model_management_types.py @@ -2,12 +2,13 @@ from __future__ import annotations import dataclasses from abc import ABCMeta, abstractmethod -from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple +from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, override import torch import torch.nn from typing_extensions import TypedDict, NotRequired +from .comfy_types import UnetWrapperFunction from .latent_formats import LatentFormat ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable') @@ -25,7 +26,71 @@ class DeviceSettable(Protocol): ... -class ModelManageable(Protocol, metaclass=ABCMeta): +class HooksSupport(Protocol, metaclass=ABCMeta): + def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): + return + + +class TrainingSupport(Protocol, metaclass=ABCMeta): + def set_model_compute_dtype(self, dtype: torch.dtype): + return + + def add_weight_wrapper(self, name, function): + return + + +class ModelManageableExtras(Protocol, metaclass=ABCMeta): + @property + def current_device(self) -> torch.device: + return torch.device("cpu") + + +class ModelManageableRequired(Protocol): + """ + The bare minimum that must be implemented to support model management when inheriting from ModelManageable + + Attributes: + load_device (torch.device): the device that this model's weights will be loaded onto for inference, typically the GPU + offload_device (torch.device): the device that this model's weights will be offloaded onto when not being used for inference or when performing CPU offloading, typically the CPU + model (torch.nn.Module): in principle this can be any callable, but it should be a torch model to work with the rest of the machinery + :see: ModelManageable + :see: PatchSupport + """ + load_device: torch.device + offload_device: torch.device + model: torch.nn.Module + + def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module: + """ + Called by ModelManageable + + An implementation of this method should + (1) Loads the model by moving it to the target device + (2) Fusing the LoRA weights ("patches", if applicable) + + :param device_to: + :param lowvram_model_memory: + :param load_weights: + :param force_patch_weights: + :return: + """ + ... + + def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: + """ + Called by ModelManageable + + Unloads the model by: + (1) Unfusing the LoRA weights ("unpatching", if applicable) + (1) Moving the weights to the provided device + :param device_to: + :param unpatch_weights: + :return: + """ + ... + + +class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol, metaclass=ABCMeta): """ Objects which implement this protocol can be managed by @@ -35,13 +100,20 @@ class ModelManageable(Protocol, metaclass=ABCMeta): >>> >>> some_model = ModelWrapper() >>> load_models_gpu([some_model]) + + The minimum required """ load_device: torch.device offload_device: torch.device model: torch.nn.Module @property + @override def current_device(self) -> torch.device: + """ + Only needed in Hidden Switch, does not need to be overridden + :return: + """ return next(self.model.parameters()).device def is_clone(self, other: ModelManageableT) -> bool: @@ -60,19 +132,10 @@ class ModelManageable(Protocol, metaclass=ABCMeta): def model_dtype(self) -> torch.dtype: return next(self.model.parameters()).dtype - def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module: - ... - - def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: - """ - Unloads the model by moving it to the offload device - :param device_to: - :param unpatch_weights: - :return: - """ - ... - def lowvram_patch_counter(self) -> int: + """ + Returns a counter related to low VRAM patching, used to decide if a reload is necessary. + """ return 0 def partially_load(self, device_to: torch.device, extra_memory: int = 0, force_patch_weights: bool = False): @@ -120,26 +183,27 @@ class ModelManageable(Protocol, metaclass=ABCMeta): @property def parent(self) -> ModelManageableT | None: + """ + Used for tracking a parent model from which this was cloned + :return: + """ return None def detach(self, unpatch_all: bool = True): - self.model_patches_to(self.offload_device) - if unpatch_all: - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) + """ + Unloads the model + :param unpatch_all: + :return: + """ + self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) return self.model - def set_model_compute_dtype(self, dtype: torch.dtype): - pass - - def add_weight_wrapper(self, name, function): - pass - - @property - def force_cast_weights(self) -> bool: - return False - - def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): - pass + def model_patches_models(self) -> list[ModelManageableT]: + """ + Used to implement Qwen DiffSynth Controlnets (?) + :return: + """ + return [] @dataclasses.dataclass @@ -169,24 +233,38 @@ class MemoryMeasurements: self._device = value +class HasModels(Protocol): + """A protocol for any object that has a .models() method returning a list.""" + + def models(self) -> list: + ... + + +class HasTo(Protocol): + def to(self, device: torch.device): + ... + + class TransformerOptions(TypedDict, total=False): cond_or_uncond: NotRequired[list] - patches: NotRequired[dict] + patches: NotRequired[dict[str, list[HasModels]]] sigmas: NotRequired[torch.Tensor] + patches_replace: NotRequired[dict[str, dict[Any, HasModels]]] class ModelOptions(TypedDict, total=False): transformer_options: NotRequired[dict] # signature of BaseModel.apply_model - model_function_wrapper: NotRequired[Callable] + model_function_wrapper: NotRequired[Callable | UnetWrapperFunction | HasModels | HasTo] sampler_cfg_function: NotRequired[Callable] sampler_post_cfg_function: NotRequired[list[Callable]] disable_cfg1_optimization: NotRequired[bool] denoise_mask_function: NotRequired[Callable] patches: NotRequired[dict[str, list]] + class LoadingListItem(NamedTuple): module_size: int name: str module: torch.nn.Module - params: list[str] \ No newline at end of file + params: list[str] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4160e9a43..650e1b757 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -40,9 +40,9 @@ from .component_model.deprecation import _deprecate_method from .float import stochastic_rounding from .gguf import move_patch_to_device, is_torch_compatible, is_quantized, GGMLOps from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks -from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue +from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue, PatchSupport from .model_base import BaseModel -from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT, LoadingListItem +from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT, LoadingListItem, TrainingSupport, HooksSupport from .patcher_extension import CallbacksMP, WrappersMP, PatcherInjection logger = logging.getLogger(__name__) @@ -230,7 +230,7 @@ class GGUFQuantization: patch_on_device: bool = False -class ModelPatcher(ModelManageable): +class ModelPatcher(ModelManageable, TrainingSupport, HooksSupport, PatchSupport): def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None): self.size = size self.model: BaseModel | torch.nn.Module = model