mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 19:02:31 +08:00
progress
This commit is contained in:
parent
1f0c02eb6c
commit
7dc366adc7
@ -1,7 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
from comfy_execution.utils import CurrentNodeContext, get_executing_context
|
||||||
|
from comfy_execution.progress import get_progress_state
|
||||||
from .basic_types import ImageInput
|
from .basic_types import ImageInput
|
||||||
|
|
||||||
|
|
||||||
@ -14,17 +17,33 @@ class ImageStreamInput(ABC):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
#Subclasses must call this init for future core ComfyUI change compatibilty
|
#Subclasses must call this init for future core ComfyUI change compatibilty
|
||||||
pass
|
self._ctx = get_executing_context()
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
#This API is final. Subclasses must NOT override this for future core ComfyUI
|
#This API is final. Subclasses must NOT override this for future core ComfyUI
|
||||||
#change compatability. Override do_reset instead.
|
#change compatability. Override do_reset instead.
|
||||||
return self.do_reset()
|
with (nullcontext() if self._ctx is None else
|
||||||
|
CurrentNodeContext(self._ctx.prompt_id, self._ctx.node_id, self._ctx.list_index)):
|
||||||
|
self.do_reset()
|
||||||
|
|
||||||
|
if self._ctx is not None:
|
||||||
|
get_progress_state().finish_progress(self._ctx.node_id)
|
||||||
|
|
||||||
def pull(self, max_frames: int) -> ImageInput:
|
def pull(self, max_frames: int) -> ImageInput:
|
||||||
#This API is final. Subclasses must NOT override this for future core ComfyUI
|
#This API is final. Subclasses must NOT override this for future core ComfyUI
|
||||||
#change compatability. Override do_pull instead.
|
#change compatability. Override do_pull instead.
|
||||||
return self.do_pull(max_frames)
|
with (nullcontext() if self._ctx is None else
|
||||||
|
CurrentNodeContext(self._ctx.prompt_id, self._ctx.node_id, self._ctx.list_index)):
|
||||||
|
result = self.do_pull(max_frames)
|
||||||
|
|
||||||
|
if self._ctx is not None:
|
||||||
|
registry = get_progress_state()
|
||||||
|
entry = registry.nodes.get(self._ctx.node_id)
|
||||||
|
if (int(result.shape[0]) < max_frames or
|
||||||
|
(entry is not None and entry["max"] > 0 and entry["value"] >= entry["max"])):
|
||||||
|
registry.finish_progress(self._ctx.node_id)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_dimensions(self) -> tuple[int, int]:
|
def get_dimensions(self) -> tuple[int, int]:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user