Support more upscale models

This commit is contained in:
doctorpangloss 2024-09-26 18:08:43 -07:00
parent 667b77149e
commit f642c7cc26

View File

@ -45,11 +45,17 @@ class UpscaleModelManageable(ModelManageable):
def input_size(self, size: tuple[int, int, int]): def input_size(self, size: tuple[int, int, int]):
self._input_size = size self._input_size = size
@property
def scale(self) -> int:
if not hasattr(self.model_descriptor, "scale"):
return 1
return self.model_descriptor.scale
@property @property
def output_size(self) -> tuple[int, int, int]: def output_size(self) -> tuple[int, int, int]:
return (self._input_size[0], return (self._input_size[0],
self._input_size[1] * self.model_descriptor.scale, self._input_size[1] * self.scale,
self._input_size[2] * self.model_descriptor.scale) self._input_size[2] * self.scale)
def set_input_size_from_images(self, images: RGBImageBatch): def set_input_size_from_images(self, images: RGBImageBatch):
if images.ndim != 4: if images.ndim != 4:
@ -152,8 +158,14 @@ class ImageUpscaleWithModel:
try: try:
steps = in_img.shape[0] * utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) steps = in_img.shape[0] * utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = utils.ProgressBar(steps) pbar = utils.ProgressBar(steps)
s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.model.scale, pbar=pbar) s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False oom = False
except RuntimeError as exc_info:
if "have 1 channels, but got 3 channels instead" in str(exc_info):
# convert RGB to luminance (assuming sRGB)
rgb_weights = torch.tensor([0.2126, 0.7152, 0.0722], device=in_img.device, dtype=in_img.dtype)
in_img = (in_img * rgb_weights.view(1, 3, 1, 1)).sum(dim=1, keepdim=True)
continue
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
tile //= 2 tile //= 2
overlap //= 2 overlap //= 2
@ -162,6 +174,10 @@ class ImageUpscaleWithModel:
# upscale_model.to("cpu") # upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0) s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
if s.shape[-1] == 1:
s = s.expand(-1, -1, -1, 3)
return (s,) return (s,)