Add context windows callback for custom cond handling (#11208)

Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com>
This commit is contained in:
drozbay 2025-12-15 17:06:32 -07:00 committed by GitHub
parent 43e0d4e3cc
commit 77b2f7c228
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -87,6 +87,7 @@ class IndexListCallbacks:
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results" COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
EXECUTE_START = "execute_start" EXECUTE_START = "execute_start"
EXECUTE_CLEANUP = "execute_cleanup" EXECUTE_CLEANUP = "execute_cleanup"
RESIZE_COND_ITEM = "resize_cond_item"
def init_callbacks(self): def init_callbacks(self):
return {} return {}
@ -166,6 +167,18 @@ class IndexListContextHandler(ContextHandlerABC):
new_cond_item = cond_item.copy() new_cond_item = cond_item.copy()
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor) # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items(): for cond_key, cond_value in new_cond_item.items():
# Allow callbacks to handle custom conditioning items
handled = False
for callback in comfy.patcher_extension.get_all_callbacks(
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
):
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
if result is not None:
new_cond_item[cond_key] = result
handled = True
break
if handled:
continue
if isinstance(cond_value, torch.Tensor): if isinstance(cond_value, torch.Tensor):
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \ if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)): (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):