Improve performance and memory management of upscale models, improve messaging on models loaded and unloaded from the GPU

This commit is contained in:
doctorpangloss 2024-07-30 17:05:53 -07:00
parent c6ce11b421
commit ce5fe01768
12 changed files with 212 additions and 96 deletions

View File

@ -1,4 +1,5 @@
from .component_model import files from .component_model import files
from .model_management import load_models_gpu
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
import torch import torch
import json import json
@ -57,7 +58,7 @@ class ClipVisionModel():
return self.model.state_dict() return self.model.state_dict()
def encode_image(self, image): def encode_image(self, image):
model_management.load_model_gpu(self.patcher) load_models_gpu([self.patcher])
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float() pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2) out = self.model(pixel_values=pixel_values, intermediate_output=-2)

View File

@ -1,17 +1,7 @@
from typing import Annotated from jaxtyping import Float
from jaxtyping import Float, Shaped
from torch import Tensor from torch import Tensor
def channels_constraint(n: int):
def constraint(x: Tensor) -> bool:
return x.shape[-1] == n
return constraint
ImageBatch = Float[Tensor, "batch height width channels"] ImageBatch = Float[Tensor, "batch height width channels"]
RGBImageBatch = Annotated[ImageBatch, Shaped[channels_constraint(3)]] | Float[Tensor, "batch height width 3"] RGBImageBatch = Float[Tensor, "batch height width 3"]
RGBAImageBatch = Annotated[ImageBatch, Shaped[channels_constraint(4)]] | Float[Tensor, "batch height width 4"] RGBAImageBatch = Float[Tensor, "batch height width 4"]
RGBImage = Float[Tensor, "height width 3"] RGBImage = Float[Tensor, "height width 3"]

View File

@ -142,12 +142,13 @@ class ControlBase:
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None): def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, ckpt_name: str = None):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.load_device = load_device self.load_device = load_device
if control_model is not None: if control_model is not None:
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device()) self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
self.control_model_wrapped.ckpt_name = os.path.basename(ckpt_name)
self.compression_ratio = compression_ratio self.compression_ratio = compression_ratio
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
self.model_sampling_current = None self.model_sampling_current = None
@ -499,7 +500,7 @@ def load_controlnet(ckpt_path, model=None):
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype, ckpt_name=filename)
return control return control
class T2IAdapter(ControlBase): class T2IAdapter(ControlBase):

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import copy import copy
import logging import logging
import pathlib
import warnings import warnings
from typing import Optional, Any, Callable, Union, List from typing import Optional, Any, Callable, Union, List
@ -175,7 +176,7 @@ class TransformersManagedModel(ModelManageable):
if hasattr(self.processor, "to"): if hasattr(self.processor, "to"):
self.processor.to(device=self.load_device) self.processor.to(device=self.load_device)
assert "<image>" in prompt, "You must specify a &lt;image&gt; token inside the prompt for it to be substituted correctly by a HuggingFace processor" assert "<image>" in prompt.lower(), "You must specify a &lt;image&gt; token inside the prompt for it to be substituted correctly by a HuggingFace processor"
batch_feature: BatchFeature = self.processor([prompt], images=images, padding=True, return_tensors="pt") batch_feature: BatchFeature = self.processor([prompt], images=images, padding=True, return_tensors="pt")
if hasattr(self.processor, "to"): if hasattr(self.processor, "to"):
self.processor.to(device=self.offload_device) self.processor.to(device=self.offload_device)
@ -188,3 +189,10 @@ class TransformersManagedModel(ModelManageable):
"inputs": batch_feature["input_ids"], "inputs": batch_feature["input_ids"],
**batch_feature **batch_feature
} }
def __str__(self):
if self.repo_id is not None:
repo_id_as_path = pathlib.PurePath(self.repo_id)
return f"<TransformersManagedModel for {'/'.join(repo_id_as_path.parts[-2:])} ({self.model.__class__.__name__})>"
else:
return f"<TransformersManagedModel for {self.model.__class__.__name__}>"

View File

@ -29,7 +29,7 @@ _session = Session()
_hf_fs = HfFileSystem() _hf_fs = HfFileSystem()
def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[List[Any]] = None) -> List[str]: def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[List[Downloadable] | KnownDownloadables] = None) -> List[str]:
if known_files is None: if known_files is None:
known_files = _get_known_models_for_folder_name(folder_name) known_files = _get_known_models_for_folder_name(folder_name)

View File

