Improve performance and memory management of upscale models, improve messaging on models loaded and unloaded from the GPU

This commit is contained in:
doctorpangloss 2024-07-30 17:05:53 -07:00
parent c6ce11b421
commit ce5fe01768
12 changed files with 212 additions and 96 deletions

View File

@ -1,4 +1,5 @@
from .component_model import files
from .model_management import load_models_gpu
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
import torch
import json
@ -57,7 +58,7 @@ class ClipVisionModel():
return self.model.state_dict()
def encode_image(self, image):
model_management.load_model_gpu(self.patcher)
load_models_gpu([self.patcher])
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2)

View File

@ -1,17 +1,7 @@
from typing import Annotated
from jaxtyping import Float, Shaped
from jaxtyping import Float
from torch import Tensor
def channels_constraint(n: int):
def constraint(x: Tensor) -> bool:
return x.shape[-1] == n
return constraint
ImageBatch = Float[Tensor, "batch height width channels"]
RGBImageBatch = Annotated[ImageBatch, Shaped[channels_constraint(3)]] | Float[Tensor, "batch height width 3"]
RGBAImageBatch = Annotated[ImageBatch, Shaped[channels_constraint(4)]] | Float[Tensor, "batch height width 4"]
RGBImageBatch = Float[Tensor, "batch height width 3"]
RGBAImageBatch = Float[Tensor, "batch height width 4"]
RGBImage = Float[Tensor, "height width 3"]

View File

@ -142,12 +142,13 @@ class ControlBase:
class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, ckpt_name: str = None):
super().__init__(device)
self.control_model = control_model
self.load_device = load_device
if control_model is not None:
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
self.control_model_wrapped.ckpt_name = os.path.basename(ckpt_name)
self.compression_ratio = compression_ratio
self.global_average_pooling = global_average_pooling
self.model_sampling_current = None
@ -499,7 +500,7 @@ def load_controlnet(ckpt_path, model=None):
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype, ckpt_name=filename)
return control
class T2IAdapter(ControlBase):

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import copy
import logging
import pathlib
import warnings
from typing import Optional, Any, Callable, Union, List
@ -175,7 +176,7 @@ class TransformersManagedModel(ModelManageable):
if hasattr(self.processor, "to"):
self.processor.to(device=self.load_device)
assert "<image>" in prompt, "You must specify a &lt;image&gt; token inside the prompt for it to be substituted correctly by a HuggingFace processor"
assert "<image>" in prompt.lower(), "You must specify a &lt;image&gt; token inside the prompt for it to be substituted correctly by a HuggingFace processor"
batch_feature: BatchFeature = self.processor([prompt], images=images, padding=True, return_tensors="pt")
if hasattr(self.processor, "to"):
self.processor.to(device=self.offload_device)
@ -188,3 +189,10 @@ class TransformersManagedModel(ModelManageable):
"inputs": batch_feature["input_ids"],
**batch_feature
}
def __str__(self):
if self.repo_id is not None:
repo_id_as_path = pathlib.PurePath(self.repo_id)
return f"<TransformersManagedModel for {'/'.join(repo_id_as_path.parts[-2:])} ({self.model.__class__.__name__})>"
else:
return f"<TransformersManagedModel for {self.model.__class__.__name__}>"

View File

@ -29,7 +29,7 @@ _session = Session()
_hf_fs = HfFileSystem()
def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[List[Any]] = None) -> List[str]:
def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[List[Downloadable] | KnownDownloadables] = None) -> List[str]:
if known_files is None:
known_files = _get_known_models_for_folder_name(folder_name)

View File

@ -6,7 +6,7 @@ import sys
import warnings
from enum import Enum
from threading import RLock
from typing import Literal, List
from typing import Literal, List, Sequence
import psutil
import torch
@ -15,6 +15,7 @@ from opentelemetry.trace import get_current_span
from . import interruption
from .cli_args import args
from .cmd.main_pre import tracer
from .component_model.deprecation import _deprecate_method
from .model_management_types import ModelManageable
model_management_lock = RLock()
@ -439,7 +440,7 @@ def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
@tracer.start_as_current_span("Load Models GPU")
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False) -> None:
global vram_state
span = get_current_span()
if memory_required != 0:
@ -472,59 +473,61 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
models_to_load.append(loaded_model)
models_freed: List[LoadedModel] = []
if len(models_to_load) == 0:
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
models_freed += free_memory(extra_mem, d, models_already_loaded)
try:
if len(models_to_load) == 0:
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
models_freed += free_memory(extra_mem, d, models_already_loaded)
return
total_memory_required = {}
for loaded_model in models_to_load:
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required:
if device != torch.device("cpu"):
# todo: where does 1.3 come from?
models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded
for loaded_model in models_to_load:
model = loaded_model.model
torch_dev = model.load_device
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
else:
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3))
if model_size <= (current_free_mem - inference_memory): # only switch to lowvram if really necessary
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
return
total_memory_required = {}
for loaded_model in models_to_load:
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required:
if device != torch.device("cpu"):
models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
span.set_attribute("models_to_load", list(map(str, models_to_load)))
span.set_attribute("models_freed", list(map(str, models_freed)))
logging.info(f"Models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}")
for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded
for loaded_model in models_to_load:
model = loaded_model.model
torch_dev = model.load_device
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
else:
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3))
if model_size <= (current_free_mem - inference_memory): # only switch to lowvram if really necessary
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
return
finally:
span.set_attribute("models", list(map(str, models)))
span.set_attribute("models_to_load", list(map(str, models_to_load)))
span.set_attribute("models_freed", list(map(str, models_freed)))
logging.info(f"Requested to load {','.join(map(str, models))}, models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}")
@_deprecate_method(message="Use load_models_gpu instead", version="0.0.2")
def load_model_gpu(model):
with model_management_lock:
return load_models_gpu([model])
return load_models_gpu([model])
def loaded_models(only_currently_used=False):

