mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 07:10:15 +08:00
Support more upscale models
This commit is contained in:
parent
667b77149e
commit
f642c7cc26
@ -45,11 +45,17 @@ class UpscaleModelManageable(ModelManageable):
|
|||||||
def input_size(self, size: tuple[int, int, int]):
|
def input_size(self, size: tuple[int, int, int]):
|
||||||
self._input_size = size
|
self._input_size = size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale(self) -> int:
|
||||||
|
if not hasattr(self.model_descriptor, "scale"):
|
||||||
|
return 1
|
||||||
|
return self.model_descriptor.scale
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_size(self) -> tuple[int, int, int]:
|
def output_size(self) -> tuple[int, int, int]:
|
||||||
return (self._input_size[0],
|
return (self._input_size[0],
|
||||||
self._input_size[1] * self.model_descriptor.scale,
|
self._input_size[1] * self.scale,
|
||||||
self._input_size[2] * self.model_descriptor.scale)
|
self._input_size[2] * self.scale)
|
||||||
|
|
||||||
def set_input_size_from_images(self, images: RGBImageBatch):
|
def set_input_size_from_images(self, images: RGBImageBatch):
|
||||||
if images.ndim != 4:
|
if images.ndim != 4:
|
||||||
@ -152,8 +158,14 @@ class ImageUpscaleWithModel:
|
|||||||
try:
|
try:
|
||||||
steps = in_img.shape[0] * utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
steps = in_img.shape[0] * utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||||
pbar = utils.ProgressBar(steps)
|
pbar = utils.ProgressBar(steps)
|
||||||
s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.model.scale, pbar=pbar)
|
s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||||
oom = False
|
oom = False
|
||||||
|
except RuntimeError as exc_info:
|
||||||
|
if "have 1 channels, but got 3 channels instead" in str(exc_info):
|
||||||
|
# convert RGB to luminance (assuming sRGB)
|
||||||
|
rgb_weights = torch.tensor([0.2126, 0.7152, 0.0722], device=in_img.device, dtype=in_img.dtype)
|
||||||
|
in_img = (in_img * rgb_weights.view(1, 3, 1, 1)).sum(dim=1, keepdim=True)
|
||||||
|
continue
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
tile //= 2
|
tile //= 2
|
||||||
overlap //= 2
|
overlap //= 2
|
||||||
@ -162,6 +174,10 @@ class ImageUpscaleWithModel:
|
|||||||
|
|
||||||
# upscale_model.to("cpu")
|
# upscale_model.to("cpu")
|
||||||
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
|
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
|
||||||
|
|
||||||
|
if s.shape[-1] == 1:
|
||||||
|
s = s.expand(-1, -1, -1, 3)
|
||||||
|
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user