@ -6,7 +6,7 @@ import sys
import warnings import warnings
from enum import Enum from enum import Enum
from threading import RLock from threading import RLock
from typing import Literal, List from typing import Literal, List, Sequence
import psutil import psutil
import torch import torch
@ -15,6 +15,7 @@ from opentelemetry.trace import get_current_span
from . import interruption from . import interruption
from .cli_args import args from .cli_args import args
from .cmd.main_pre import tracer from .cmd.main_pre import tracer
from .component_model.deprecation import _deprecate_method
from .model_management_types import ModelManageable from .model_management_types import ModelManageable
model_management_lock = RLock() model_management_lock = RLock()
@ -439,7 +440,7 @@ def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
@tracer.start_as_current_span("Load Models GPU") @tracer.start_as_current_span("Load Models GPU")
def load_models_gpu(models, memory_required=0, force_patch_weights=False): def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False) -> None:
global vram_state global vram_state
span = get_current_span() span = get_current_span()
if memory_required != 0: if memory_required != 0:
@ -472,59 +473,61 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
models_to_load.append(loaded_model) models_to_load.append(loaded_model)
models_freed: List[LoadedModel] = [] models_freed: List[LoadedModel] = []
if len(models_to_load) == 0: try:
devs = set(map(lambda a: a.device, models_already_loaded)) if len(models_to_load) == 0:
for d in devs: devs = set(map(lambda a: a.device, models_already_loaded))
if d != torch.device("cpu"): for d in devs:
models_freed += free_memory(extra_mem, d, models_already_loaded) if d != torch.device("cpu"):
models_freed += free_memory(extra_mem, d, models_already_loaded)
return
total_memory_required = {}
for loaded_model in models_to_load:
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required:
if device != torch.device("cpu"):
# todo: where does 1.3 come from?
models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded
for loaded_model in models_to_load:
model = loaded_model.model
torch_dev = model.load_device
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
else:
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3))
if model_size <= (current_free_mem - inference_memory): # only switch to lowvram if really necessary
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
return return
finally:
total_memory_required = {} span.set_attribute("models", list(map(str, models)))
for loaded_model in models_to_load: span.set_attribute("models_to_load", list(map(str, models_to_load)))
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different span.set_attribute("models_freed", list(map(str, models_freed)))
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) logging.info(f"Requested to load {','.join(map(str, models))}, models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}")
for device in total_memory_required:
if device != torch.device("cpu"):
models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
span.set_attribute("models_to_load", list(map(str, models_to_load)))
span.set_attribute("models_freed", list(map(str, models_freed)))
logging.info(f"Models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}")
for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded
for loaded_model in models_to_load:
model = loaded_model.model
torch_dev = model.load_device
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
else:
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3))
if model_size <= (current_free_mem - inference_memory): # only switch to lowvram if really necessary
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
return
@_deprecate_method(message="Use load_models_gpu instead", version="0.0.2")
def load_model_gpu(model): def load_model_gpu(model):
with model_management_lock: return load_models_gpu([model])
return load_models_gpu([model])
def loaded_models(only_currently_used=False): def loaded_models(only_currently_used=False):

View File

@ -9,11 +9,12 @@ class ModelManageable(Protocol):
""" """
Objects which implement this protocol can be managed by Objects which implement this protocol can be managed by
>>> import comfy.model_management >>> from comfy.model_management import load_models_gpu
>>> class SomeObj("ModelManageable"): >>> class ModelWrapper(ModelManageable):
>>> ... >>> ...
>>> >>>
>>> comfy.model_management.load_model_gpu(SomeObj()) >>> some_model = ModelWrapper()
>>> load_models_gpu([some_model])
""" """
load_device: torch.device load_device: torch.device
offload_device: torch.device offload_device: torch.device

View File

@ -128,7 +128,9 @@ class CustomNode(Protocol):
CATEGORY: ClassVar[str] CATEGORY: ClassVar[str]
OUTPUT_NODE: Optional[ClassVar[bool]] OUTPUT_NODE: Optional[ClassVar[bool]]
IS_CHANGED: Optional[ClassVar[IsChangedMethod]] @classmethod
def IS_CHANGED(cls, *args, **kwargs) -> str:
...
@classmethod @classmethod
def __call__(cls, *args, **kwargs) -> 'CustomNode': def __call__(cls, *args, **kwargs) -> 'CustomNode':

View File