View File

@ -9,11 +9,12 @@ class ModelManageable(Protocol):
"""
Objects which implement this protocol can be managed by
>>> import comfy.model_management
>>> class SomeObj("ModelManageable"):
>>> from comfy.model_management import load_models_gpu
>>> class ModelWrapper(ModelManageable):
>>> ...
>>>
>>> comfy.model_management.load_model_gpu(SomeObj())
>>> some_model = ModelWrapper()
>>> load_models_gpu([some_model])
"""
load_device: torch.device
offload_device: torch.device

View File

@ -128,7 +128,9 @@ class CustomNode(Protocol):
CATEGORY: ClassVar[str]
OUTPUT_NODE: Optional[ClassVar[bool]]
IS_CHANGED: Optional[ClassVar[IsChangedMethod]]
@classmethod
def IS_CHANGED(cls, *args, **kwargs) -> str:
...
@classmethod
def __call__(cls, *args, **kwargs) -> 'CustomNode':

View File

@ -20,6 +20,7 @@ from . import model_sampling
from . import sd1_clip
from . import sdxl_clip
from . import utils
from .model_management import load_models_gpu
from .text_encoders import sd2_clip
from .text_encoders import sd3_clip
@ -153,7 +154,7 @@ class CLIP:
return sd_clip
def load_model(self):
model_management.load_model_gpu(self.patcher)
load_models_gpu([self.patcher])
return self.patcher
def get_key_patches(self):
@ -337,7 +338,7 @@ class VAE:
return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap=16):
model_management.load_model_gpu(self.patcher)
load_models_gpu([self.patcher])
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
return output.movedim(1, -1)
@ -366,7 +367,7 @@ class VAE:
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
model_management.load_model_gpu(self.patcher)
load_models_gpu([self.patcher])
pixel_samples = pixel_samples.movedim(-1, 1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
return samples
@ -574,7 +575,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_model:
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device, ckpt_name=os.path.basename(ckpt_path))
if inital_load_device != torch.device("cpu"):
model_management.load_model_gpu(_model_patcher)
load_models_gpu([_model_patcher])
return (_model_patcher, clip, vae, clipvision)

View File

@ -18,7 +18,7 @@ from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
from comfy.language.language_types import ProcessorResult
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device, load_models_gpu
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
@ -340,7 +340,7 @@ class TransformersGenerate(CustomNode):
tokens = copy.copy(tokens)
sampler = sampler or {}
generate_kwargs = copy.copy(sampler)
load_model_gpu(model)
load_models_gpu([model])
transformers_model: PreTrainedModel = model.model
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
# remove unused inputs

View File

