mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Improve performance and memory management of upscale models, improve messaging on models loaded and unloaded from the GPU
This commit is contained in:
parent
c6ce11b421
commit
ce5fe01768
@ -1,4 +1,5 @@
|
||||
from .component_model import files
|
||||
from .model_management import load_models_gpu
|
||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||
import torch
|
||||
import json
|
||||
@ -57,7 +58,7 @@ class ClipVisionModel():
|
||||
return self.model.state_dict()
|
||||
|
||||
def encode_image(self, image):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
load_models_gpu([self.patcher])
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
|
||||
|
||||
@ -1,17 +1,7 @@
|
||||
from typing import Annotated
|
||||
|
||||
from jaxtyping import Float, Shaped
|
||||
from jaxtyping import Float
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def channels_constraint(n: int):
|
||||
def constraint(x: Tensor) -> bool:
|
||||
return x.shape[-1] == n
|
||||
|
||||
return constraint
|
||||
|
||||
|
||||
ImageBatch = Float[Tensor, "batch height width channels"]
|
||||
RGBImageBatch = Annotated[ImageBatch, Shaped[channels_constraint(3)]] | Float[Tensor, "batch height width 3"]
|
||||
RGBAImageBatch = Annotated[ImageBatch, Shaped[channels_constraint(4)]] | Float[Tensor, "batch height width 4"]
|
||||
RGBImageBatch = Float[Tensor, "batch height width 3"]
|
||||
RGBAImageBatch = Float[Tensor, "batch height width 4"]
|
||||
RGBImage = Float[Tensor, "height width 3"]
|
||||
|
||||
@ -142,12 +142,13 @@ class ControlBase:
|
||||
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, ckpt_name: str = None):
|
||||
super().__init__(device)
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
if control_model is not None:
|
||||
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
self.control_model_wrapped.ckpt_name = os.path.basename(ckpt_name)
|
||||
self.compression_ratio = compression_ratio
|
||||
self.global_average_pooling = global_average_pooling
|
||||
self.model_sampling_current = None
|
||||
@ -499,7 +500,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype, ckpt_name=filename)
|
||||
return control
|
||||
|
||||
class T2IAdapter(ControlBase):
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import pathlib
|
||||
import warnings
|
||||
from typing import Optional, Any, Callable, Union, List
|
||||
|
||||
@ -175,7 +176,7 @@ class TransformersManagedModel(ModelManageable):
|
||||
if hasattr(self.processor, "to"):
|
||||
self.processor.to(device=self.load_device)
|
||||
|
||||
assert "<image>" in prompt, "You must specify a <image> token inside the prompt for it to be substituted correctly by a HuggingFace processor"
|
||||
assert "<image>" in prompt.lower(), "You must specify a <image> token inside the prompt for it to be substituted correctly by a HuggingFace processor"
|
||||
batch_feature: BatchFeature = self.processor([prompt], images=images, padding=True, return_tensors="pt")
|
||||
if hasattr(self.processor, "to"):
|
||||
self.processor.to(device=self.offload_device)
|
||||
@ -188,3 +189,10 @@ class TransformersManagedModel(ModelManageable):
|
||||
"inputs": batch_feature["input_ids"],
|
||||
**batch_feature
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
if self.repo_id is not None:
|
||||
repo_id_as_path = pathlib.PurePath(self.repo_id)
|
||||
return f"<TransformersManagedModel for {'/'.join(repo_id_as_path.parts[-2:])} ({self.model.__class__.__name__})>"
|
||||
else:
|
||||
return f"<TransformersManagedModel for {self.model.__class__.__name__}>"
|
||||
|
||||
@ -29,7 +29,7 @@ _session = Session()
|
||||
_hf_fs = HfFileSystem()
|
||||
|
||||
|
||||
def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[List[Any]] = None) -> List[str]:
|
||||
def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[List[Downloadable] | KnownDownloadables] = None) -> List[str]:
|
||||
if known_files is None:
|
||||
known_files = _get_known_models_for_folder_name(folder_name)
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ import sys
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from threading import RLock
|
||||
from typing import Literal, List
|
||||
from typing import Literal, List, Sequence
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
@ -15,6 +15,7 @@ from opentelemetry.trace import get_current_span
|
||||
from . import interruption
|
||||
from .cli_args import args
|
||||
from .cmd.main_pre import tracer
|
||||
from .component_model.deprecation import _deprecate_method
|
||||
from .model_management_types import ModelManageable
|
||||
|
||||
model_management_lock = RLock()
|
||||
@ -439,7 +440,7 @@ def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
|
||||
|
||||
|
||||
@tracer.start_as_current_span("Load Models GPU")
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
||||
def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False) -> None:
|
||||
global vram_state
|
||||
span = get_current_span()
|
||||
if memory_required != 0:
|
||||
@ -472,59 +473,61 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
||||
models_to_load.append(loaded_model)
|
||||
|
||||
models_freed: List[LoadedModel] = []
|
||||
if len(models_to_load) == 0:
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
models_freed += free_memory(extra_mem, d, models_already_loaded)
|
||||
try:
|
||||
if len(models_to_load) == 0:
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
models_freed += free_memory(extra_mem, d, models_already_loaded)
|
||||
return
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
# todo: where does 1.3 come from?
|
||||
models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
|
||||
if weights_unloaded is not None:
|
||||
loaded_model.weights_loaded = not weights_unloaded
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
model = loaded_model.model
|
||||
torch_dev = model.load_device
|
||||
if is_device_cpu(torch_dev):
|
||||
vram_set_state = VRAMState.DISABLED
|
||||
else:
|
||||
vram_set_state = vram_state
|
||||
lowvram_model_memory = 0
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
model_size = loaded_model.model_memory_required(torch_dev)
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3))
|
||||
if model_size <= (current_free_mem - inference_memory): # only switch to lowvram if really necessary
|
||||
|
||||
lowvram_model_memory = 0
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 64 * 1024 * 1024
|
||||
|
||||
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
return
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
||||
|
||||
span.set_attribute("models_to_load", list(map(str, models_to_load)))
|
||||
span.set_attribute("models_freed", list(map(str, models_freed)))
|
||||
|
||||
logging.info(f"Models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}")
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
|
||||
if weights_unloaded is not None:
|
||||
loaded_model.weights_loaded = not weights_unloaded
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
model = loaded_model.model
|
||||
torch_dev = model.load_device
|
||||
if is_device_cpu(torch_dev):
|
||||
vram_set_state = VRAMState.DISABLED
|
||||
else:
|
||||
vram_set_state = vram_state
|
||||
lowvram_model_memory = 0
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
model_size = loaded_model.model_memory_required(torch_dev)
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3))
|
||||
if model_size <= (current_free_mem - inference_memory): # only switch to lowvram if really necessary
|
||||
|
||||
lowvram_model_memory = 0
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 64 * 1024 * 1024
|
||||
|
||||
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
return
|
||||
finally:
|
||||
span.set_attribute("models", list(map(str, models)))
|
||||
span.set_attribute("models_to_load", list(map(str, models_to_load)))
|
||||
span.set_attribute("models_freed", list(map(str, models_freed)))
|
||||
logging.info(f"Requested to load {','.join(map(str, models))}, models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}")
|
||||
|
||||
|
||||
@_deprecate_method(message="Use load_models_gpu instead", version="0.0.2")
|
||||
def load_model_gpu(model):
|
||||
with model_management_lock:
|
||||
return load_models_gpu([model])
|
||||
return load_models_gpu([model])
|
||||
|
||||
|
||||
def loaded_models(only_currently_used=False):
|
||||
|
||||
@ -9,11 +9,12 @@ class ModelManageable(Protocol):
|
||||
"""
|
||||
Objects which implement this protocol can be managed by
|
||||
|
||||
>>> import comfy.model_management
|
||||
>>> class SomeObj("ModelManageable"):
|
||||
>>> from comfy.model_management import load_models_gpu
|
||||
>>> class ModelWrapper(ModelManageable):
|
||||
>>> ...
|
||||
>>>
|
||||
>>> comfy.model_management.load_model_gpu(SomeObj())
|
||||
>>> some_model = ModelWrapper()
|
||||
>>> load_models_gpu([some_model])
|
||||
"""
|
||||
load_device: torch.device
|
||||
offload_device: torch.device
|
||||
|
||||
@ -128,7 +128,9 @@ class CustomNode(Protocol):
|
||||
CATEGORY: ClassVar[str]
|
||||
OUTPUT_NODE: Optional[ClassVar[bool]]
|
||||
|
||||
IS_CHANGED: Optional[ClassVar[IsChangedMethod]]
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, *args, **kwargs) -> str:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def __call__(cls, *args, **kwargs) -> 'CustomNode':
|
||||
|
||||
@ -20,6 +20,7 @@ from . import model_sampling
|
||||
from . import sd1_clip
|
||||
from . import sdxl_clip
|
||||
from . import utils
|
||||
from .model_management import load_models_gpu
|
||||
|
||||
from .text_encoders import sd2_clip
|
||||
from .text_encoders import sd3_clip
|
||||
@ -153,7 +154,7 @@ class CLIP:
|
||||
return sd_clip
|
||||
|
||||
def load_model(self):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
load_models_gpu([self.patcher])
|
||||
return self.patcher
|
||||
|
||||
def get_key_patches(self):
|
||||
@ -337,7 +338,7 @@ class VAE:
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap=16):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
load_models_gpu([self.patcher])
|
||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||
return output.movedim(1, -1)
|
||||
|
||||
@ -366,7 +367,7 @@ class VAE:
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
load_models_gpu([self.patcher])
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||
return samples
|
||||
@ -574,7 +575,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if output_model:
|
||||
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device, ckpt_name=os.path.basename(ckpt_path))
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
model_management.load_model_gpu(_model_patcher)
|
||||
load_models_gpu([_model_patcher])
|
||||
|
||||
return (_model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
||||
from comfy.language.language_types import ProcessorResult
|
||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
|
||||
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
|
||||
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device, load_models_gpu
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
||||
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
|
||||
|
||||
@ -340,7 +340,7 @@ class TransformersGenerate(CustomNode):
|
||||
tokens = copy.copy(tokens)
|
||||
sampler = sampler or {}
|
||||
generate_kwargs = copy.copy(sampler)
|
||||
load_model_gpu(model)
|
||||
load_models_gpu([model])
|
||||
transformers_model: PreTrainedModel = model.model
|
||||
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
|
||||
# remove unused inputs
|
||||
|
||||
@ -1,24 +1,135 @@
|
||||
import logging
|
||||
import torch
|
||||
from typing import Optional, Any
|
||||
|
||||
import torch
|
||||
from spandrel import ModelLoader, ImageModelDescriptor
|
||||
|
||||
from comfy import model_management
|
||||
from comfy import utils
|
||||
from comfy.component_model.tensor_types import RGBImageBatch
|
||||
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download
|
||||
|
||||
from comfy.model_management import load_models_gpu
|
||||
from comfy.model_management_types import ModelManageable
|
||||
|
||||
try:
|
||||
from spandrel_extra_arches import EXTRA_REGISTRY # pylint: disable=import-error
|
||||
from spandrel import MAIN_REGISTRY
|
||||
|
||||
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
|
||||
logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class UpscaleModelManageable(ModelManageable):
|
||||
def __init__(self, model_descriptor: ImageModelDescriptor, ckpt_name: str):
|
||||
self.ckpt_name = ckpt_name
|
||||
self.model_descriptor = model_descriptor
|
||||
self.model = model_descriptor.model
|
||||
self.load_device = model_management.unet_offload_device()
|
||||
self.offload_device = model_management.unet_offload_device()
|
||||
self._current_device = self.offload_device
|
||||
self._lowvram_patch_counter = 0
|
||||
|
||||
# Private properties for image sizes and channels
|
||||
self._input_size = (1, 512, 512) # Default input size (batch, height, width)
|
||||
self._input_channels = model_descriptor.input_channels
|
||||
self._output_channels = model_descriptor.output_channels
|
||||
self.tile = 512
|
||||
|
||||
@property
|
||||
def current_device(self) -> torch.device:
|
||||
return self._current_device
|
||||
|
||||
@property
|
||||
def input_size(self) -> tuple[int, int, int]:
|
||||
return self._input_size
|
||||
|
||||
@input_size.setter
|
||||
def input_size(self, size: tuple[int, int, int]):
|
||||
self._input_size = size
|
||||
|
||||
@property
|
||||
def output_size(self) -> tuple[int, int, int]:
|
||||
return (self._input_size[0],
|
||||
self._input_size[1] * self.model_descriptor.scale,
|
||||
self._input_size[2] * self.model_descriptor.scale)
|
||||
|
||||
def set_input_size_from_images(self, images: RGBImageBatch):
|
||||
if images.ndim != 4:
|
||||
raise ValueError("Input must be a 4D tensor (batch, height, width, channels)")
|
||||
if images.shape[-1] != 3:
|
||||
raise ValueError("Input must have 3 channels (RGB)")
|
||||
self._input_size = (images.shape[0], images.shape[1], images.shape[2])
|
||||
|
||||
def is_clone(self, other: Any) -> bool:
|
||||
return isinstance(other, UpscaleModelManageable) and self.model is other.model
|
||||
|
||||
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool:
|
||||
return self.is_clone(clone)
|
||||
|
||||
def model_size(self) -> int:
|
||||
# Calculate the size of the model parameters
|
||||
model_params_size = sum(p.numel() * p.element_size() for p in self.model.parameters())
|
||||
|
||||
# Get the byte size of the model's dtype
|
||||
dtype_size = torch.finfo(self.model_dtype()).bits // 8
|
||||
|
||||
# Calculate the memory required for input and output images
|
||||
input_size = self._input_size[0] * min(self.tile, self._input_size[1]) * min(self.tile, self._input_size[2]) * self._input_channels * dtype_size
|
||||
output_size = self.output_size[0] * self.output_size[1] * self.output_size[2] * self._output_channels * dtype_size
|
||||
|
||||
# Add some extra memory for processing
|
||||
extra_memory = (input_size + output_size) * 2 # This is an estimate, adjust as needed
|
||||
|
||||
return model_params_size + input_size + output_size + extra_memory
|
||||
|
||||
def model_patches_to(self, arg: torch.device | torch.dtype):
|
||||
if isinstance(arg, torch.device):
|
||||
self.model.to(device=arg)
|
||||
else:
|
||||
self.model.to(dtype=arg)
|
||||
|
||||
def model_dtype(self) -> torch.dtype:
|
||||
return next(self.model.parameters()).dtype
|
||||
|
||||
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
self.model.to(device=device_to)
|
||||
self._current_device = device_to
|
||||
self._lowvram_patch_counter += 1
|
||||
return self.model
|
||||
|
||||
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
|
||||
if patch_weights:
|
||||
self.model.to(device=device_to)
|
||||
self._current_device = device_to
|
||||
return self.model
|
||||
|
||||
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
if unpatch_weights:
|
||||
self.model.to(device=offload_device)
|
||||
self._current_device = offload_device
|
||||
return self.model
|
||||
|
||||
@property
|
||||
def lowvram_patch_counter(self) -> int:
|
||||
return self._lowvram_patch_counter
|
||||
|
||||
@lowvram_patch_counter.setter
|
||||
def lowvram_patch_counter(self, value: int):
|
||||
self._lowvram_patch_counter = value
|
||||
|
||||
def __str__(self):
|
||||
if self.ckpt_name is not None:
|
||||
return f"<UpscaleModelManageable for {self.ckpt_name} ({self.model.__class__.__name__})>"
|
||||
else:
|
||||
return f"<UpscaleModelManageable for {self.model.__class__.__name__}>"
|
||||
|
||||
|
||||
class UpscaleModelLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model_name": (get_filename_list_with_downloadable("upscale_models", KNOWN_UPSCALERS),),
|
||||
return {"required": {"model_name": (get_filename_list_with_downloadable("upscale_models"),),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("UPSCALE_MODEL",)
|
||||
@ -36,7 +147,7 @@ class UpscaleModelLoader:
|
||||
if not isinstance(out, ImageModelDescriptor):
|
||||
raise Exception("Upscale model must be a single-image model.")
|
||||
|
||||
return (out,)
|
||||
return (UpscaleModelManageable(out, model_name),)
|
||||
|
||||
|
||||
class ImageUpscaleWithModel:
|
||||
@ -51,33 +162,30 @@ class ImageUpscaleWithModel:
|
||||
|
||||
CATEGORY = "image/upscaling"
|
||||
|
||||
def upscale(self, upscale_model, image):
|
||||
device = model_management.get_torch_device()
|
||||
def upscale(self, upscale_model: UpscaleModelManageable, image: RGBImageBatch):
|
||||
upscale_model.set_input_size_from_images(image)
|
||||
load_models_gpu([upscale_model])
|
||||
|
||||
memory_required = model_management.module_size(upscale_model.model)
|
||||
memory_required += (512 * 512 * 3) * image.element_size() * max(upscale_model.scale, 1.0) * 384.0 # The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate
|
||||
memory_required += image.nelement() * image.element_size()
|
||||
model_management.free_memory(memory_required, device)
|
||||
in_img = image.movedim(-1, -3).to(upscale_model.current_device, dtype=upscale_model.model_dtype())
|
||||
|
||||
upscale_model.to(device)
|
||||
in_img = image.movedim(-1, -3).to(device)
|
||||
|
||||
tile = 512
|
||||
tile = upscale_model.tile
|
||||
overlap = 32
|
||||
|
||||
oom = True
|
||||
s = None
|
||||
while oom:
|
||||
try:
|
||||
steps = in_img.shape[0] * utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
pbar = utils.ProgressBar(steps)
|
||||
s = utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||
s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.model.scale, pbar=pbar)
|
||||
oom = False
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
tile //= 2
|
||||
if tile < 128:
|
||||
overlap //= 2
|
||||
if tile < 64 or overlap < 4:
|
||||
raise e
|
||||
|
||||
upscale_model.to("cpu")
|
||||
# upscale_model.to("cpu")
|
||||
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
|
||||
return (s,)
|
||||
|
||||
|
||||
@ -61,3 +61,4 @@ PySoundFile
|
||||
networkx>=2.6.3
|
||||
joblib
|
||||
jaxtyping
|
||||
spandrel_extra_arches
|
||||
|
||||
Loading…
Reference in New Issue
Block a user