This commit is contained in:
Rattus 2026-04-09 20:32:28 +10:00
parent 1f0c02eb6c
commit 7dc366adc7

View File

@ -1,7 +1,10 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from contextlib import nullcontext
from comfy_execution.utils import CurrentNodeContext, get_executing_context
from comfy_execution.progress import get_progress_state
from .basic_types import ImageInput
@ -14,17 +17,33 @@ class ImageStreamInput(ABC):
def __init__(self):
#Subclasses must call this init for future core ComfyUI change compatibilty
pass
self._ctx = get_executing_context()
def reset(self) -> None:
#This API is final. Subclasses must NOT override this for future core ComfyUI
#change compatability. Override do_reset instead.
return self.do_reset()
with (nullcontext() if self._ctx is None else
CurrentNodeContext(self._ctx.prompt_id, self._ctx.node_id, self._ctx.list_index)):
self.do_reset()
if self._ctx is not None:
get_progress_state().finish_progress(self._ctx.node_id)
def pull(self, max_frames: int) -> ImageInput:
#This API is final. Subclasses must NOT override this for future core ComfyUI
#change compatability. Override do_pull instead.
return self.do_pull(max_frames)
with (nullcontext() if self._ctx is None else
CurrentNodeContext(self._ctx.prompt_id, self._ctx.node_id, self._ctx.list_index)):
result = self.do_pull(max_frames)
if self._ctx is not None:
registry = get_progress_state()
entry = registry.nodes.get(self._ctx.node_id)
if (int(result.shape[0]) < max_frames or
(entry is not None and entry["max"] > 0 and entry["value"] >= entry["max"])):
registry.finish_progress(self._ctx.node_id)
return result
@abstractmethod
def get_dimensions(self) -> tuple[int, int]: