diff --git a/comfy_extras/nodes/nodes_upscale_model.py b/comfy_extras/nodes/nodes_upscale_model.py index 20f5b5eb5..3d157a2b0 100644 --- a/comfy_extras/nodes/nodes_upscale_model.py +++ b/comfy_extras/nodes/nodes_upscale_model.py @@ -73,11 +73,11 @@ class UpscaleModelManageable(ModelManageable): def model_size(self) -> int: model_params_size = sum(p.numel() * p.element_size() for p in self.model.parameters()) dtype_size = torch.finfo(self.model_dtype()).bits // 8 - input_size = self._input_size[0] * min(self.tile, self._input_size[1]) * min(self.tile, self._input_size[2]) * self._input_channels * dtype_size - output_size = self.output_size[0] * self.output_size[1] * self.output_size[2] * self._output_channels * dtype_size - extra_memory = (input_size + output_size) * 2 # This is an estimate, adjust as needed + batch_size = self._input_size[0] + input_size = batch_size * min(self.tile, self._input_size[1]) * min(self.tile, self._input_size[2]) * self._input_channels * dtype_size + output_size = batch_size * min(self.tile * self.scale, self.output_size[1]) * min(self.tile * self.scale, self.output_size[2]) * self._output_channels * dtype_size - return model_params_size + input_size + output_size + extra_memory + return model_params_size + input_size + output_size def model_patches_to(self, arg: torch.device | torch.dtype): if isinstance(arg, torch.device): @@ -160,17 +160,20 @@ class ImageUpscaleWithModel: pbar = utils.ProgressBar(steps) s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) oom = False - except RuntimeError as exc_info: - if "have 1 channels, but got 3 channels instead" in str(exc_info): - # convert RGB to luminance (assuming sRGB) - rgb_weights = torch.tensor([0.2126, 0.7152, 0.0722], device=in_img.device, dtype=in_img.dtype) - in_img = (in_img * rgb_weights.view(1, 3, 1, 1)).sum(dim=1, keepdim=True) - continue except model_management.OOM_EXCEPTION as e: tile //= 2 overlap //= 2 if tile < 64 or overlap < 4: raise e + except RuntimeError as exc_info: + if "have 1 channels, but got 3 channels instead" in str(exc_info): + # convert RGB to luminance (assuming sRGB) + + rgb_weights = torch.tensor([0.2126, 0.7152, 0.0722], device=in_img.device, dtype=in_img.dtype) + in_img = (in_img * rgb_weights.view(1, 3, 1, 1)).sum(dim=1, keepdim=True) + continue + else: + raise exc_info # upscale_model.to("cpu") s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0) @@ -178,6 +181,7 @@ class ImageUpscaleWithModel: if s.shape[-1] == 1: s = s.expand(-1, -1, -1, 3) + del in_img return (s,)