From f642c7cc260a443362a8457da82a16ed5a4b7848 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Thu, 26 Sep 2024 18:08:43 -0700 Subject: [PATCH] Support more upscale models --- comfy_extras/nodes/nodes_upscale_model.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes/nodes_upscale_model.py b/comfy_extras/nodes/nodes_upscale_model.py index 0d29d2368..20f5b5eb5 100644 --- a/comfy_extras/nodes/nodes_upscale_model.py +++ b/comfy_extras/nodes/nodes_upscale_model.py @@ -45,11 +45,17 @@ class UpscaleModelManageable(ModelManageable): def input_size(self, size: tuple[int, int, int]): self._input_size = size + @property + def scale(self) -> int: + if not hasattr(self.model_descriptor, "scale"): + return 1 + return self.model_descriptor.scale + @property def output_size(self) -> tuple[int, int, int]: return (self._input_size[0], - self._input_size[1] * self.model_descriptor.scale, - self._input_size[2] * self.model_descriptor.scale) + self._input_size[1] * self.scale, + self._input_size[2] * self.scale) def set_input_size_from_images(self, images: RGBImageBatch): if images.ndim != 4: @@ -152,8 +158,14 @@ class ImageUpscaleWithModel: try: steps = in_img.shape[0] * utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) pbar = utils.ProgressBar(steps) - s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.model.scale, pbar=pbar) + s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) oom = False + except RuntimeError as exc_info: + if "have 1 channels, but got 3 channels instead" in str(exc_info): + # convert RGB to luminance (assuming sRGB) + rgb_weights = torch.tensor([0.2126, 0.7152, 0.0722], device=in_img.device, dtype=in_img.dtype) + in_img = (in_img * rgb_weights.view(1, 3, 1, 1)).sum(dim=1, keepdim=True) + continue except model_management.OOM_EXCEPTION as e: tile //= 2 overlap //= 2 @@ -162,6 +174,10 @@ class ImageUpscaleWithModel: # upscale_model.to("cpu") s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0) + + if s.shape[-1] == 1: + s = s.expand(-1, -1, -1, 3) + return (s,)