mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 05:10:18 +08:00
Improve performance and memory management of upscale models, improve messaging on models loaded and unloaded from the GPU
This commit is contained in:
parent
c6ce11b421
commit
ce5fe01768
@ -1,4 +1,5 @@
|
|||||||
from .component_model import files
|
from .component_model import files
|
||||||
|
from .model_management import load_models_gpu
|
||||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||||
import torch
|
import torch
|
||||||
import json
|
import json
|
||||||
@ -57,7 +58,7 @@ class ClipVisionModel():
|
|||||||
return self.model.state_dict()
|
return self.model.state_dict()
|
||||||
|
|
||||||
def encode_image(self, image):
|
def encode_image(self, image):
|
||||||
model_management.load_model_gpu(self.patcher)
|
load_models_gpu([self.patcher])
|
||||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
|
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
|
||||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||||
|
|
||||||
|
|||||||
@ -1,17 +1,7 @@
|
|||||||
from typing import Annotated
|
from jaxtyping import Float
|
||||||
|
|
||||||
from jaxtyping import Float, Shaped
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
def channels_constraint(n: int):
|
|
||||||
def constraint(x: Tensor) -> bool:
|
|
||||||
return x.shape[-1] == n
|
|
||||||
|
|
||||||
return constraint
|
|
||||||
|
|
||||||
|
|
||||||
ImageBatch = Float[Tensor, "batch height width channels"]
|
ImageBatch = Float[Tensor, "batch height width channels"]
|
||||||
RGBImageBatch = Annotated[ImageBatch, Shaped[channels_constraint(3)]] | Float[Tensor, "batch height width 3"]
|
RGBImageBatch = Float[Tensor, "batch height width 3"]
|
||||||
RGBAImageBatch = Annotated[ImageBatch, Shaped[channels_constraint(4)]] | Float[Tensor, "batch height width 4"]
|
RGBAImageBatch = Float[Tensor, "batch height width 4"]
|
||||||
RGBImage = Float[Tensor, "height width 3"]
|
RGBImage = Float[Tensor, "height width 3"]
|
||||||
|
|||||||
@ -142,12 +142,13 @@ class ControlBase:
|
|||||||
|
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
|
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, ckpt_name: str = None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
if control_model is not None:
|
if control_model is not None:
|
||||||
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||||
|
self.control_model_wrapped.ckpt_name = os.path.basename(ckpt_name)
|
||||||
self.compression_ratio = compression_ratio
|
self.compression_ratio = compression_ratio
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
self.model_sampling_current = None
|
self.model_sampling_current = None
|
||||||
@ -499,7 +500,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||||
global_average_pooling = True
|
global_average_pooling = True
|
||||||
|
|
||||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype, ckpt_name=filename)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
class T2IAdapter(ControlBase):
|
class T2IAdapter(ControlBase):
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
import pathlib
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Any, Callable, Union, List
|
from typing import Optional, Any, Callable, Union, List
|
||||||
|
|
||||||
@ -175,7 +176,7 @@ class TransformersManagedModel(ModelManageable):
|
|||||||
if hasattr(self.processor, "to"):
|
if hasattr(self.processor, "to"):
|
||||||
self.processor.to(device=self.load_device)
|
self.processor.to(device=self.load_device)
|
||||||
|
|
||||||
assert "<image>" in prompt, "You must specify a <image> token inside the prompt for it to be substituted correctly by a HuggingFace processor"
|
assert "<image>" in prompt.lower(), "You must specify a <image> token inside the prompt for it to be substituted correctly by a HuggingFace processor"
|
||||||
batch_feature: BatchFeature = self.processor([prompt], images=images, padding=True, return_tensors="pt")
|
batch_feature: BatchFeature = self.processor([prompt], images=images, padding=True, return_tensors="pt")
|
||||||
if hasattr(self.processor, "to"):
|
if hasattr(self.processor, "to"):
|
||||||
self.processor.to(device=self.offload_device)
|
self.processor.to(device=self.offload_device)
|
||||||
@ -188,3 +189,10 @@ class TransformersManagedModel(ModelManageable):
|
|||||||
"inputs": batch_feature["input_ids"],
|
"inputs": batch_feature["input_ids"],
|
||||||
**batch_feature
|
**batch_feature
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if self.repo_id is not None:
|
||||||
|
repo_id_as_path = pathlib.PurePath(self.repo_id)
|
||||||
|
return f"<TransformersManagedModel for {'/'.join(repo_id_as_path.parts[-2:])} ({self.model.__class__.__name__})>"
|
||||||
|
else:
|
||||||
|
return f"<TransformersManagedModel for {self.model.__class__.__name__}>"
|
||||||
|
|||||||
@ -29,7 +29,7 @@ _session = Session()
|
|||||||
_hf_fs = HfFileSystem()
|
_hf_fs = HfFileSystem()
|
||||||
|
|
||||||
|
|
||||||
def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[List[Any]] = None) -> List[str]:
|
def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[List[Downloadable] | KnownDownloadables] = None) -> List[str]:
|
||||||
if known_files is None:
|
if known_files is None:
|
||||||
known_files = _get_known_models_for_folder_name(folder_name)
|
known_files = _get_known_models_for_folder_name(folder_name)
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import sys
|
|||||||
import warnings
|
import warnings
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from threading import RLock
|
from threading import RLock
|
||||||
from typing import Literal, List
|
from typing import Literal, List, Sequence
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
@ -15,6 +15,7 @@ from opentelemetry.trace import get_current_span
|
|||||||
from . import interruption
|
from . import interruption
|
||||||
from .cli_args import args
|
from .cli_args import args
|
||||||
from .cmd.main_pre import tracer
|
from .cmd.main_pre import tracer
|
||||||
|
from .component_model.deprecation import _deprecate_method
|
||||||
from .model_management_types import ModelManageable
|
from .model_management_types import ModelManageable
|
||||||
|
|
||||||
model_management_lock = RLock()
|
model_management_lock = RLock()
|
||||||
@ -439,7 +440,7 @@ def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
|
|||||||
|
|
||||||
|
|
||||||
@tracer.start_as_current_span("Load Models GPU")
|
@tracer.start_as_current_span("Load Models GPU")
|
||||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False) -> None:
|
||||||
global vram_state
|
global vram_state
|
||||||
span = get_current_span()
|
span = get_current_span()
|
||||||
if memory_required != 0:
|
if memory_required != 0:
|
||||||
@ -472,59 +473,61 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
|||||||
models_to_load.append(loaded_model)
|
models_to_load.append(loaded_model)
|
||||||
|
|
||||||
models_freed: List[LoadedModel] = []
|
models_freed: List[LoadedModel] = []
|
||||||
if len(models_to_load) == 0:
|
try:
|
||||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
if len(models_to_load) == 0:
|
||||||
for d in devs:
|
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||||
if d != torch.device("cpu"):
|
for d in devs:
|
||||||
models_freed += free_memory(extra_mem, d, models_already_loaded)
|
if d != torch.device("cpu"):
|
||||||
|
models_freed += free_memory(extra_mem, d, models_already_loaded)
|
||||||
|
return
|
||||||
|
|
||||||
|
total_memory_required = {}
|
||||||
|
for loaded_model in models_to_load:
|
||||||
|
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
|
||||||
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||||
|
|
||||||
|
for device in total_memory_required:
|
||||||
|
if device != torch.device("cpu"):
|
||||||
|
# todo: where does 1.3 come from?
|
||||||
|
models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
||||||
|
|
||||||
|
for loaded_model in models_to_load:
|
||||||
|
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
|
||||||
|
if weights_unloaded is not None:
|
||||||
|
loaded_model.weights_loaded = not weights_unloaded
|
||||||
|
|
||||||
|
for loaded_model in models_to_load:
|
||||||
|
model = loaded_model.model
|
||||||
|
torch_dev = model.load_device
|
||||||
|
if is_device_cpu(torch_dev):
|
||||||
|
vram_set_state = VRAMState.DISABLED
|
||||||
|
else:
|
||||||
|
vram_set_state = vram_state
|
||||||
|
lowvram_model_memory = 0
|
||||||
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||||
|
model_size = loaded_model.model_memory_required(torch_dev)
|
||||||
|
current_free_mem = get_free_memory(torch_dev)
|
||||||
|
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3))
|
||||||
|
if model_size <= (current_free_mem - inference_memory): # only switch to lowvram if really necessary
|
||||||
|
|
||||||
|
lowvram_model_memory = 0
|
||||||
|
|
||||||
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
|
lowvram_model_memory = 64 * 1024 * 1024
|
||||||
|
|
||||||
|
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
|
current_loaded_models.insert(0, loaded_model)
|
||||||
return
|
return
|
||||||
|
finally:
|
||||||
total_memory_required = {}
|
span.set_attribute("models", list(map(str, models)))
|
||||||
for loaded_model in models_to_load:
|
span.set_attribute("models_to_load", list(map(str, models_to_load)))
|
||||||
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
|
span.set_attribute("models_freed", list(map(str, models_freed)))
|
||||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
logging.info(f"Requested to load {','.join(map(str, models))}, models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}")
|
||||||
|
|
||||||
for device in total_memory_required:
|
|
||||||
if device != torch.device("cpu"):
|
|
||||||
models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
|
||||||
|
|
||||||
span.set_attribute("models_to_load", list(map(str, models_to_load)))
|
|
||||||
span.set_attribute("models_freed", list(map(str, models_freed)))
|
|
||||||
|
|
||||||
logging.info(f"Models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}")
|
|
||||||
|
|
||||||
for loaded_model in models_to_load:
|
|
||||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
|
|
||||||
if weights_unloaded is not None:
|
|
||||||
loaded_model.weights_loaded = not weights_unloaded
|
|
||||||
|
|
||||||
for loaded_model in models_to_load:
|
|
||||||
model = loaded_model.model
|
|
||||||
torch_dev = model.load_device
|
|
||||||
if is_device_cpu(torch_dev):
|
|
||||||
vram_set_state = VRAMState.DISABLED
|
|
||||||
else:
|
|
||||||
vram_set_state = vram_state
|
|
||||||
lowvram_model_memory = 0
|
|
||||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
|
||||||
model_size = loaded_model.model_memory_required(torch_dev)
|
|
||||||
current_free_mem = get_free_memory(torch_dev)
|
|
||||||
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3))
|
|
||||||
if model_size <= (current_free_mem - inference_memory): # only switch to lowvram if really necessary
|
|
||||||
|
|
||||||
lowvram_model_memory = 0
|
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
|
||||||
lowvram_model_memory = 64 * 1024 * 1024
|
|
||||||
|
|
||||||
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
|
||||||
current_loaded_models.insert(0, loaded_model)
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
|
@_deprecate_method(message="Use load_models_gpu instead", version="0.0.2")
|
||||||
def load_model_gpu(model):
|
def load_model_gpu(model):
|
||||||
with model_management_lock:
|
return load_models_gpu([model])
|
||||||
return load_models_gpu([model])
|
|
||||||
|
|
||||||
|
|
||||||
def loaded_models(only_currently_used=False):
|
def loaded_models(only_currently_used=False):
|
||||||
|
|||||||
@ -9,11 +9,12 @@ class ModelManageable(Protocol):
|
|||||||
"""
|
"""
|
||||||
Objects which implement this protocol can be managed by
|
Objects which implement this protocol can be managed by
|
||||||
|
|
||||||
>>> import comfy.model_management
|
>>> from comfy.model_management import load_models_gpu
|
||||||
>>> class SomeObj("ModelManageable"):
|
>>> class ModelWrapper(ModelManageable):
|
||||||
>>> ...
|
>>> ...
|
||||||
>>>
|
>>>
|
||||||
>>> comfy.model_management.load_model_gpu(SomeObj())
|
>>> some_model = ModelWrapper()
|
||||||
|
>>> load_models_gpu([some_model])
|
||||||
"""
|
"""
|
||||||
load_device: torch.device
|
load_device: torch.device
|
||||||
offload_device: torch.device
|
offload_device: torch.device
|
||||||
|
|||||||
@ -128,7 +128,9 @@ class CustomNode(Protocol):
|
|||||||
CATEGORY: ClassVar[str]
|
CATEGORY: ClassVar[str]
|
||||||
OUTPUT_NODE: Optional[ClassVar[bool]]
|
OUTPUT_NODE: Optional[ClassVar[bool]]
|
||||||
|
|
||||||
IS_CHANGED: Optional[ClassVar[IsChangedMethod]]
|
@classmethod
|
||||||
|
def IS_CHANGED(cls, *args, **kwargs) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __call__(cls, *args, **kwargs) -> 'CustomNode':
|
def __call__(cls, *args, **kwargs) -> 'CustomNode':
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from . import model_sampling
|
|||||||
from . import sd1_clip
|
from . import sd1_clip
|
||||||
from . import sdxl_clip
|
from . import sdxl_clip
|
||||||
from . import utils
|
from . import utils
|
||||||
|
from .model_management import load_models_gpu
|
||||||
|
|
||||||
from .text_encoders import sd2_clip
|
from .text_encoders import sd2_clip
|
||||||
from .text_encoders import sd3_clip
|
from .text_encoders import sd3_clip
|
||||||
@ -153,7 +154,7 @@ class CLIP:
|
|||||||
return sd_clip
|
return sd_clip
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
model_management.load_model_gpu(self.patcher)
|
load_models_gpu([self.patcher])
|
||||||
return self.patcher
|
return self.patcher
|
||||||
|
|
||||||
def get_key_patches(self):
|
def get_key_patches(self):
|
||||||
@ -337,7 +338,7 @@ class VAE:
|
|||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap=16):
|
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap=16):
|
||||||
model_management.load_model_gpu(self.patcher)
|
load_models_gpu([self.patcher])
|
||||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||||
return output.movedim(1, -1)
|
return output.movedim(1, -1)
|
||||||
|
|
||||||
@ -366,7 +367,7 @@ class VAE:
|
|||||||
|
|
||||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
|
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
model_management.load_model_gpu(self.patcher)
|
load_models_gpu([self.patcher])
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||||
return samples
|
return samples
|
||||||
@ -574,7 +575,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if output_model:
|
if output_model:
|
||||||
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device, ckpt_name=os.path.basename(ckpt_path))
|
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device, ckpt_name=os.path.basename(ckpt_path))
|
||||||
if inital_load_device != torch.device("cpu"):
|
if inital_load_device != torch.device("cpu"):
|
||||||
model_management.load_model_gpu(_model_patcher)
|
load_models_gpu([_model_patcher])
|
||||||
|
|
||||||
return (_model_patcher, clip, vae, clipvision)
|
return (_model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
|||||||
from comfy.language.language_types import ProcessorResult
|
from comfy.language.language_types import ProcessorResult
|
||||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||||
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
|
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
|
||||||
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
|
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device, load_models_gpu
|
||||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
||||||
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
|
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
|
||||||
|
|
||||||
@ -340,7 +340,7 @@ class TransformersGenerate(CustomNode):
|
|||||||
tokens = copy.copy(tokens)
|
tokens = copy.copy(tokens)
|
||||||
sampler = sampler or {}
|
sampler = sampler or {}
|
||||||
generate_kwargs = copy.copy(sampler)
|
generate_kwargs = copy.copy(sampler)
|
||||||
load_model_gpu(model)
|
load_models_gpu([model])
|
||||||
transformers_model: PreTrainedModel = model.model
|
transformers_model: PreTrainedModel = model.model
|
||||||
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
|
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
|
||||||
# remove unused inputs
|
# remove unused inputs
|
||||||
|
|||||||
@ -1,24 +1,135 @@
|
|||||||
import logging
|
import logging
|
||||||
import torch
|
from typing import Optional, Any
|
||||||
|
|
||||||
|
import torch
|
||||||
from spandrel import ModelLoader, ImageModelDescriptor
|
from spandrel import ModelLoader, ImageModelDescriptor
|
||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
from comfy import utils
|
from comfy import utils
|
||||||
|
from comfy.component_model.tensor_types import RGBImageBatch
|
||||||
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download
|
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download
|
||||||
|
from comfy.model_management import load_models_gpu
|
||||||
|
from comfy.model_management_types import ModelManageable
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from spandrel_extra_arches import EXTRA_REGISTRY # pylint: disable=import-error
|
from spandrel_extra_arches import EXTRA_REGISTRY # pylint: disable=import-error
|
||||||
from spandrel import MAIN_REGISTRY
|
from spandrel import MAIN_REGISTRY
|
||||||
|
|
||||||
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
|
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
|
||||||
logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
|
logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UpscaleModelManageable(ModelManageable):
|
||||||
|
def __init__(self, model_descriptor: ImageModelDescriptor, ckpt_name: str):
|
||||||
|
self.ckpt_name = ckpt_name
|
||||||
|
self.model_descriptor = model_descriptor
|
||||||
|
self.model = model_descriptor.model
|
||||||
|
self.load_device = model_management.unet_offload_device()
|
||||||
|
self.offload_device = model_management.unet_offload_device()
|
||||||
|
self._current_device = self.offload_device
|
||||||
|
self._lowvram_patch_counter = 0
|
||||||
|
|
||||||
|
# Private properties for image sizes and channels
|
||||||
|
self._input_size = (1, 512, 512) # Default input size (batch, height, width)
|
||||||
|
self._input_channels = model_descriptor.input_channels
|
||||||
|
self._output_channels = model_descriptor.output_channels
|
||||||
|
self.tile = 512
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_device(self) -> torch.device:
|
||||||
|
return self._current_device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_size(self) -> tuple[int, int, int]:
|
||||||
|
return self._input_size
|
||||||
|
|
||||||
|
@input_size.setter
|
||||||
|
def input_size(self, size: tuple[int, int, int]):
|
||||||
|
self._input_size = size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self) -> tuple[int, int, int]:
|
||||||
|
return (self._input_size[0],
|
||||||
|
self._input_size[1] * self.model_descriptor.scale,
|
||||||
|
self._input_size[2] * self.model_descriptor.scale)
|
||||||
|
|
||||||
|
def set_input_size_from_images(self, images: RGBImageBatch):
|
||||||
|
if images.ndim != 4:
|
||||||
|
raise ValueError("Input must be a 4D tensor (batch, height, width, channels)")
|
||||||
|
if images.shape[-1] != 3:
|
||||||
|
raise ValueError("Input must have 3 channels (RGB)")
|
||||||
|
self._input_size = (images.shape[0], images.shape[1], images.shape[2])
|
||||||
|
|
||||||
|
def is_clone(self, other: Any) -> bool:
|
||||||
|
return isinstance(other, UpscaleModelManageable) and self.model is other.model
|
||||||
|
|
||||||
|
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool:
|
||||||
|
return self.is_clone(clone)
|
||||||
|
|
||||||
|
def model_size(self) -> int:
|
||||||
|
# Calculate the size of the model parameters
|
||||||
|
model_params_size = sum(p.numel() * p.element_size() for p in self.model.parameters())
|
||||||
|
|
||||||
|
# Get the byte size of the model's dtype
|
||||||
|
dtype_size = torch.finfo(self.model_dtype()).bits // 8
|
||||||
|
|
||||||
|
# Calculate the memory required for input and output images
|
||||||
|
input_size = self._input_size[0] * min(self.tile, self._input_size[1]) * min(self.tile, self._input_size[2]) * self._input_channels * dtype_size
|
||||||
|
output_size = self.output_size[0] * self.output_size[1] * self.output_size[2] * self._output_channels * dtype_size
|
||||||
|
|
||||||
|
# Add some extra memory for processing
|
||||||
|
extra_memory = (input_size + output_size) * 2 # This is an estimate, adjust as needed
|
||||||
|
|
||||||
|
return model_params_size + input_size + output_size + extra_memory
|
||||||
|
|
||||||
|
def model_patches_to(self, arg: torch.device | torch.dtype):
|
||||||
|
if isinstance(arg, torch.device):
|
||||||
|
self.model.to(device=arg)
|
||||||
|
else:
|
||||||
|
self.model.to(dtype=arg)
|
||||||
|
|
||||||
|
def model_dtype(self) -> torch.dtype:
|
||||||
|
return next(self.model.parameters()).dtype
|
||||||
|
|
||||||
|
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||||
|
self.model.to(device=device_to)
|
||||||
|
self._current_device = device_to
|
||||||
|
self._lowvram_patch_counter += 1
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
|
||||||
|
if patch_weights:
|
||||||
|
self.model.to(device=device_to)
|
||||||
|
self._current_device = device_to
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||||
|
if unpatch_weights:
|
||||||
|
self.model.to(device=offload_device)
|
||||||
|
self._current_device = offload_device
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lowvram_patch_counter(self) -> int:
|
||||||
|
return self._lowvram_patch_counter
|
||||||
|
|
||||||
|
@lowvram_patch_counter.setter
|
||||||
|
def lowvram_patch_counter(self, value: int):
|
||||||
|
self._lowvram_patch_counter = value
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if self.ckpt_name is not None:
|
||||||
|
return f"<UpscaleModelManageable for {self.ckpt_name} ({self.model.__class__.__name__})>"
|
||||||
|
else:
|
||||||
|
return f"<UpscaleModelManageable for {self.model.__class__.__name__}>"
|
||||||
|
|
||||||
|
|
||||||
class UpscaleModelLoader:
|
class UpscaleModelLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"model_name": (get_filename_list_with_downloadable("upscale_models", KNOWN_UPSCALERS),),
|
return {"required": {"model_name": (get_filename_list_with_downloadable("upscale_models"),),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("UPSCALE_MODEL",)
|
RETURN_TYPES = ("UPSCALE_MODEL",)
|
||||||
@ -36,7 +147,7 @@ class UpscaleModelLoader:
|
|||||||
if not isinstance(out, ImageModelDescriptor):
|
if not isinstance(out, ImageModelDescriptor):
|
||||||
raise Exception("Upscale model must be a single-image model.")
|
raise Exception("Upscale model must be a single-image model.")
|
||||||
|
|
||||||
return (out,)
|
return (UpscaleModelManageable(out, model_name),)
|
||||||
|
|
||||||
|
|
||||||
class ImageUpscaleWithModel:
|
class ImageUpscaleWithModel:
|
||||||
@ -51,33 +162,30 @@ class ImageUpscaleWithModel:
|
|||||||
|
|
||||||
CATEGORY = "image/upscaling"
|
CATEGORY = "image/upscaling"
|
||||||
|
|
||||||
def upscale(self, upscale_model, image):
|
def upscale(self, upscale_model: UpscaleModelManageable, image: RGBImageBatch):
|
||||||
device = model_management.get_torch_device()
|
upscale_model.set_input_size_from_images(image)
|
||||||
|
load_models_gpu([upscale_model])
|
||||||
|
|
||||||
memory_required = model_management.module_size(upscale_model.model)
|
in_img = image.movedim(-1, -3).to(upscale_model.current_device, dtype=upscale_model.model_dtype())
|
||||||
memory_required += (512 * 512 * 3) * image.element_size() * max(upscale_model.scale, 1.0) * 384.0 # The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate
|
|
||||||
memory_required += image.nelement() * image.element_size()
|
|
||||||
model_management.free_memory(memory_required, device)
|
|
||||||
|
|
||||||
upscale_model.to(device)
|
tile = upscale_model.tile
|
||||||
in_img = image.movedim(-1, -3).to(device)
|
|
||||||
|
|
||||||
tile = 512
|
|
||||||
overlap = 32
|
overlap = 32
|
||||||
|
|
||||||
oom = True
|
oom = True
|
||||||
|
s = None
|
||||||
while oom:
|
while oom:
|
||||||
try:
|
try:
|
||||||
steps = in_img.shape[0] * utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
steps = in_img.shape[0] * utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||||
pbar = utils.ProgressBar(steps)
|
pbar = utils.ProgressBar(steps)
|
||||||
s = utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.model.scale, pbar=pbar)
|
||||||
oom = False
|
oom = False
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
tile //= 2
|
tile //= 2
|
||||||
if tile < 128:
|
overlap //= 2
|
||||||
|
if tile < 64 or overlap < 4:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
upscale_model.to("cpu")
|
# upscale_model.to("cpu")
|
||||||
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
|
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
|
|||||||
@ -61,3 +61,4 @@ PySoundFile
|
|||||||
networkx>=2.6.3
|
networkx>=2.6.3
|
||||||
joblib
|
joblib
|
||||||
jaxtyping
|
jaxtyping
|
||||||
|
spandrel_extra_arches
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user