mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-16 01:00:49 +08:00
This allows for example Wan/HunyuanVideo image to video to "work" by using the initial start frame for each window, otherwise windows beyond first will be pure T2V generations.
98 lines
6.5 KiB
Python
98 lines
6.5 KiB
Python
from __future__ import annotations
|
|
from comfy_api.latest import ComfyExtension, io
|
|
import comfy.context_windows
|
|
import nodes
|
|
|
|
|
|
class ContextWindowsManualNode(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
return io.Schema(
|
|
node_id="ContextWindowsManual",
|
|
display_name="Context Windows (Manual)",
|
|
category="context",
|
|
description="Manually set context windows.",
|
|
inputs=[
|
|
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
|
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."),
|
|
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."),
|
|
io.Combo.Input("context_schedule", options=[
|
|
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
|
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
|
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
|
comfy.context_windows.ContextSchedules.BATCHED,
|
|
], tooltip="The stride of the context window."),
|
|
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
|
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
|
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
|
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
|
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
|
],
|
|
outputs=[
|
|
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
|
],
|
|
is_experimental=True,
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, cond_retain_index_list: list[int]=[]) -> io.Model:
|
|
model = model.clone()
|
|
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
|
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
|
fuse_method=comfy.context_windows.get_matching_fuse_method(fuse_method),
|
|
context_length=context_length,
|
|
context_overlap=context_overlap,
|
|
context_stride=context_stride,
|
|
closed_loop=closed_loop,
|
|
dim=dim,
|
|
freenoise=freenoise,
|
|
cond_retain_index_list=cond_retain_index_list)
|
|
# make memory usage calculation only take into account the context window latents
|
|
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
|
if freenoise: # no other use for this wrapper at this time
|
|
comfy.context_windows.create_sampler_sample_wrapper(model)
|
|
return io.NodeOutput(model)
|
|
|
|
class WanContextWindowsManualNode(ContextWindowsManualNode):
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
schema = super().define_schema()
|
|
schema.node_id = "WanContextWindowsManual"
|
|
schema.display_name = "WAN Context Windows (Manual)"
|
|
schema.description = "Manually set context windows for WAN-like models (dim=2)."
|
|
schema.inputs = [
|
|
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
|
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window."),
|
|
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window."),
|
|
io.Combo.Input("context_schedule", options=[
|
|
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
|
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
|
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
|
comfy.context_windows.ContextSchedules.BATCHED,
|
|
], tooltip="The stride of the context window."),
|
|
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
|
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
|
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
|
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
|
]
|
|
return schema
|
|
|
|
@classmethod
|
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool, cond_retain_index_list: list[int]=[]) -> io.Model:
|
|
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
|
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
|
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list)
|
|
|
|
|
|
class ContextWindowsExtension(ComfyExtension):
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [
|
|
ContextWindowsManualNode,
|
|
WanContextWindowsManualNode,
|
|
]
|
|
|
|
def comfy_entrypoint():
|
|
return ContextWindowsExtension()
|