@ -20,6 +20,7 @@ from . import model_sampling
from . import sd1_clip from . import sd1_clip
from . import sdxl_clip from . import sdxl_clip
from . import utils from . import utils
from .model_management import load_models_gpu
from .text_encoders import sd2_clip from .text_encoders import sd2_clip
from .text_encoders import sd3_clip from .text_encoders import sd3_clip
@ -153,7 +154,7 @@ class CLIP:
return sd_clip return sd_clip
def load_model(self): def load_model(self):
model_management.load_model_gpu(self.patcher) load_models_gpu([self.patcher])
return self.patcher return self.patcher
def get_key_patches(self): def get_key_patches(self):
@ -337,7 +338,7 @@ class VAE:
return pixel_samples return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap=16): def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap=16):
model_management.load_model_gpu(self.patcher) load_models_gpu([self.patcher])
output = self.decode_tiled_(samples, tile_x, tile_y, overlap) output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
return output.movedim(1, -1) return output.movedim(1, -1)
@ -366,7 +367,7 @@ class VAE:
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap=64): def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
model_management.load_model_gpu(self.patcher) load_models_gpu([self.patcher])
pixel_samples = pixel_samples.movedim(-1, 1) pixel_samples = pixel_samples.movedim(-1, 1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
return samples return samples
@ -574,7 +575,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_model: if output_model:
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device, ckpt_name=os.path.basename(ckpt_path)) _model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device, ckpt_name=os.path.basename(ckpt_path))
if inital_load_device != torch.device("cpu"): if inital_load_device != torch.device("cpu"):
model_management.load_model_gpu(_model_patcher) load_models_gpu([_model_patcher])
return (_model_patcher, clip, vae, clipvision) return (_model_patcher, clip, vae, clipvision)

View File

@ -18,7 +18,7 @@ from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
from comfy.language.language_types import ProcessorResult from comfy.language.language_types import ProcessorResult
from comfy.language.transformers_model_management import TransformersManagedModel from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device, load_models_gpu
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
@ -340,7 +340,7 @@ class TransformersGenerate(CustomNode):
tokens = copy.copy(tokens) tokens = copy.copy(tokens)
sampler = sampler or {} sampler = sampler or {}
generate_kwargs = copy.copy(sampler) generate_kwargs = copy.copy(sampler)
load_model_gpu(model) load_models_gpu([model])
transformers_model: PreTrainedModel = model.model transformers_model: PreTrainedModel = model.model
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
# remove unused inputs # remove unused inputs

View File

