mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Support more upscale models
This commit is contained in:
parent
667b77149e
commit
f642c7cc26
@ -45,11 +45,17 @@ class UpscaleModelManageable(ModelManageable):
|
||||
def input_size(self, size: tuple[int, int, int]):
|
||||
self._input_size = size
|
||||
|
||||
@property
|
||||
def scale(self) -> int:
|
||||
if not hasattr(self.model_descriptor, "scale"):
|
||||
return 1
|
||||
return self.model_descriptor.scale
|
||||
|
||||
@property
|
||||
def output_size(self) -> tuple[int, int, int]:
|
||||
return (self._input_size[0],
|
||||
self._input_size[1] * self.model_descriptor.scale,
|
||||
self._input_size[2] * self.model_descriptor.scale)
|
||||
self._input_size[1] * self.scale,
|
||||
self._input_size[2] * self.scale)
|
||||
|
||||
def set_input_size_from_images(self, images: RGBImageBatch):
|
||||
if images.ndim != 4:
|
||||
@ -152,8 +158,14 @@ class ImageUpscaleWithModel:
|
||||
try:
|
||||
steps = in_img.shape[0] * utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
pbar = utils.ProgressBar(steps)
|
||||
s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.model.scale, pbar=pbar)
|
||||
s = utils.tiled_scale(in_img, lambda a: upscale_model.model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||
oom = False
|
||||
except RuntimeError as exc_info:
|
||||
if "have 1 channels, but got 3 channels instead" in str(exc_info):
|
||||
# convert RGB to luminance (assuming sRGB)
|
||||
rgb_weights = torch.tensor([0.2126, 0.7152, 0.0722], device=in_img.device, dtype=in_img.dtype)
|
||||
in_img = (in_img * rgb_weights.view(1, 3, 1, 1)).sum(dim=1, keepdim=True)
|
||||
continue
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
tile //= 2
|
||||
overlap //= 2
|
||||
@ -162,6 +174,10 @@ class ImageUpscaleWithModel:
|
||||
|
||||
# upscale_model.to("cpu")
|
||||
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
|
||||
|
||||
if s.shape[-1] == 1:
|
||||
s = s.expand(-1, -1, -1, 3)
|
||||
|
||||
return (s,)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user