diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 216a864fe..f1dc4d34d 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,4 +1,5 @@ from .component_model import files +from .model_management import load_models_gpu from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace import torch import json @@ -57,7 +58,7 @@ class ClipVisionModel(): return self.model.state_dict() def encode_image(self, image): - model_management.load_model_gpu(self.patcher) + load_models_gpu([self.patcher]) pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float() out = self.model(pixel_values=pixel_values, intermediate_output=-2) diff --git a/comfy/component_model/tensor_types.py b/comfy/component_model/tensor_types.py index 33d4d07f6..def0e21f1 100644 --- a/comfy/component_model/tensor_types.py +++ b/comfy/component_model/tensor_types.py @@ -1,17 +1,7 @@ -from typing import Annotated - -from jaxtyping import Float, Shaped +from jaxtyping import Float from torch import Tensor - -def channels_constraint(n: int): - def constraint(x: Tensor) -> bool: - return x.shape[-1] == n - - return constraint - - ImageBatch = Float[Tensor, "batch height width channels"] -RGBImageBatch = Annotated[ImageBatch, Shaped[channels_constraint(3)]] | Float[Tensor, "batch height width 3"] -RGBAImageBatch = Annotated[ImageBatch, Shaped[channels_constraint(4)]] | Float[Tensor, "batch height width 4"] +RGBImageBatch = Float[Tensor, "batch height width 3"] +RGBAImageBatch = Float[Tensor, "batch height width 4"] RGBImage = Float[Tensor, "height width 3"] diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 252308b24..c7b6f222b 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -142,12 +142,13 @@ class ControlBase: class ControlNet(ControlBase): - def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None): + def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, ckpt_name: str = None): super().__init__(device) self.control_model = control_model self.load_device = load_device if control_model is not None: self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device()) + self.control_model_wrapped.ckpt_name = os.path.basename(ckpt_name) self.compression_ratio = compression_ratio self.global_average_pooling = global_average_pooling self.model_sampling_current = None @@ -499,7 +500,7 @@ def load_controlnet(ckpt_path, model=None): if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling global_average_pooling = True - control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) + control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype, ckpt_name=filename) return control class T2IAdapter(ControlBase): diff --git a/comfy/language/transformers_model_management.py b/comfy/language/transformers_model_management.py index b2da0b0b2..8a4982b5d 100644 --- a/comfy/language/transformers_model_management.py +++ b/comfy/language/transformers_model_management.py @@ -2,6 +2,7 @@ from __future__ import annotations import copy import logging +import pathlib import warnings from typing import Optional, Any, Callable, Union, List @@ -175,7 +176,7 @@ class TransformersManagedModel(ModelManageable): if hasattr(self.processor, "to"): self.processor.to(device=self.load_device) - assert "" in prompt, "You must specify a <image> token inside the prompt for it to be substituted correctly by a HuggingFace processor" + assert "" in prompt.lower(), "You must specify a <image> token inside the prompt for it to be substituted correctly by a HuggingFace processor" batch_feature: BatchFeature = self.processor([prompt], images=images, padding=True, return_tensors="pt") if hasattr(self.processor, "to"): self.processor.to(device=self.offload_device) @@ -188,3 +189,10 @@ class TransformersManagedModel(ModelManageable): "inputs": batch_feature["input_ids"], **batch_feature } + + def __str__(self): + if self.repo_id is not None: + repo_id_as_path = pathlib.PurePath(self.repo_id) + return f"" + else: + return f"" diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index af85d102f..002d9d715 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -29,7 +29,7 @@ _session = Session() _hf_fs = HfFileSystem() -def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[List[Any]] = None) -> List[str]: +def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[List[Downloadable] | KnownDownloadables] = None) -> List[str]: if known_files is None: known_files = _get_known_models_for_folder_name(folder_name) diff --git a/comfy/model_management.py b/comfy/model_management.py index b8250af85..39f3f2db6 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -6,7 +6,7 @@ import sys import warnings from enum import Enum from threading import RLock -from typing import Literal, List +from typing import Literal, List, Sequence import psutil import torch @@ -15,6 +15,7 @@ from opentelemetry.trace import get_current_span from . import interruption from .cli_args import args from .cmd.main_pre import tracer +from .component_model.deprecation import _deprecate_method from .model_management_types import ModelManageable model_management_lock = RLock() @@ -439,7 +440,7 @@ def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]: @tracer.start_as_current_span("Load Models GPU") -def load_models_gpu(models, memory_required=0, force_patch_weights=False): +def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False) -> None: global vram_state span = get_current_span() if memory_required != 0: @@ -472,59 +473,61 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False): models_to_load.append(loaded_model) models_freed: List[LoadedModel] = [] - if len(models_to_load) == 0: - devs = set(map(lambda a: a.device, models_already_loaded)) - for d in devs: - if d != torch.device("cpu"): - models_freed += free_memory(extra_mem, d, models_already_loaded) + try: + if len(models_to_load) == 0: + devs = set(map(lambda a: a.device, models_already_loaded)) + for d in devs: + if d != torch.device("cpu"): + models_freed += free_memory(extra_mem, d, models_already_loaded) + return + + total_memory_required = {} + for loaded_model in models_to_load: + if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different + total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + + for device in total_memory_required: + if device != torch.device("cpu"): + # todo: where does 1.3 come from? + models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) + + for loaded_model in models_to_load: + weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded + if weights_unloaded is not None: + loaded_model.weights_loaded = not weights_unloaded + + for loaded_model in models_to_load: + model = loaded_model.model + torch_dev = model.load_device + if is_device_cpu(torch_dev): + vram_set_state = VRAMState.DISABLED + else: + vram_set_state = vram_state + lowvram_model_memory = 0 + if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): + model_size = loaded_model.model_memory_required(torch_dev) + current_free_mem = get_free_memory(torch_dev) + lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3)) + if model_size <= (current_free_mem - inference_memory): # only switch to lowvram if really necessary + + lowvram_model_memory = 0 + + if vram_set_state == VRAMState.NO_VRAM: + lowvram_model_memory = 64 * 1024 * 1024 + + loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) + current_loaded_models.insert(0, loaded_model) return - - total_memory_required = {} - for loaded_model in models_to_load: - if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different - total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) - - for device in total_memory_required: - if device != torch.device("cpu"): - models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) - - span.set_attribute("models_to_load", list(map(str, models_to_load))) - span.set_attribute("models_freed", list(map(str, models_freed))) - - logging.info(f"Models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}") - - for loaded_model in models_to_load: - weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded - if weights_unloaded is not None: - loaded_model.weights_loaded = not weights_unloaded - - for loaded_model in models_to_load: - model = loaded_model.model - torch_dev = model.load_device - if is_device_cpu(torch_dev): - vram_set_state = VRAMState.DISABLED - else: - vram_set_state = vram_state - lowvram_model_memory = 0 - if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): - model_size = loaded_model.model_memory_required(torch_dev) - current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3)) - if model_size <= (current_free_mem - inference_memory): # only switch to lowvram if really necessary - - lowvram_model_memory = 0 - - if vram_set_state == VRAMState.NO_VRAM: - lowvram_model_memory = 64 * 1024 * 1024 - - loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) - current_loaded_models.insert(0, loaded_model) - return + finally: + span.set_attribute("models", list(map(str, models))) + span.set_attribute("models_to_load", list(map(str, models_to_load))) + span.set_attribute("models_freed", list(map(str, models_freed))) + logging.info(f"Requested to load {','.join(map(str, models))}, models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}") +@_deprecate_method(message="Use load_models_gpu instead", version="0.0.2") def load_model_gpu(model): - with model_management_lock: - return load_models_gpu([model]) + return load_models_gpu([model]) def loaded_models(only_currently_used=False): diff --git a/comfy/model_management_types.py b/comfy/model_management_types.py index 119eeafcd..80a357cb0 100644 --- a/comfy/model_management_types.py +++ b/comfy/model_management_types.py @@ -9,11 +9,12 @@ class ModelManageable(Protocol): """ Objects which implement this protocol can be managed by - >>> import comfy.model_management - >>> class SomeObj("ModelManageable"): + >>> from comfy.model_management import load_models_gpu + >>> class ModelWrapper(ModelManageable): >>> ... >>> - >>> comfy.model_management.load_model_gpu(SomeObj()) + >>> some_model = ModelWrapper() + >>> load_models_gpu([some_model]) """ load_device: torch.device offload_device: torch.device diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index cc6cfd003..005b76132 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -128,7 +128,9 @@ class CustomNode(Protocol): CATEGORY: ClassVar[str] OUTPUT_NODE: Optional[ClassVar[bool]] - IS_CHANGED: Optional[ClassVar[IsChangedMethod]] + @classmethod + def IS_CHANGED(cls, *args, **kwargs) -> str: + ... @classmethod def __call__(cls, *args, **kwargs) -> 'CustomNode': diff --git a/comfy/sd.py b/comfy/sd.py index d39cd3249..da342ce9f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -20,6 +20,7 @@ from . import model_sampling from . import sd1_clip from . import sdxl_clip from . import utils +from .model_management import load_models_gpu from .text_encoders import sd2_clip from .text_encoders import sd3_clip @@ -153,7 +154,7 @@ class CLIP: return sd_clip def load_model(self): - model_management.load_model_gpu(self.patcher) + load_models_gpu([self.patcher]) return self.patcher def get_key_patches(self): @@ -337,7 +338,7 @@ class VAE: return pixel_samples def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap=16): - model_management.load_model_gpu(self.patcher) + load_models_gpu([self.patcher]) output = self.decode_tiled_(samples, tile_x, tile_y, overlap) return output.movedim(1, -1) @@ -366,7 +367,7 @@ class VAE: def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap=64): pixel_samples = self.vae_encode_crop_pixels(pixel_samples) - model_management.load_model_gpu(self.patcher) + load_models_gpu([self.patcher]) pixel_samples = pixel_samples.movedim(-1, 1) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) return samples @@ -574,7 +575,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if output_model: _model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device, ckpt_name=os.path.basename(ckpt_path)) if inital_load_device != torch.device("cpu"): - model_management.load_model_gpu(_model_patcher) + load_models_gpu([_model_patcher]) return (_model_patcher, clip, vae, clipvision) diff --git a/comfy_extras/nodes/nodes_language.py b/comfy_extras/nodes/nodes_language.py index b67236b1d..09a5e33eb 100644 --- a/comfy_extras/nodes/nodes_language.py +++ b/comfy_extras/nodes/nodes_language.py @@ -18,7 +18,7 @@ from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES from comfy.language.language_types import ProcessorResult from comfy.language.transformers_model_management import TransformersManagedModel from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo -from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device +from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device, load_models_gpu from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar @@ -340,7 +340,7 @@ class TransformersGenerate(CustomNode): tokens = copy.copy(tokens) sampler = sampler or {} generate_kwargs = copy.copy(sampler) - load_model_gpu(model) + load_models_gpu([model]) transformers_model: PreTrainedModel = model.model tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer # remove unused inputs diff --git a/comfy_extras/nodes/nodes_upscale_model.py b/comfy_extras/nodes/nodes_upscale_model.py index b380e71ff..1160585bd 100644 --- a/comfy_extras/nodes/nodes_upscale_model.py +++ b/comfy_extras/nodes/nodes_upscale_model.py @@ -1,24 +1,135 @@ import logging -import torch +from typing import Optional, Any +import torch from spandrel import ModelLoader, ImageModelDescriptor + from comfy import model_management from comfy import utils +from comfy.component_model.tensor_types import RGBImageBatch from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download - +from comfy.model_management import load_models_gpu +from comfy.model_management_types import ModelManageable try: from spandrel_extra_arches import EXTRA_REGISTRY # pylint: disable=import-error from spandrel import MAIN_REGISTRY + MAIN_REGISTRY.add(*EXTRA_REGISTRY) logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.") except: pass + +class UpscaleModelManageable(ModelManageable): + def __init__(self, model_descriptor: ImageModelDescriptor, ckpt_name: str): + self.ckpt_name = ckpt_name + self.model_descriptor = model_descriptor + self.model = model_descriptor.model + self.load_device = model_management.unet_offload_device() + self.offload_device = model_management.unet_offload_device() + self._current_device = self.offload_device + self._lowvram_patch_counter = 0 + + # Private properties for image sizes and channels + self._input_size = (1, 512, 512) # Default input size (batch, height, width) + self._input_channels = model_descriptor.input_channels + self._output_channels = model_descriptor.output_channels + self.tile = 512 + + @property + def current_device(self) -> torch.device: + return self._current_device + + @property + def input_size(self) -> tuple[int, int, int]: + return self._input_size + + @input_size.setter + def input_size(self, size: tuple[int, int, int]): + self._input_size = size + + @property + def output_size(self) -> tuple[int, int, int]: + return (self._input_size[0], + self._input_size[1] * self.model_descriptor.scale, + self._input_size[2] * self.model_descriptor.scale) + + def set_input_size_from_images(self, images: RGBImageBatch): + if images.ndim != 4: + raise ValueError("Input must be a 4D tensor (batch, height, width, channels)") + if images.shape[-1] != 3: + raise ValueError("Input must have 3 channels (RGB)") + self._input_size = (images.shape[0], images.shape[1], images.shape[2]) + + def is_clone(self, other: Any) -> bool: + return isinstance(other, UpscaleModelManageable) and self.model is other.model + + def clone_has_same_weights(self, clone: torch.nn.Module) -> bool: + return self.is_clone(clone) + + def model_size(self) -> int: + # Calculate the size of the model parameters + model_params_size = sum(p.numel() * p.element_size() for p in self.model.parameters()) + + # Get the byte size of the model's dtype + dtype_size = torch.finfo(self.model_dtype()).bits // 8 + + # Calculate the memory required for input and output images + input_size = self._input_size[0] * min(self.tile, self._input_size[1]) * min(self.tile, self._input_size[2]) * self._input_channels * dtype_size + output_size = self.output_size[0] * self.output_size[1] * self.output_size[2] * self._output_channels * dtype_size + + # Add some extra memory for processing + extra_memory = (input_size + output_size) * 2 # This is an estimate, adjust as needed + + return model_params_size + input_size + output_size + extra_memory + + def model_patches_to(self, arg: torch.device | torch.dtype): + if isinstance(arg, torch.device): + self.model.to(device=arg) + else: + self.model.to(dtype=arg) + + def model_dtype(self) -> torch.dtype: + return next(self.model.parameters()).dtype + + def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module: + self.model.to(device=device_to) + self._current_device = device_to + self._lowvram_patch_counter += 1 + return self.model + + def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module: + if patch_weights: + self.model.to(device=device_to) + self._current_device = device_to + return self.model + + def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: + if unpatch_weights: + self.model.to(device=offload_device) + self._current_device = offload_device + return self.model + + @property + def lowvram_patch_counter(self) -> int: + return self._lowvram_patch_counter + + @lowvram_patch_counter.setter + def lowvram_patch_counter(self, value: int): + self._lowvram_patch_counter = value + + def __str__(self): + if self.ckpt_name is not None: + return f"" + else: + return f"" + + class UpscaleModelLoader: @classmethod def INPUT_TYPES(s): - return {"required": {"model_name": (get_filename_list_with_downloadable("upscale_models", KNOWN_UPSCALERS),), + return {"required": {"model_name": (get_filename_list_with_downloadable("upscale_models"),), }} RETURN_TYPES = ("UPSCALE_MODEL",) @@ -36,7 +147,7 @@ class UpscaleModelLoader: if not isinstance(out, ImageModelDescriptor): raise Exception("Upscale model must be a single-image model.") - return (out,) + return (UpscaleModelManageable(out, model_name),) class ImageUpscaleWithModel: @@ -51,33 +162,30 @@ class ImageUpscaleWithModel: CATEGORY = "image/upscaling" - def upscale(self, upscale_model, image): - device = model_management.get_torch_device() + def upscale(self, upscale_model: UpscaleModelManageable, image: RGBImageBatch): + upscale_model.set_input_size_from_images(image) + load_models_gpu([upscale_model]) - memory_required = model_management.module_size(upscale_model.model) - memory_required += (512 * 512 * 3) * image.element_size() * max(upscale_model.scale, 1.0) * 384.0 # The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate - memory_required += image.nelement() * image.element_size() - model_management.free_memory(memory_required, device) + in_img = image.movedim(-1, -3).to(upscale_model.current_device, dtype=upscale_model.model_dtype()) - upscale_model.to(device) - in_img = image.movedim(-1, -3).to(device) - - tile = 512 + tile = upscale_model.tile overlap = 32 oom = True + s = None while oom: try: steps = in_img.shape[0] * utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) pbar = utils.ProgressBar(steps) - s = utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) + s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.model.scale, pbar=pbar) oom = False except model_management.OOM_EXCEPTION as e: tile //= 2 - if tile < 128: + overlap //= 2 + if tile < 64 or overlap < 4: raise e - upscale_model.to("cpu") + # upscale_model.to("cpu") s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0) return (s,) diff --git a/requirements.txt b/requirements.txt index 7966dacb3..1e83cd36e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,3 +61,4 @@ PySoundFile networkx>=2.6.3 joblib jaxtyping +spandrel_extra_arches