Support bounded feedback loops in the DAG execution engine

Allow sampler nodes' internal iteration variables (e.g. step_index)
to flow back upstream through ComfyMathExpression nodes to control
per-step parameters (cfg, s_noise, eta, r) without triggering a
dependency cycle error.

Architecture: Two-level cycle handling
- Static validation: _is_bounded_feedback_cycle() allows cycles
  where any node declares BOUNDED_FEEDBACK
- Graph building: _is_feedback_output() skips strong links for
  declared feedback sockets, records them in feedback_links

Multi-hop chain walking via _build_feedback_fns() resolves
expression->CFGGuider/Sampler chains with simple_eval + MATH_FUNCTIONS,
composing per-step fn(step, total_steps) callables.

Sampler functions now re-read s_noise/eta/r each iteration via
_init_dynamic_options() / _refresh_dynamic_params() / _apply_dynamic_s_noise().
KSAMPLER.sample() conditionally injects mutable extra_options ref.

Safety: _dynamic_sampler_options popped at function top before model() calls.
One-line opt-in: BOUNDED_FEEDBACK = {'step_index'} on any node.
This commit is contained in:
PR Author 2026-06-19 18:40:47 +08:00
parent bd39bbf067
commit 983d6a1566
5 changed files with 2190 additions and 354 deletions

File diff suppressed because it is too large Load Diff

View File

@ -996,6 +996,12 @@ class KSAMPLER(Sampler):
if callback is not None: if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
# Expose mutable extra_options so sampler functions can re-read
# updated values at each step (e.g. s_noise varied by feedback).
# Only inject when the sampler has per-step feedback param functions,
# otherwise _dynamic_sampler_options would leak to the model call.
if hasattr(self, '_feedback_param_fns') and self._feedback_param_fns:
extra_args["_dynamic_sampler_options"] = self.extra_options
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples) samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
return samples return samples

View File