@ -1,24 +1,135 @@
import logging import logging
import torch from typing import Optional, Any
import torch
from spandrel import ModelLoader, ImageModelDescriptor from spandrel import ModelLoader, ImageModelDescriptor
from comfy import model_management from comfy import model_management
from comfy import utils from comfy import utils
from comfy.component_model.tensor_types import RGBImageBatch
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download
from comfy.model_management import load_models_gpu
from comfy.model_management_types import ModelManageable
try: try:
from spandrel_extra_arches import EXTRA_REGISTRY # pylint: disable=import-error from spandrel_extra_arches import EXTRA_REGISTRY # pylint: disable=import-error
from spandrel import MAIN_REGISTRY from spandrel import MAIN_REGISTRY
MAIN_REGISTRY.add(*EXTRA_REGISTRY) MAIN_REGISTRY.add(*EXTRA_REGISTRY)
logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.") logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
except: except:
pass pass
class UpscaleModelManageable(ModelManageable):
def __init__(self, model_descriptor: ImageModelDescriptor, ckpt_name: str):
self.ckpt_name = ckpt_name
self.model_descriptor = model_descriptor
self.model = model_descriptor.model
self.load_device = model_management.unet_offload_device()
self.offload_device = model_management.unet_offload_device()
self._current_device = self.offload_device
self._lowvram_patch_counter = 0
# Private properties for image sizes and channels
self._input_size = (1, 512, 512) # Default input size (batch, height, width)
self._input_channels = model_descriptor.input_channels
self._output_channels = model_descriptor.output_channels
self.tile = 512
@property
def current_device(self) -> torch.device:
return self._current_device
@property
def input_size(self) -> tuple[int, int, int]:
return self._input_size
@input_size.setter
def input_size(self, size: tuple[int, int, int]):
self._input_size = size
@property
def output_size(self) -> tuple[int, int, int]:
return (self._input_size[0],
self._input_size[1] * self.model_descriptor.scale,
self._input_size[2] * self.model_descriptor.scale)
def set_input_size_from_images(self, images: RGBImageBatch):
if images.ndim != 4:
raise ValueError("Input must be a 4D tensor (batch, height, width, channels)")
if images.shape[-1] != 3:
raise ValueError("Input must have 3 channels (RGB)")
self._input_size = (images.shape[0], images.shape[1], images.shape[2])
def is_clone(self, other: Any) -> bool:
return isinstance(other, UpscaleModelManageable) and self.model is other.model
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool:
return self.is_clone(clone)
def model_size(self) -> int:
# Calculate the size of the model parameters
model_params_size = sum(p.numel() * p.element_size() for p in self.model.parameters())
# Get the byte size of the model's dtype
dtype_size = torch.finfo(self.model_dtype()).bits // 8
# Calculate the memory required for input and output images
input_size = self._input_size[0] * min(self.tile, self._input_size[1]) * min(self.tile, self._input_size[2]) * self._input_channels * dtype_size
output_size = self.output_size[0] * self.output_size[1] * self.output_size[2] * self._output_channels * dtype_size
# Add some extra memory for processing
extra_memory = (input_size + output_size) * 2 # This is an estimate, adjust as needed
return model_params_size + input_size + output_size + extra_memory
def model_patches_to(self, arg: torch.device | torch.dtype):
if isinstance(arg, torch.device):
self.model.to(device=arg)
else:
self.model.to(dtype=arg)
def model_dtype(self) -> torch.dtype:
return next(self.model.parameters()).dtype
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
self.model.to(device=device_to)
self._current_device = device_to
self._lowvram_patch_counter += 1
return self.model
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
if patch_weights:
self.model.to(device=device_to)
self._current_device = device_to
return self.model
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
if unpatch_weights:
self.model.to(device=offload_device)
self._current_device = offload_device
return self.model
@property
def lowvram_patch_counter(self) -> int:
return self._lowvram_patch_counter
@lowvram_patch_counter.setter
def lowvram_patch_counter(self, value: int):
self._lowvram_patch_counter = value
def __str__(self):
if self.ckpt_name is not None:
return f"<UpscaleModelManageable for {self.ckpt_name} ({self.model.__class__.__name__})>"
else:
return f"<UpscaleModelManageable for {self.model.__class__.__name__}>"
class UpscaleModelLoader: class UpscaleModelLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"model_name": (get_filename_list_with_downloadable("upscale_models", KNOWN_UPSCALERS),), return {"required": {"model_name": (get_filename_list_with_downloadable("upscale_models"),),
}} }}
RETURN_TYPES = ("UPSCALE_MODEL",) RETURN_TYPES = ("UPSCALE_MODEL",)
@ -36,7 +147,7 @@ class UpscaleModelLoader:
if not isinstance(out, ImageModelDescriptor): if not isinstance(out, ImageModelDescriptor):
raise Exception("Upscale model must be a single-image model.") raise Exception("Upscale model must be a single-image model.")
return (out,) return (UpscaleModelManageable(out, model_name),)
class ImageUpscaleWithModel: class ImageUpscaleWithModel:
@ -51,33 +162,30 @@ class ImageUpscaleWithModel:
CATEGORY = "image/upscaling" CATEGORY = "image/upscaling"
def upscale(self, upscale_model, image): def upscale(self, upscale_model: UpscaleModelManageable, image: RGBImageBatch):
device = model_management.get_torch_device() upscale_model.set_input_size_from_images(image)
load_models_gpu([upscale_model])
memory_required = model_management.module_size(upscale_model.model) in_img = image.movedim(-1, -3).to(upscale_model.current_device, dtype=upscale_model.model_dtype())
memory_required += (512 * 512 * 3) * image.element_size() * max(upscale_model.scale, 1.0) * 384.0 # The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate
memory_required += image.nelement() * image.element_size()
model_management.free_memory(memory_required, device)
upscale_model.to(device) tile = upscale_model.tile
in_img = image.movedim(-1, -3).to(device)
tile = 512
overlap = 32 overlap = 32
oom = True oom = True
s = None
while oom: while oom:
try: try:
steps = in_img.shape[0] * utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) steps = in_img.shape[0] * utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = utils.ProgressBar(steps) pbar = utils.ProgressBar(steps)
s = utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.model.scale, pbar=pbar)
oom = False oom = False
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
tile //= 2 tile //= 2
if tile < 128: overlap //= 2
if tile < 64 or overlap < 4:
raise e raise e
upscale_model.to("cpu") # upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0) s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
return (s,) return (s,)

View File

@ -61,3 +61,4 @@ PySoundFile
networkx>=2.6.3 networkx>=2.6.3
joblib joblib
jaxtyping jaxtyping
spandrel_extra_arches