Fix upscale model interacting with B&W upscaling throwing exceptions in cases where the image is weird

This commit is contained in:
doctorpangloss 2024-09-27 12:06:41 -07:00
parent f642c7cc26
commit 53056ca76f

View File

@ -73,11 +73,11 @@ class UpscaleModelManageable(ModelManageable):
def model_size(self) -> int:
model_params_size = sum(p.numel() * p.element_size() for p in self.model.parameters())
dtype_size = torch.finfo(self.model_dtype()).bits // 8
input_size = self._input_size[0] * min(self.tile, self._input_size[1]) * min(self.tile, self._input_size[2]) * self._input_channels * dtype_size
output_size = self.output_size[0] * self.output_size[1] * self.output_size[2] * self._output_channels * dtype_size
extra_memory = (input_size + output_size) * 2 # This is an estimate, adjust as needed
batch_size = self._input_size[0]
input_size = batch_size * min(self.tile, self._input_size[1]) * min(self.tile, self._input_size[2]) * self._input_channels * dtype_size
output_size = batch_size * min(self.tile * self.scale, self.output_size[1]) * min(self.tile * self.scale, self.output_size[2]) * self._output_channels * dtype_size
return model_params_size + input_size + output_size + extra_memory
return model_params_size + input_size + output_size
def model_patches_to(self, arg: torch.device | torch.dtype):
if isinstance(arg, torch.device):
@ -160,17 +160,20 @@ class ImageUpscaleWithModel:
pbar = utils.ProgressBar(steps)
s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False
except RuntimeError as exc_info:
if "have 1 channels, but got 3 channels instead" in str(exc_info):
# convert RGB to luminance (assuming sRGB)
rgb_weights = torch.tensor([0.2126, 0.7152, 0.0722], device=in_img.device, dtype=in_img.dtype)
in_img = (in_img * rgb_weights.view(1, 3, 1, 1)).sum(dim=1, keepdim=True)
continue
except model_management.OOM_EXCEPTION as e:
tile //= 2
overlap //= 2
if tile < 64 or overlap < 4:
raise e
except RuntimeError as exc_info:
if "have 1 channels, but got 3 channels instead" in str(exc_info):
# convert RGB to luminance (assuming sRGB)
rgb_weights = torch.tensor([0.2126, 0.7152, 0.0722], device=in_img.device, dtype=in_img.dtype)
in_img = (in_img * rgb_weights.view(1, 3, 1, 1)).sum(dim=1, keepdim=True)
continue
else:
raise exc_info
# upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
@ -178,6 +181,7 @@ class ImageUpscaleWithModel:
if s.shape[-1] == 1:
s = s.expand(-1, -1, -1, 3)
del in_img
return (s,)