mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
- Experimental support for sage attention on Linux - Diffusers loader now supports model indices - Transformers model management now aligns with updates to ComfyUI - Flux layers correctly use unbind - Add float8 support for model loading in more places - Experimental quantization approaches from Quanto and torchao - Model upscaling interacts with memory management better This update also disables ROCm testing because it isn't reliable enough on consumer hardware. ROCm is not really supported by the 7600.
188 lines
7.0 KiB
Python
188 lines
7.0 KiB
Python
import logging
|
|
from typing import Optional, Any
|
|
|
|
import torch
|
|
from spandrel import ModelLoader, ImageModelDescriptor
|
|
|
|
from comfy import model_management
|
|
from comfy import utils
|
|
from comfy.component_model.tensor_types import RGBImageBatch
|
|
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download
|
|
from comfy.model_management import load_models_gpu
|
|
from comfy.model_management_types import ModelManageable
|
|
|
|
try:
|
|
from spandrel_extra_arches import EXTRA_REGISTRY # pylint: disable=import-error
|
|
from spandrel import MAIN_REGISTRY
|
|
|
|
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
|
|
logging.debug("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
|
|
except:
|
|
pass
|
|
|
|
|
|
class UpscaleModelManageable(ModelManageable):
|
|
def __init__(self, model_descriptor: ImageModelDescriptor, ckpt_name: str):
|
|
self.ckpt_name = ckpt_name
|
|
self.model_descriptor = model_descriptor
|
|
self.model = model_descriptor.model
|
|
self.load_device = model_management.get_torch_device()
|
|
self.offload_device = model_management.unet_offload_device()
|
|
self._input_size = (1, 512, 512)
|
|
self._input_channels = model_descriptor.input_channels
|
|
self._output_channels = model_descriptor.output_channels
|
|
self.tile = 512
|
|
|
|
@property
|
|
def current_device(self) -> torch.device:
|
|
return self.model_descriptor.device
|
|
|
|
@property
|
|
def input_size(self) -> tuple[int, int, int]:
|
|
return self._input_size
|
|
|
|
@input_size.setter
|
|
def input_size(self, size: tuple[int, int, int]):
|
|
self._input_size = size
|
|
|
|
@property
|
|
def scale(self) -> int:
|
|
if not hasattr(self.model_descriptor, "scale"):
|
|
return 1
|
|
return self.model_descriptor.scale
|
|
|
|
@property
|
|
def output_size(self) -> tuple[int, int, int]:
|
|
return (self._input_size[0],
|
|
self._input_size[1] * self.scale,
|
|
self._input_size[2] * self.scale)
|
|
|
|
def set_input_size_from_images(self, images: RGBImageBatch):
|
|
if images.ndim != 4:
|
|
raise ValueError("Input must be a 4D tensor (batch, height, width, channels)")
|
|
if images.shape[-1] != 3:
|
|
raise ValueError("Input must have 3 channels (RGB)")
|
|
self._input_size = (images.shape[0], images.shape[1], images.shape[2])
|
|
|
|
def is_clone(self, other: Any) -> bool:
|
|
return isinstance(other, UpscaleModelManageable) and self.model is other.model
|
|
|
|
def clone_has_same_weights(self, clone) -> bool:
|
|
return self.is_clone(clone)
|
|
|
|
def model_size(self) -> int:
|
|
model_params_size = sum(p.numel() * p.element_size() for p in self.model.parameters())
|
|
dtype_size = torch.finfo(self.model_dtype()).bits // 8
|
|
batch_size = self._input_size[0]
|
|
input_size = batch_size * min(self.tile, self._input_size[1]) * min(self.tile, self._input_size[2]) * self._input_channels * dtype_size
|
|
output_size = batch_size * min(self.tile * self.scale, self.output_size[1]) * min(self.tile * self.scale, self.output_size[2]) * self._output_channels * dtype_size
|
|
|
|
return model_params_size + input_size + output_size
|
|
|
|
def model_patches_to(self, arg: torch.device | torch.dtype):
|
|
if isinstance(arg, torch.device):
|
|
self.model.to(device=arg)
|
|
else:
|
|
self.model.to(dtype=arg)
|
|
|
|
def model_dtype(self) -> torch.dtype:
|
|
return next(self.model.parameters()).dtype
|
|
|
|
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
|
|
self.model.to(device=device_to)
|
|
return self.model
|
|
|
|
def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
|
self.model.to(device=device_to)
|
|
return self.model
|
|
|
|
def __str__(self):
|
|
if self.ckpt_name is not None:
|
|
return f"<UpscaleModelManageable for {self.ckpt_name} ({self.model.__class__.__name__})>"
|
|
else:
|
|
return f"<UpscaleModelManageable for {self.model.__class__.__name__}>"
|
|
|
|
|
|
class UpscaleModelLoader:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {"model_name": (get_filename_list_with_downloadable("upscale_models"),),
|
|
}}
|
|
|
|
RETURN_TYPES = ("UPSCALE_MODEL",)
|
|
FUNCTION = "load_model"
|
|
|
|
CATEGORY = "loaders"
|
|
|
|
def load_model(self, model_name):
|
|
model_path = get_or_download("upscale_models", model_name, KNOWN_UPSCALERS)
|
|
sd = utils.load_torch_file(model_path, safe_load=True)
|
|
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
|
sd = utils.state_dict_prefix_replace(sd, {"module.": ""})
|
|
out = ModelLoader().load_from_state_dict(sd).eval()
|
|
|
|
if not isinstance(out, ImageModelDescriptor):
|
|
raise Exception("Upscale model must be a single-image model.")
|
|
|
|
return (UpscaleModelManageable(out, model_name),)
|
|
|
|
|
|
class ImageUpscaleWithModel:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {"upscale_model": ("UPSCALE_MODEL",),
|
|
"image": ("IMAGE",),
|
|
}}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "upscale"
|
|
|
|
CATEGORY = "image/upscaling"
|
|
|
|
def upscale(self, upscale_model: UpscaleModelManageable, image: RGBImageBatch):
|
|
upscale_model.set_input_size_from_images(image)
|
|
load_models_gpu([upscale_model])
|
|
|
|
in_img = image.movedim(-1, -3).to(upscale_model.current_device, dtype=upscale_model.model_dtype())
|
|
|
|
tile = upscale_model.tile
|
|
overlap = 32
|
|
|
|
oom = True
|
|
s = None
|
|
while oom:
|
|
try:
|
|
steps = in_img.shape[0] * utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
|
pbar = utils.ProgressBar(steps)
|
|
s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
|
oom = False
|
|
except model_management.OOM_EXCEPTION as e:
|
|
tile //= 2
|
|
overlap //= 2
|
|
if tile < 64 or overlap < 4:
|
|
raise e
|
|
except RuntimeError as exc_info:
|
|
if "have 1 channels, but got 3 channels instead" in str(exc_info):
|
|
# convert RGB to luminance (assuming sRGB)
|
|
|
|
rgb_weights = torch.tensor([0.2126, 0.7152, 0.0722], device=in_img.device, dtype=in_img.dtype)
|
|
in_img = (in_img * rgb_weights.view(1, 3, 1, 1)).sum(dim=1, keepdim=True)
|
|
continue
|
|
else:
|
|
raise exc_info
|
|
|
|
# upscale_model.to("cpu")
|
|
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
|
|
|
|
if s.shape[-1] == 1:
|
|
s = s.expand(-1, -1, -1, 3)
|
|
|
|
del in_img
|
|
return (s,)
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"UpscaleModelLoader": UpscaleModelLoader,
|
|
"ImageUpscaleWithModel": ImageUpscaleWithModel
|
|
}
|