@ -111,6 +111,32 @@ class TopologicalSort:
self.blocking = {} # Which nodes are blocked by this node self.blocking = {} # Which nodes are blocked by this node
self.externalBlocks = 0 self.externalBlocks = 0
self.unblockedEvent = asyncio.Event() self.unblockedEvent = asyncio.Event()
# Tracks bounded-feedback edges that were intentionally excluded from
# strong (blocking) links. Maps to_node_id -> list of (from_node_id,
# from_socket) so the execution layer can inject initial values for the
# iteration output that closes the cycle.
self.feedback_links = {}
def _is_feedback_output(self, from_node_id, from_socket):
"""Return True when *from_socket* of *from_node_id* is a declared
bounded-iteration output (``BOUNDED_FEEDBACK``)."""
try:
class_type = self.dynprompt.get_node(from_node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS.get(class_type)
except (NodeNotFoundError, KeyError):
return False
if class_def is None:
return False
bounded = getattr(class_def, 'BOUNDED_FEEDBACK', None)
if not bounded:
return False
# Map socket index to name via RETURN_NAMES, falling back to the raw index.
return_names = getattr(class_def, 'RETURN_NAMES', None)
idx = int(from_socket)
if return_names is not None and 0 <= idx < len(return_names):
return return_names[idx] in bounded
# If the socket is already a string (uncommon), check directly.
return str(from_socket) in bounded
def get_input_info(self, unique_id, input_name): def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"] class_type = self.dynprompt.get_node(unique_id)["class_type"]
@ -163,6 +189,24 @@ class TopologicalSort:
links.append((from_node_id, from_socket, unique_id)) links.append((from_node_id, from_socket, unique_id))
for link in links: for link in links:
from_node_id, from_socket, to_node_id = link
if self._is_feedback_output(from_node_id, from_socket):
# This edge carries an iteration variable (e.g. step_index)
# back upstream to close a bounded feedback cycle. Don't
# create a strong (blocking) link — that would deadlock the
# topological dissolve. Instead record it so the execution
# layer can seed the iteration output with an initial value.
if to_node_id not in self.feedback_links:
self.feedback_links[to_node_id] = []
self.feedback_links[to_node_id].append((from_node_id, from_socket))
# Still ensure the source node is in the graph.
self.add_node(from_node_id)
# Create a cache link so the downstream node can read the
# placeholder value injected into the output cache by the
# execution bootstrap (only available on ExecutionList).
if hasattr(self, 'cache_link'):
self.cache_link(from_node_id, to_node_id)
continue
self.add_strong_link(*link) self.add_strong_link(*link)
def add_external_block(self, node_id): def add_external_block(self, node_id):

View File

@ -1011,6 +1011,10 @@ class RandomNoise(io.ComfyNode):
class SamplerCustomAdvanced(io.ComfyNode): class SamplerCustomAdvanced(io.ComfyNode):
# Declare which outputs are bounded iteration variables that may feed back
# through the graph to control upstream parameters (e.g. step_index -> cfg).
BOUNDED_FEEDBACK = {"step_index"}
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
@ -1026,6 +1030,7 @@ class SamplerCustomAdvanced(io.ComfyNode):
outputs=[ outputs=[
io.Latent.Output(display_name="output"), io.Latent.Output(display_name="output"),
io.Latent.Output(display_name="denoised_output"), io.Latent.Output(display_name="denoised_output"),
io.Int.Output(display_name="step_index"),
] ]
) )
@ -1041,8 +1046,30 @@ class SamplerCustomAdvanced(io.ComfyNode):
if "noise_mask" in latent: if "noise_mask" in latent:
noise_mask = latent["noise_mask"] noise_mask = latent["noise_mask"]
total_steps = sigmas.shape[-1] - 1
x0_output = {} x0_output = {}
callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output) callback = latent_preview.prepare_callback(guider.model_patcher, total_steps, x0_output)
# ---- bounded-feedback per-step updates ----
# The execution engine may have injected per-step update functions
# onto the guider and/or sampler objects. Wrap the callback to
# apply them before the *next* sampling step. The k-diffusion
# callback fires *after* the model call for step i, so we pass
# i+1 so that step N uses parameters computed with a=N.
cfg_fn = getattr(guider, '_feedback_cfg_fn', None)
param_fns = getattr(sampler, '_feedback_param_fns', None)
_has_feedback = cfg_fn is not None or param_fns
if _has_feedback:
_orig_callback = callback
def _feedback_callback(step, x0, x, total_steps):
if cfg_fn is not None:
guider.cfg = cfg_fn(step + 1, total_steps)
if param_fns is not None:
for key, fn in param_fns.items():
sampler.extra_options[key] = fn(step + 1, total_steps)
_orig_callback(step, x0, x, total_steps)
callback = _feedback_callback
# ----------------------------------------------------
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed) samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
@ -1061,7 +1088,7 @@ class SamplerCustomAdvanced(io.ComfyNode):
out_denoised["samples"] = x0_out out_denoised["samples"] = x0_out
else: else:
out_denoised = out out_denoised = out
return io.NodeOutput(out, out_denoised) return io.NodeOutput(out, out_denoised, total_steps)
sample = execute sample = execute

View File

