diff --git a/comfy_api/latest/_input/image_stream_types.py b/comfy_api/latest/_input/image_stream_types.py index 04af3562c..a582859ad 100644 --- a/comfy_api/latest/_input/image_stream_types.py +++ b/comfy_api/latest/_input/image_stream_types.py @@ -1,7 +1,10 @@ from __future__ import annotations from abc import ABC, abstractmethod +from contextlib import nullcontext +from comfy_execution.utils import CurrentNodeContext, get_executing_context +from comfy_execution.progress import get_progress_state from .basic_types import ImageInput @@ -14,17 +17,33 @@ class ImageStreamInput(ABC): def __init__(self): #Subclasses must call this init for future core ComfyUI change compatibilty - pass + self._ctx = get_executing_context() def reset(self) -> None: #This API is final. Subclasses must NOT override this for future core ComfyUI #change compatability. Override do_reset instead. - return self.do_reset() + with (nullcontext() if self._ctx is None else + CurrentNodeContext(self._ctx.prompt_id, self._ctx.node_id, self._ctx.list_index)): + self.do_reset() + + if self._ctx is not None: + get_progress_state().finish_progress(self._ctx.node_id) def pull(self, max_frames: int) -> ImageInput: #This API is final. Subclasses must NOT override this for future core ComfyUI #change compatability. Override do_pull instead. - return self.do_pull(max_frames) + with (nullcontext() if self._ctx is None else + CurrentNodeContext(self._ctx.prompt_id, self._ctx.node_id, self._ctx.list_index)): + result = self.do_pull(max_frames) + + if self._ctx is not None: + registry = get_progress_state() + entry = registry.nodes.get(self._ctx.node_id) + if (int(result.shape[0]) < max_frames or + (entry is not None and entry["max"] > 0 and entry["value"] >= entry["max"])): + registry.finish_progress(self._ctx.node_id) + + return result @abstractmethod def get_dimensions(self) -> tuple[int, int]: