Support more upscale models

This commit is contained in:
doctorpangloss 2024-09-26 18:08:43 -07:00
parent 667b77149e
commit f642c7cc26

View File

@ -45,11 +45,17 @@ class UpscaleModelManageable(ModelManageable):
def input_size(self, size: tuple[int, int, int]):
self._input_size = size
@property
def scale(self) -> int:
if not hasattr(self.model_descriptor, "scale"):
return 1
return self.model_descriptor.scale
@property
def output_size(self) -> tuple[int, int, int]:
return (self._input_size[0],
self._input_size[1] * self.model_descriptor.scale,
self._input_size[2] * self.model_descriptor.scale)
self._input_size[1] * self.scale,
self._input_size[2] * self.scale)
def set_input_size_from_images(self, images: RGBImageBatch):
if images.ndim != 4:
@ -152,8 +158,14 @@ class ImageUpscaleWithModel:
try:
steps = in_img.shape[0] * utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = utils.ProgressBar(steps)
s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.model.scale, pbar=pbar)
s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False
except RuntimeError as exc_info:
if "have 1 channels, but got 3 channels instead" in str(exc_info):
# convert RGB to luminance (assuming sRGB)
rgb_weights = torch.tensor([0.2126, 0.7152, 0.0722], device=in_img.device, dtype=in_img.dtype)
in_img = (in_img * rgb_weights.view(1, 3, 1, 1)).sum(dim=1, keepdim=True)
continue
except model_management.OOM_EXCEPTION as e:
tile //= 2
overlap //= 2
@ -162,6 +174,10 @@ class ImageUpscaleWithModel:
# upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
if s.shape[-1] == 1:
s = s.expand(-1, -1, -1, 3)
return (s,)