@ -110,6 +110,21 @@ class CacheType(Enum):
RAM_PRESSURE = 3 RAM_PRESSURE = 3
# Initial values for bounded-feedback iteration outputs keyed by ComfyUI type
# string. When the DAG contains a feedback loop (e.g. step_index → … → cfg
# → guider → sampler) the execution engine seeds the iteration output with
# the default listed here so the downstream chain can evaluate before the
# iteration-producing node runs.
_FEEDBACK_DEFAULTS = {
"INT": 0,
"FLOAT": 0.0,
"BOOLEAN": False,
"STRING": "",
"NUMBER": 0,
"PRIMITIVE": 0,
}
class CacheSet: class CacheSet:
def __init__(self, cache_type=None, cache_args={}): def __init__(self, cache_type=None, cache_args={}):
if cache_type == CacheType.NONE: if cache_type == CacheType.NONE:
@ -176,12 +191,28 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
continue # This might be a lazily-evaluated input continue # This might be a lazily-evaluated input
cached = execution_list.get_cache(input_unique_id, unique_id) cached = execution_list.get_cache(input_unique_id, unique_id)
if cached is None or cached.outputs is None: if cached is None or cached.outputs is None:
mark_missing() # If this is a bounded-feedback link whose source hasn't
# executed yet, supply the type-appropriate initial value
# (e.g. step_index=0) so the feedback chain can evaluate
# before the iteration-producing node runs.
if _is_feedback_link(execution_list, unique_id, input_unique_id, output_index):
default_val = _get_feedback_default(dynprompt, input_unique_id, output_index)
obj = default_val
if isinstance(obj, (int, float, bool, str)):
obj = (obj,)
input_data_all[x] = obj
else:
mark_missing()
continue continue
if output_index >= len(cached.outputs): if output_index >= len(cached.outputs):
mark_missing() mark_missing()
continue continue
obj = cached.outputs[output_index] obj = cached.outputs[output_index]
# Wrap atomic types (int, float, bool, str) in a tuple so
# _async_map_node_over_list can call len() on every input.
# The slice_dict helper then unwraps: (val,)[0] == val.
if isinstance(obj, (int, float, bool, str)):
obj = (obj,)
input_data_all[x] = obj input_data_all[x] = obj
elif input_category is not None or (is_v3 and class_def.ACCEPT_ALL_INPUTS): elif input_category is not None or (is_v3 and class_def.ACCEPT_ALL_INPUTS):
input_data_all[x] = [input_data] input_data_all[x] = [input_data]
@ -658,6 +689,209 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
return (ExecutionResult.SUCCESS, None, None) return (ExecutionResult.SUCCESS, None, None)
def _is_feedback_link(execution_list, to_node_id, from_node_id, from_socket):
"""Return True when *to_node_id* receives *from_node_id*:*from_socket*
through a bounded-feedback edge (recorded during graph construction)."""
edges = execution_list.feedback_links.get(to_node_id, [])
return (from_node_id, from_socket) in edges
def _get_feedback_default(dynprompt, from_node_id, from_socket):
"""Return the type-appropriate initial value for a feedback iteration
output (e.g. 0 for INT, 0.0 for FLOAT)."""
try:
class_type = dynprompt.get_node(from_node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return_types = class_def.RETURN_TYPES
except Exception:
return 0
if from_socket < len(return_types):
return _FEEDBACK_DEFAULTS.get(return_types[from_socket], 0)
return 0
def _build_feedback_fns(dynamic_prompt, from_node_id, from_socket, to_node_id,
cfg_injections, sampler_injections):
"""Try to build per-step update functions from a feedback edge.
Walks forward from the feedback-receiving node through intermediate
ComfyMathExpression nodes to find targets that need per-step callables.
Handles two target types:
* **CFGGuider** populates *cfg_injections* keyed by guider node id
with a ``cfg_fn(step, total_steps)`` callable.
* **Sampler-producing nodes** (any node whose class_type starts with
"Sampler" except the iteration node itself) populates
*sampler_injections* keyed by (sampler_node_id, param_name) with a
``param_fn(step, total_steps)`` callable.
Supports multi-hop chains like::
iteration_node (step_index) MathExpr_A MathExpr_B CFGGuider
SamplerXXX
...
"""
try:
prompt = dynamic_prompt.original_prompt
except Exception:
return
from simpleeval import simple_eval
from comfy_extras.nodes_math import MATH_FUNCTIONS
# ---- helpers ----
def _find_consumers(source_id):
consumers = []
for nid, n in prompt.items():
for iname, ival in n.get("inputs", {}).items():
if isinstance(ival, list) and len(ival) == 2 \
and ival[0] == source_id and ival[1] == 0:
consumers.append((nid, n.get("class_type"), iname))
return consumers
def _is_sampler_target(class_type):
# Sampler-producing nodes whose parameters can be updated per-step
# via KSAMPLER.extra_options.
return (class_type is not None
and "Sampler" in class_type
and class_type != "SamplerCustomAdvanced")
def _resolve_input_value(source_node_id, source_socket):
"""Try to resolve a non-feedback linked input to a static value.
First checks the source node's ``inputs`` dict (API format) for a
direct scalar value at the socket. Falls back to ``widgets_values``
positional mapping (workflow-file format). Returns the resolved
value, or None if unresolvable.
"""
try:
snode = prompt.get(str(source_node_id))
if snode is None:
return None
class_type = snode.get("class_type", "")
inputs = snode.get("inputs", {})
# API format: inputs are named — find the name that maps to
# *source_socket* via the class's INPUT_TYPES ordering.
cls = nodes.NODE_CLASS_MAPPINGS.get(class_type)
if cls is not None:
try:
input_types = cls.INPUT_TYPES()
except Exception:
input_types = {}
required = input_types.get("required", {})
req_names = list(required.keys())
if source_socket < len(req_names):
name = req_names[source_socket]
val = inputs.get(name)
if val is not None and not isinstance(val, list):
return val
# Fallback: widgets_values positional mapping (workflow-file format)
wv = snode.get("widgets_values", [])
if wv:
if class_type in ("PrimitiveInt", "PrimitiveFloat", "PrimitiveBool"):
if source_socket == 0 and len(wv) > 0:
return wv[0]
if cls is not None and source_socket < len(req_names) and source_socket < len(wv):
return wv[source_socket]
return None
except Exception:
return None
def _collect_extra_names(node_id, feedback_from_node, feedback_from_socket,
feedback_var_name):
"""Collect non-feedback linked inputs from a MathExpression node
and resolve them to values. Returns dict of namevalue."""
extra = {}
try:
snode = prompt.get(str(node_id))
if snode is None:
return extra
for inp_name, inp_val in snode.get("inputs", {}).items():
if not isinstance(inp_val, list) or len(inp_val) != 2:
continue
src_id, src_socket = inp_val[0], inp_val[1]
# Skip the feedback-linked input — that's the iteration variable
if (src_id == str(feedback_from_node)
and int(src_socket) == int(feedback_from_socket)):
continue
# This is an additional linked input — try to resolve it
val = _resolve_input_value(src_id, src_socket)
if val is not None:
var_name = inp_name.rsplit(".", 1)[-1]
extra[var_name] = val
except Exception:
pass
return extra
# Each chain element is now (expression, feedback_var, extra_names_dict)
# ---- depth-first search ----
def _dfs(start_id, from_node, from_socket, chain):
"""Walk the MathExpr chain looking for any target node that needs
per-step updates. Returns a list of (target_type, target_id,
input_name, full_chain) tuples, where target_type is 'guider'
or 'sampler'."""
try:
node = dynamic_prompt.get_node(start_id)
except Exception:
return []
if node.get("class_type") != "ComfyMathExpression":
return []
expression = node.get("inputs", {}).get("expression", "")
if not expression or not expression.strip():
return []
var_name = None
for input_name, input_val in node.get("inputs", {}).items():
if isinstance(input_val, list) and len(input_val) == 2 \
and input_val[0] == from_node and input_val[1] == from_socket:
var_name = input_name.rsplit(".", 1)[-1]
break
if var_name is None:
return []
# Collect additional (non-feedback) input values for this node
extra_names = _collect_extra_names(start_id, from_node, from_socket,
var_name)
new_chain = chain + [(expression, var_name, extra_names)]
results = []
for cid, ctype, ciname in _find_consumers(start_id):
if ctype == "CFGGuider":
results.append(("guider", cid, None, new_chain))
elif _is_sampler_target(ctype):
results.append(("sampler", cid, ciname, new_chain))
elif ctype == "ComfyMathExpression":
results.extend(_dfs(cid, start_id, 0, new_chain))
return results
# ---- compose functions from discovered chains ----
for target_type, target_id, param_name, chain in \
_dfs(to_node_id, from_node_id, from_socket, []):
if not chain:
continue
def _make_fn(_chain):
def _fn(step, total_steps):
val = step
for expr_str, var, extra_names in _chain:
ctx = dict(extra_names) if extra_names else {}
ctx[var] = val
val = float(simple_eval(expr_str, names=ctx, functions=MATH_FUNCTIONS))
return val
return _fn
if target_type == "guider":
cfg_injections[target_id] = _make_fn(chain)
elif target_type == "sampler" and param_name:
sampler_injections[target_id] = sampler_injections.get(target_id, {})
sampler_injections[target_id][param_name] = _make_fn(chain)
class PromptExecutor: class PromptExecutor:
def __init__(self, server, cache_type=False, cache_args=None): def __init__(self, server, cache_type=False, cache_args=None):
self.cache_args = cache_args self.cache_args = cache_args
@ -774,6 +1008,26 @@ class PromptExecutor:
for node_id in list(execute_outputs): for node_id in list(execute_outputs):
execution_list.add_node(node_id) execution_list.add_node(node_id)
# ---- bounded-feedback bootstrap ---------------------------------
# Build per-step update functions for feedback chains that
# pass through ComfyMathExpression → CFGGuider / SamplerXXX.
# These are injected into the guider / sampler after the
# target node executes so the sampler can vary parameters
# (cfg, s_noise, ...) with step_index.
_feedback_cfg_injections = {} # guider_node_id → cfg_fn
_feedback_sampler_injections = {} # sampler_node_id → {param: fn}
for to_node_id, edges in execution_list.feedback_links.items():
for from_node_id, from_socket in edges:
try:
_build_feedback_fns(
dynamic_prompt, from_node_id, from_socket,
to_node_id, _feedback_cfg_injections,
_feedback_sampler_injections,
)
except Exception:
pass # non-critical feedback just wonʼt vary per step
# -----------------------------------------------------------------
while not execution_list.is_empty(): while not execution_list.is_empty():
node_id, error, ex = await execution_list.stage_node_execution() node_id, error, ex = await execution_list.stage_node_execution()
if error is not None: if error is not None:
@ -789,6 +1043,29 @@ class PromptExecutor:
elif result == ExecutionResult.PENDING: elif result == ExecutionResult.PENDING:
execution_list.unstage_node_execution() execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS: else: # result == ExecutionResult.SUCCESS:
# ---- bounded-feedback injection ----
# If this node just produced a guider or sampler
# that is part of a feedback cycle, inject per-step
# update function(s).
if node_id in _feedback_cfg_injections:
try:
output = self.caches.outputs.get_local(node_id)
if output is not None and output.outputs is not None \
and len(output.outputs) > 0 and len(output.outputs[0]) > 0:
guider = output.outputs[0][0]
guider._feedback_cfg_fn = _feedback_cfg_injections[node_id]
except Exception:
pass
if node_id in _feedback_sampler_injections:
try:
output = self.caches.outputs.get_local(node_id)
if output is not None and output.outputs is not None \
and len(output.outputs) > 0 and len(output.outputs[0]) > 0:
sampler_obj = output.outputs[0][0]
sampler_obj._feedback_param_fns = _feedback_sampler_injections[node_id]
except Exception:
pass
# ---------------------------------------
execution_list.complete_node_execution() execution_list.complete_node_execution()
if self.cache_type == CacheType.RAM_PRESSURE: if self.cache_type == CacheType.RAM_PRESSURE:
@ -831,6 +1108,34 @@ class PromptExecutor:
self._notify_prompt_lifecycle("end", prompt_id) self._notify_prompt_lifecycle("end", prompt_id)
def _is_bounded_feedback_cycle(prompt, visiting, unique_id):
"""Check whether a detected dependency cycle is a *bounded* feedback loop.
A cycle is bounded when at least one node in it declares ``BOUNDED_FEEDBACK``,
i.e. the node has a finite internal iteration whose step / index variable
feeds back upstream to control its own parameters (e.g. a sampler's
``step_index`` flowing through a math expression to set ``cfg``).
Because the iteration is bounded (N steps, then terminates) this isn't an
infinite cycle the DAG can safely allow it and the execution engine will
break the feedback edge by seeding the iteration output with an initial value.
"""
cycle_nodes = visiting[visiting.index(unique_id):] + [unique_id]
for node_id in cycle_nodes:
if node_id not in prompt:
continue
class_type = prompt[node_id].get('class_type')
if class_type is None:
continue
obj_class = nodes.NODE_CLASS_MAPPINGS.get(class_type)
if obj_class is None:
continue
bounded = getattr(obj_class, 'BOUNDED_FEEDBACK', None)
if bounded:
return True
return False
async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
if visiting is None: if visiting is None:
visiting = [] visiting = []
@ -842,6 +1147,19 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
if unique_id in visiting: if unique_id in visiting:
cycle_path_nodes = visiting[visiting.index(unique_id):] + [unique_id] cycle_path_nodes = visiting[visiting.index(unique_id):] + [unique_id]
cycle_nodes = list(dict.fromkeys(cycle_path_nodes)) cycle_nodes = list(dict.fromkeys(cycle_path_nodes))
# A bounded feedback cycle is one where at least one node in the cycle
# declares BOUNDED_FEEDBACK — meaning its internal iteration is finite
# and its iteration output(s) can safely flow back upstream without
# causing an infinite loop (e.g. a sampler's step_index controlling cfg).
if _is_bounded_feedback_cycle(prompt, visiting, unique_id):
# Mark the repeated node as valid and continue the traversal on
# other branches. The execution layer handles the feedback edge
# by breaking it and seeding the iteration output with an initial
# value (e.g. step_index = 0).
validated[unique_id] = (True, [], unique_id)
return validated[unique_id]
cycle_path = " -> ".join(f"{node_id} ({prompt[node_id]['class_type']})" for node_id in cycle_path_nodes) cycle_path = " -> ".join(f"{node_id} ({prompt[node_id]['class_type']})" for node_id in cycle_path_nodes)
for node_id in cycle_nodes: for node_id in cycle_nodes:
validated[node_id] = (False, [{ validated[node_id] = (False, [{