diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index d3ee3f1c1..e2ae24422 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -5,8 +5,7 @@ import torch import comfy.utils import folder_paths from typing_extensions import override -from comfy_api.latest import ComfyExtension, io -import comfy.model_management +from comfy_api.latest import ComfyExtension, Input, io try: from spandrel_extra_arches import EXTRA_REGISTRY @@ -47,9 +46,29 @@ class UpscaleModelLoader(io.ComfyNode): load_model = execute # TODO: remove +class UpscaledImageStream(Input.ImageStream): + def __init__(self, upscale_model, stream: Input.ImageStream): + super().__init__() + self._upscale_model = upscale_model + self._stream = stream + + def get_dimensions(self) -> tuple[int, int]: + width, height = self._stream.get_dimensions() + scale = self._upscale_model.scale + return int(width * scale), int(height * scale) + + def do_reset(self) -> None: + self._stream.reset() + + def do_pull(self, max_frames: int) -> Input.Image: + chunk = self._stream.pull(max_frames) + return ImageUpscaleWithModel.upscale_batch(self._upscale_model, chunk) + + class ImageUpscaleWithModel(io.ComfyNode): @classmethod def define_schema(cls): + image_template = io.MatchType.Template("image_type", allowed_types=[io.Image, io.ImageStream]) return io.Schema( node_id="ImageUpscaleWithModel", display_name="Upscale Image (using Model)", @@ -57,15 +76,18 @@ class ImageUpscaleWithModel(io.ComfyNode): search_aliases=["upscale", "upscaler", "upsc", "enlarge image", "super resolution", "hires", "superres", "increase resolution"], inputs=[ io.UpscaleModel.Input("upscale_model"), - io.Image.Input("image"), + io.MatchType.Input("image", template=image_template), ], outputs=[ - io.Image.Output(), + io.MatchType.Output(template=image_template, display_name="image"), ], ) @classmethod - def execute(cls, upscale_model, image) -> io.NodeOutput: + def upscale_batch(cls, upscale_model, image: torch.Tensor) -> torch.Tensor: + if image.shape[0] == 0: + return image.clone() + device = model_management.get_torch_device() memory_required = model_management.module_size(upscale_model.model) @@ -79,7 +101,7 @@ class ImageUpscaleWithModel(io.ComfyNode): tile = 512 overlap = 32 - output_device = comfy.model_management.intermediate_device() + output_device = model_management.intermediate_device() oom = True try: @@ -97,8 +119,14 @@ class ImageUpscaleWithModel(io.ComfyNode): finally: upscale_model.to("cpu") - s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype()) - return io.NodeOutput(s) + return torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(model_management.intermediate_dtype()) + + @classmethod + def execute(cls, upscale_model, image) -> io.NodeOutput: + if isinstance(image, torch.Tensor): + return io.NodeOutput(cls.upscale_batch(upscale_model, image)) + + return io.NodeOutput(UpscaledImageStream(upscale_model, image)) upscale = execute # TODO: remove