@ -1,24 +1,135 @@
import logging
import torch
from typing import Optional, Any
import torch
from spandrel import ModelLoader, ImageModelDescriptor
from comfy import model_management
from comfy import utils
from comfy.component_model.tensor_types import RGBImageBatch
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download
from comfy.model_management import load_models_gpu
from comfy.model_management_types import ModelManageable
try:
from spandrel_extra_arches import EXTRA_REGISTRY # pylint: disable=import-error
from spandrel import MAIN_REGISTRY
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
except:
pass
class UpscaleModelManageable(ModelManageable):
def __init__(self, model_descriptor: ImageModelDescriptor, ckpt_name: str):
self.ckpt_name = ckpt_name
self.model_descriptor = model_descriptor
self.model = model_descriptor.model
self.load_device = model_management.unet_offload_device()
self.offload_device = model_management.unet_offload_device()
self._current_device = self.offload_device
self._lowvram_patch_counter = 0
# Private properties for image sizes and channels
self._input_size = (1, 512, 512) # Default input size (batch, height, width)
self._input_channels = model_descriptor.input_channels
self._output_channels = model_descriptor.output_channels
self.tile = 512
@property
def current_device(self) -> torch.device:
return self._current_device
@property
def input_size(self) -> tuple[int, int, int]:
return self._input_size
@input_size.setter
def input_size(self, size: tuple[int, int, int]):
self._input_size = size
@property
def output_size(self) -> tuple[int, int, int]:
return (self._input_size[0],
self._input_size[1] * self.model_descriptor.scale,
self._input_size[2] * self.model_descriptor.scale)
def set_input_size_from_images(self, images: RGBImageBatch):
if images.ndim != 4:
raise ValueError("Input must be a 4D tensor (batch, height, width, channels)")
if images.shape[-1] != 3:
raise ValueError("Input must have 3 channels (RGB)")
self._input_size = (images.shape[0], images.shape[1], images.shape[2])
def is_clone(self, other: Any) -> bool:
return isinstance(other, UpscaleModelManageable) and self.model is other.model
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool:
return self.is_clone(clone)
def model_size(self) -> int:
# Calculate the size of the model parameters
model_params_size = sum(p.numel() * p.element_size() for p in self.model.parameters())
# Get the byte size of the model's dtype
dtype_size = torch.finfo(self.model_dtype()).bits // 8
# Calculate the memory required for input and output images
input_size = self._input_size[0] * min(self.tile, self._input_size[1]) * min(self.tile, self._input_size[2]) * self._input_channels * dtype_size
output_size = self.output_size[0] * self.output_size[1] * self.output_size[2] * self._output_channels * dtype_size
# Add some extra memory for processing
extra_memory = (input_size + output_size) * 2 # This is an estimate, adjust as needed
return model_params_size + input_size + output_size + extra_memory
def model_patches_to(self, arg: torch.device | torch.dtype):
if isinstance(arg, torch.device):
self.model.to(device=arg)
else:
self.model.to(dtype=arg)
def model_dtype(self) -> torch.dtype:
return next(self.model.parameters()).dtype
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
self.model.to(device=device_to)
self._current_device = device_to
self._lowvram_patch_counter += 1
return self.model
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
if patch_weights:
self.model.to(device=device_to)
self._current_device = device_to
return self.model
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
if unpatch_weights:
self.model.to(device=offload_device)
self._current_device = offload_device
return self.model
@property
def lowvram_patch_counter(self) -> int:
return self._lowvram_patch_counter
@lowvram_patch_counter.setter
def lowvram_patch_counter(self, value: int):
self._lowvram_patch_counter = value
def __str__(self):
if self.ckpt_name is not None:
return f"<UpscaleModelManageable for {self.ckpt_name} ({self.model.__class__.__name__})>"
else:
return f"<UpscaleModelManageable for {self.model.__class__.__name__}>"
class UpscaleModelLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model_name": (get_filename_list_with_downloadable("upscale_models", KNOWN_UPSCALERS),),
return {"required": {"model_name": (get_filename_list_with_downloadable("upscale_models"),),
}}
RETURN_TYPES = ("UPSCALE_MODEL",)
@ -36,7 +147,7 @@ class UpscaleModelLoader:
if not isinstance(out, ImageModelDescriptor):
raise Exception("Upscale model must be a single-image model.")
return (out,)
return (UpscaleModelManageable(out, model_name),)
class ImageUpscaleWithModel:
@ -51,33 +162,30 @@ class ImageUpscaleWithModel:
CATEGORY = "image/upscaling"
def upscale(self, upscale_model, image):
device = model_management.get_torch_device()
def upscale(self, upscale_model: UpscaleModelManageable, image: RGBImageBatch):
upscale_model.set_input_size_from_images(image)
load_models_gpu([upscale_model])
memory_required = model_management.module_size(upscale_model.model)
memory_required += (512 * 512 * 3) * image.element_size() * max(upscale_model.scale, 1.0) * 384.0 # The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate
memory_required += image.nelement() * image.element_size()
model_management.free_memory(memory_required, device)
in_img = image.movedim(-1, -3).to(upscale_model.current_device, dtype=upscale_model.model_dtype())
upscale_model.to(device)
in_img = image.movedim(-1, -3).to(device)
tile = 512
tile = upscale_model.tile
overlap = 32
oom = True
s = None
while oom:
try:
steps = in_img.shape[0] * utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = utils.ProgressBar(steps)
s = utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.model.scale, pbar=pbar)
oom = False
except model_management.OOM_EXCEPTION as e:
tile //= 2
if tile < 128:
overlap //= 2
if tile < 64 or overlap < 4:
raise e
upscale_model.to("cpu")
# upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
return (s,)

View File

@ -61,3 +61,4 @@ PySoundFile
networkx>=2.6.3
joblib
jaxtyping
spandrel_extra_arches