diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml
index a5a1ed2d0..2bc8e5905 100644
--- a/.github/workflows/stable-release.yml
+++ b/.github/workflows/stable-release.yml
@@ -90,7 +90,7 @@ jobs:
cd ..
- "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
+ "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
cd ComfyUI_windows_portable
diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml
index 3334e6839..46375698e 100644
--- a/.github/workflows/windows_release_package.yml
+++ b/.github/workflows/windows_release_package.yml
@@ -86,7 +86,7 @@ jobs:
cd ..
- "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
+ "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
cd ComfyUI_windows_portable
diff --git a/README.md b/README.md
index e4cff01a9..fa99a8cbe 100644
--- a/README.md
+++ b/README.md
@@ -39,7 +39,7 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
## Get Started
#### [Desktop Application](https://www.comfy.org/download)
-- The easiest way to get started.
+- The easiest way to get started.
- Available on Windows & macOS.
#### [Windows Portable Package](#installing)
@@ -211,27 +211,19 @@ This is the command to install the nightly with ROCm 6.4 which might have some p
### Intel GPUs (Windows and Linux)
-(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
-
-1. To install PyTorch nightly, use the following command:
+(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
+
+1. To install PyTorch xpu, use the following command:
+
+```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu```
+
+This is the command to install the Pytorch xpu nightly which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
-2. Launch ComfyUI by running `python main.py`
-
-
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
-1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below:
-
-```
-conda install libuv
-pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
-```
-
-For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
-
-Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
+1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
### NVIDIA
@@ -352,7 +344,7 @@ Generate a self-signed certificate (not appropriate for shared/production use) a
Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`.
-> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
+> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
## Support and dev channel
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index 60e88860e..2e140f767 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -138,6 +138,8 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
+parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
+
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
diff --git a/comfy/context_windows.py b/comfy/context_windows.py
new file mode 100644
index 000000000..928b111df
--- /dev/null
+++ b/comfy/context_windows.py
@@ -0,0 +1,537 @@
+from __future__ import annotations
+from typing import TYPE_CHECKING, Callable
+import torch
+import numpy as np
+import collections
+from dataclasses import dataclass
+from abc import ABC, abstractmethod
+import logging
+import comfy.model_management
+import comfy.patcher_extension
+if TYPE_CHECKING:
+ from comfy.model_base import BaseModel
+ from comfy.model_patcher import ModelPatcher
+ from comfy.controlnet import ControlBase
+
+
+class ContextWindowABC(ABC):
+ def __init__(self):
+ ...
+
+ @abstractmethod
+ def get_tensor(self, full: torch.Tensor) -> torch.Tensor:
+ """
+ Get torch.Tensor applicable to current window.
+ """
+ raise NotImplementedError("Not implemented.")
+
+ @abstractmethod
+ def add_window(self, full: torch.Tensor, to_add: torch.Tensor) -> torch.Tensor:
+ """
+ Apply torch.Tensor of window to the full tensor, in place. Returns reference to updated full tensor, not a copy.
+ """
+ raise NotImplementedError("Not implemented.")
+
+class ContextHandlerABC(ABC):
+ def __init__(self):
+ ...
+
+ @abstractmethod
+ def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
+ raise NotImplementedError("Not implemented.")
+
+ @abstractmethod
+ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: ContextWindowABC, device=None) -> list:
+ raise NotImplementedError("Not implemented.")
+
+ @abstractmethod
+ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
+ raise NotImplementedError("Not implemented.")
+
+
+
+class IndexListContextWindow(ContextWindowABC):
+ def __init__(self, index_list: list[int], dim: int=0):
+ self.index_list = index_list
+ self.context_length = len(index_list)
+ self.dim = dim
+
+ def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
+ if dim is None:
+ dim = self.dim
+ if dim == 0 and full.shape[dim] == 1:
+ return full
+ idx = [slice(None)] * dim + [self.index_list]
+ return full[idx].to(device)
+
+ def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
+ if dim is None:
+ dim = self.dim
+ idx = [slice(None)] * dim + [self.index_list]
+ full[idx] += to_add
+ return full
+
+
+class IndexListCallbacks:
+ EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
+ COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
+ EXECUTE_START = "execute_start"
+ EXECUTE_CLEANUP = "execute_cleanup"
+
+ def init_callbacks(self):
+ return {}
+
+
+@dataclass
+class ContextSchedule:
+ name: str
+ func: Callable
+
+@dataclass
+class ContextFuseMethod:
+ name: str
+ func: Callable
+
+ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
+class IndexListContextHandler(ContextHandlerABC):
+ def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
+ self.context_schedule = context_schedule
+ self.fuse_method = fuse_method
+ self.context_length = context_length
+ self.context_overlap = context_overlap
+ self.context_stride = context_stride
+ self.closed_loop = closed_loop
+ self.dim = dim
+ self._step = 0
+
+ self.callbacks = {}
+
+ def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
+ # for now, assume first dim is batch - should have stored on BaseModel in actual implementation
+ if x_in.size(self.dim) > self.context_length:
+ logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
+ return True
+ return False
+
+ def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
+ if control.previous_controlnet is not None:
+ self.prepare_control_objects(control.previous_controlnet, device)
+ return control
+
+ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: IndexListContextWindow, device=None) -> list:
+ if cond_in is None:
+ return None
+ # reuse or resize cond items to match context requirements
+ resized_cond = []
+ # cond object is a list containing a dict - outer list is irrelevant, so just loop through it
+ for actual_cond in cond_in:
+ resized_actual_cond = actual_cond.copy()
+ # now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary
+ for key in actual_cond:
+ try:
+ cond_item = actual_cond[key]
+ if isinstance(cond_item, torch.Tensor):
+ # check that tensor is the expected length - x.size(0)
+ if self.dim < cond_item.ndim and cond_item.size(self.dim) == x_in.size(self.dim):
+ # if so, it's subsetting time - tell controls the expected indeces so they can handle them
+ actual_cond_item = window.get_tensor(cond_item)
+ resized_actual_cond[key] = actual_cond_item.to(device)
+ else:
+ resized_actual_cond[key] = cond_item.to(device)
+ # look for control
+ elif key == "control":
+ resized_actual_cond[key] = self.prepare_control_objects(cond_item, device)
+ elif isinstance(cond_item, dict):
+ new_cond_item = cond_item.copy()
+ # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
+ for cond_key, cond_value in new_cond_item.items():
+ if isinstance(cond_value, torch.Tensor):
+ if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
+ new_cond_item[cond_key] = window.get_tensor(cond_value, device)
+ # if has cond that is a Tensor, check if needs to be subset
+ elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
+ if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
+ new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
+ elif cond_key == "num_video_frames": # for SVD
+ new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
+ new_cond_item[cond_key].cond = window.context_length
+ resized_actual_cond[key] = new_cond_item
+ else:
+ resized_actual_cond[key] = cond_item
+ finally:
+ del cond_item # just in case to prevent VRAM issues
+ resized_cond.append(resized_actual_cond)
+ return resized_cond
+
+ def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
+ indexes = torch.where(model_options["transformer_options"]["sample_sigmas"] == timestep[0])
+ self._step = int(indexes[0])
+
+ def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
+ full_length = x_in.size(self.dim) # TODO: choose dim based on model
+ context_windows = self.context_schedule.func(full_length, self, model_options)
+ context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
+ return context_windows
+
+ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
+ self.set_step(timestep, model_options)
+ context_windows = self.get_context_windows(model, x_in, model_options)
+ enumerated_context_windows = list(enumerate(context_windows))
+
+ conds_final = [torch.zeros_like(x_in) for _ in conds]
+ if self.fuse_method.name == ContextFuseMethods.RELATIVE:
+ counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
+ else:
+ counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
+ biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
+
+ for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
+ callback(self, model, x_in, conds, timestep, model_options)
+
+ for enum_window in enumerated_context_windows:
+ results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
+ for result in results:
+ self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
+ conds_final, counts_final, biases_final)
+ try:
+ # finalize conds
+ if self.fuse_method.name == ContextFuseMethods.RELATIVE:
+ # relative is already normalized, so return as is
+ del counts_final
+ return conds_final
+ else:
+ # normalize conds via division by context usage counts
+ for i in range(len(conds_final)):
+ conds_final[i] /= counts_final[i]
+ del counts_final
+ return conds_final
+ finally:
+ for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
+ callback(self, model, x_in, conds, timestep, model_options)
+
+ def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
+ model_options, device=None, first_device=None):
+ results: list[ContextResults] = []
+ for window_idx, window in enumerated_context_windows:
+ # allow processing to end between context window executions for faster Cancel
+ comfy.model_management.throw_exception_if_processing_interrupted()
+
+ for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
+ callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
+
+ # update exposed params
+ model_options["transformer_options"]["context_window"] = window
+ # get subsections of x, timestep, conds
+ sub_x = window.get_tensor(x_in, device)
+ sub_timestep = window.get_tensor(timestep, device, dim=0)
+ sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
+
+ sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
+ if device is not None:
+ for i in range(len(sub_conds_out)):
+ sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
+ results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
+ return results
+
+
+ def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor,
+ conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
+ if self.fuse_method.name == ContextFuseMethods.RELATIVE:
+ for pos, idx in enumerate(window.index_list):
+ # bias is the influence of a specific index in relation to the whole context window
+ bias = 1 - abs(idx - (window.index_list[0] + window.index_list[-1]) / 2) / ((window.index_list[-1] - window.index_list[0] + 1e-2) / 2)
+ bias = max(1e-2, bias)
+ # take weighted average relative to total bias of current idx
+ for i in range(len(sub_conds_out)):
+ bias_total = biases_final[i][idx]
+ prev_weight = (bias_total / (bias_total + bias))
+ new_weight = (bias / (bias_total + bias))
+ # account for dims of tensors
+ idx_window = [slice(None)] * self.dim + [idx]
+ pos_window = [slice(None)] * self.dim + [pos]
+ # apply new values
+ conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
+ biases_final[i][idx] = bias_total + bias
+ else:
+ # add conds and counts based on weights of fuse method
+ weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
+ weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
+ for i in range(len(sub_conds_out)):
+ window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
+ window.add_window(counts_final[i], weights_tensor)
+
+ for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
+ callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
+
+
+def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
+ # limit noise_shape length to context_length for more accurate vram use estimation
+ model_options = kwargs.get("model_options", None)
+ if model_options is None:
+ raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
+ handler: IndexListContextHandler = model_options.get("context_handler", None)
+ if handler is not None:
+ noise_shape = list(noise_shape)
+ noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
+ return executor(model, noise_shape, *args, **kwargs)
+
+
+def create_prepare_sampling_wrapper(model: ModelPatcher):
+ model.add_wrapper_with_key(
+ comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING,
+ "ContextWindows_prepare_sampling",
+ _prepare_sampling_wrapper
+ )
+
+
+def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
+ total_dims = len(x_in.shape)
+ weights_tensor = torch.Tensor(weights).to(device=device)
+ for _ in range(dim):
+ weights_tensor = weights_tensor.unsqueeze(0)
+ for _ in range(total_dims - dim - 1):
+ weights_tensor = weights_tensor.unsqueeze(-1)
+ return weights_tensor
+
+def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
+ total_dims = len(x_in.shape)
+ shape = []
+ for _ in range(dim):
+ shape.append(1)
+ shape.append(x_in.shape[dim])
+ for _ in range(total_dims - dim - 1):
+ shape.append(1)
+ return shape
+
+class ContextSchedules:
+ UNIFORM_LOOPED = "looped_uniform"
+ UNIFORM_STANDARD = "standard_uniform"
+ STATIC_STANDARD = "standard_static"
+ BATCHED = "batched"
+
+
+# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py
+def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
+ windows = []
+ if num_frames < handler.context_length:
+ windows.append(list(range(num_frames)))
+ return windows
+
+ context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
+ # obtain uniform windows as normal, looping and all
+ for context_step in 1 << np.arange(context_stride):
+ pad = int(round(num_frames * ordered_halving(handler._step)))
+ for j in range(
+ int(ordered_halving(handler._step) * context_step) + pad,
+ num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
+ (handler.context_length * context_step - handler.context_overlap),
+ ):
+ windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
+
+ return windows
+
+def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
+ # unlike looped, uniform_straight does NOT allow windows that loop back to the beginning;
+ # instead, they get shifted to the corresponding end of the frames.
+ # in the case that a window (shifted or not) is identical to the previous one, it gets skipped.
+ windows = []
+ if num_frames <= handler.context_length:
+ windows.append(list(range(num_frames)))
+ return windows
+
+ context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
+ # first, obtain uniform windows as normal, looping and all
+ for context_step in 1 << np.arange(context_stride):
+ pad = int(round(num_frames * ordered_halving(handler._step)))
+ for j in range(
+ int(ordered_halving(handler._step) * context_step) + pad,
+ num_frames + pad + (-handler.context_overlap),
+ (handler.context_length * context_step - handler.context_overlap),
+ ):
+ windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
+
+ # now that windows are created, shift any windows that loop, and delete duplicate windows
+ delete_idxs = []
+ win_i = 0
+ while win_i < len(windows):
+ # if window is rolls over itself, need to shift it
+ is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
+ if is_roll:
+ roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
+ shift_window_to_end(windows[win_i], num_frames=num_frames)
+ # check if next window (cyclical) is missing roll_val
+ if roll_val not in windows[(win_i+1) % len(windows)]:
+ # need to insert new window here - just insert window starting at roll_val
+ windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length)))
+ # delete window if it's not unique
+ for pre_i in range(0, win_i):
+ if windows[win_i] == windows[pre_i]:
+ delete_idxs.append(win_i)
+ break
+ win_i += 1
+
+ # reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
+ delete_idxs.reverse()
+ for i in delete_idxs:
+ windows.pop(i)
+
+ return windows
+
+
+def create_windows_static_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
+ windows = []
+ if num_frames <= handler.context_length:
+ windows.append(list(range(num_frames)))
+ return windows
+ # always return the same set of windows
+ delta = handler.context_length - handler.context_overlap
+ for start_idx in range(0, num_frames, delta):
+ # if past the end of frames, move start_idx back to allow same context_length
+ ending = start_idx + handler.context_length
+ if ending >= num_frames:
+ final_delta = ending - num_frames
+ final_start_idx = start_idx - final_delta
+ windows.append(list(range(final_start_idx, final_start_idx + handler.context_length)))
+ break
+ windows.append(list(range(start_idx, start_idx + handler.context_length)))
+ return windows
+
+
+def create_windows_batched(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
+ windows = []
+ if num_frames <= handler.context_length:
+ windows.append(list(range(num_frames)))
+ return windows
+ # always return the same set of windows;
+ # no overlap, just cut up based on context_length;
+ # last window size will be different if num_frames % opts.context_length != 0
+ for start_idx in range(0, num_frames, handler.context_length):
+ windows.append(list(range(start_idx, min(start_idx + handler.context_length, num_frames))))
+ return windows
+
+
+def create_windows_default(num_frames: int, handler: IndexListContextHandler):
+ return [list(range(num_frames))]
+
+
+CONTEXT_MAPPING = {
+ ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped,
+ ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard,
+ ContextSchedules.STATIC_STANDARD: create_windows_static_standard,
+ ContextSchedules.BATCHED: create_windows_batched,
+}
+
+
+def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
+ func = CONTEXT_MAPPING.get(context_schedule, None)
+ if func is None:
+ raise ValueError(f"Unknown context_schedule '{context_schedule}'.")
+ return ContextSchedule(context_schedule, func)
+
+
+def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
+ return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
+
+
+def create_weights_flat(length: int, **kwargs) -> list[float]:
+ # weight is the same for all
+ return [1.0] * length
+
+def create_weights_pyramid(length: int, **kwargs) -> list[float]:
+ # weight is based on the distance away from the edge of the context window;
+ # based on weighted average concept in FreeNoise paper
+ if length % 2 == 0:
+ max_weight = length // 2
+ weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))
+ else:
+ max_weight = (length + 1) // 2
+ weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
+ return weight_sequence
+
+def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
+ # based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
+ # only expected overlap is given different weights
+ weights_torch = torch.ones((length))
+ # blend left-side on all except first window
+ if min(idxs) > 0:
+ ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
+ weights_torch[:handler.context_overlap] = ramp_up
+ # blend right-side on all except last window
+ if max(idxs) < full_length-1:
+ ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
+ weights_torch[-handler.context_overlap:] = ramp_down
+ return weights_torch
+
+class ContextFuseMethods:
+ FLAT = "flat"
+ PYRAMID = "pyramid"
+ RELATIVE = "relative"
+ OVERLAP_LINEAR = "overlap-linear"
+
+ LIST = [PYRAMID, FLAT, OVERLAP_LINEAR]
+ LIST_STATIC = [PYRAMID, RELATIVE, FLAT, OVERLAP_LINEAR]
+
+
+FUSE_MAPPING = {
+ ContextFuseMethods.FLAT: create_weights_flat,
+ ContextFuseMethods.PYRAMID: create_weights_pyramid,
+ ContextFuseMethods.RELATIVE: create_weights_pyramid,
+ ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear,
+}
+
+def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod:
+ func = FUSE_MAPPING.get(fuse_method, None)
+ if func is None:
+ raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
+ return ContextFuseMethod(fuse_method, func)
+
+# Returns fraction that has denominator that is a power of 2
+def ordered_halving(val):
+ # get binary value, padded with 0s for 64 bits
+ bin_str = f"{val:064b}"
+ # flip binary value, padding included
+ bin_flip = bin_str[::-1]
+ # convert binary to int
+ as_int = int(bin_flip, 2)
+ # divide by 1 << 64, equivalent to 2**64, or 18446744073709551616,
+ # or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's)
+ return as_int / (1 << 64)
+
+
+def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
+ all_indexes = list(range(num_frames))
+ for w in windows:
+ for val in w:
+ try:
+ all_indexes.remove(val)
+ except ValueError:
+ pass
+ return all_indexes
+
+
+def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
+ prev_val = -1
+ for i, val in enumerate(window):
+ val = val % num_frames
+ if val < prev_val:
+ return True, i
+ prev_val = val
+ return False, -1
+
+
+def shift_window_to_start(window: list[int], num_frames: int):
+ start_val = window[0]
+ for i in range(len(window)):
+ # 1) subtract each element by start_val to move vals relative to the start of all frames
+ # 2) add num_frames and take modulus to get adjusted vals
+ window[i] = ((window[i] - start_val) + num_frames) % num_frames
+
+
+def shift_window_to_end(window: list[int], num_frames: int):
+ # 1) shift window to start
+ shift_window_to_start(window, num_frames)
+ end_val = window[-1]
+ end_delta = num_frames - end_val - 1
+ for i in range(len(window)):
+ # 2) add end_delta to each val to slide windows to end
+ window[i] = window[i] + end_delta
diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py
index 5eb2c6548..6e8cbf1d9 100644
--- a/comfy/ldm/hunyuan3d/vae.py
+++ b/comfy/ldm/hunyuan3d/vae.py
@@ -178,7 +178,7 @@ class FourierEmbedder(nn.Module):
class CrossAttentionProcessor:
def __call__(self, attn, q, k, v):
- out = F.scaled_dot_product_attention(q, k, v)
+ out = comfy.ops.scaled_dot_product_attention(q, k, v)
return out
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index 35d2270ee..043df28df 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
mask = mask.unsqueeze(1)
if SDP_BATCH_LIMIT >= b:
- out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
+ out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.shape[0] > 1:
m = mask[i : i + SDP_BATCH_LIMIT]
- out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
+ out[i : i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention(
q[i : i + SDP_BATCH_LIMIT],
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],
diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py
index 5c0373b74..1fd12b35a 100644
--- a/comfy/ldm/modules/diffusionmodules/model.py
+++ b/comfy/ldm/modules/diffusionmodules/model.py
@@ -285,7 +285,7 @@ def pytorch_attention(q, k, v):
)
try:
- out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
+ out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
diff --git a/comfy/model_management.py b/comfy/model_management.py
index c08f759e5..2a9f18068 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -78,7 +78,6 @@ try:
torch_version = torch.version.__version__
temp = torch_version.split(".")
torch_version_numeric = (int(temp[0]), int(temp[1]))
- xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available()
except:
pass
@@ -102,10 +101,14 @@ if args.directml is not None:
try:
import intel_extension_for_pytorch as ipex # noqa: F401
- _ = torch.xpu.device_count()
- xpu_available = xpu_available or torch.xpu.is_available()
except:
- xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
+ pass
+
+try:
+ _ = torch.xpu.device_count()
+ xpu_available = torch.xpu.is_available()
+except:
+ xpu_available = False
try:
if torch.backends.mps.is_available():
@@ -946,10 +949,12 @@ def pick_weight_dtype(dtype, fallback_dtype, device=None):
return dtype
def device_supports_non_blocking(device):
+ if args.force_non_blocking:
+ return True
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
- if is_intel_xpu():
- return True
+ if is_intel_xpu(): #xpu does support non blocking but it is slower on iGPUs for some reason so disable by default until situation changes
+ return False
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
return False
if directml_enabled:
@@ -1282,10 +1287,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
if is_intel_xpu():
- if torch_version_numeric < (2, 6):
+ if torch_version_numeric < (2, 3):
return True
else:
- return torch.xpu.get_device_capability(device)['has_bfloat16_conversions']
+ return torch.xpu.is_bf16_supported()
if is_ascend_npu():
return True
diff --git a/comfy/ops.py b/comfy/ops.py
index 2cc9bbc27..be312d714 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -24,6 +24,29 @@ import comfy.float
import comfy.rmsnorm
import contextlib
+
+def scaled_dot_product_attention(q, k, v, *args, **kwargs):
+ return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
+
+
+try:
+ if torch.cuda.is_available():
+ from torch.nn.attention import SDPBackend, sdpa_kernel
+
+ SDPA_BACKEND_PRIORITY = [
+ SDPBackend.FLASH_ATTENTION,
+ SDPBackend.EFFICIENT_ATTENTION,
+ SDPBackend.MATH,
+ ]
+
+ SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
+
+ @sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
+ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
+ return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
+except (ModuleNotFoundError, TypeError):
+ logging.warning("Could not set sdpa backend priority.")
+
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
def cast_to_input(weight, input, non_blocking=False, copy=True):
diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py
index 8dbc41455..e46971afb 100644
--- a/comfy/sampler_helpers.py
+++ b/comfy/sampler_helpers.py
@@ -149,7 +149,7 @@ def cleanup_models(conds, models):
cleanup_additional_models(set(control_cleanup))
-def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
+def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
'''
Registers hooks from conds.
'''
@@ -158,8 +158,8 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
for k in conds:
get_hooks_from_cond(conds[k], hooks)
# add wrappers and callbacks from ModelPatcher to transformer_options
- model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
- model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
+ comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("wrappers", {}), model.wrappers, copy_dict1=False)
+ comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("callbacks", {}), model.callbacks, copy_dict1=False)
# begin registering hooks
registered = comfy.hooks.HookGroup()
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
diff --git a/comfy/samplers.py b/comfy/samplers.py
index ad2f40cdc..d5390d64e 100644
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -16,6 +16,7 @@ import comfy.sampler_helpers
import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
+import comfy.context_windows
import scipy.stats
import numpy
@@ -198,14 +199,20 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
-def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
+def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options: dict[str]):
+ handler: comfy.context_windows.ContextHandlerABC = model_options.get("context_handler", None)
+ if handler is None or not handler.should_use_context(model, conds, x_in, timestep, model_options):
+ return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options)
+ return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options)
+
+def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_calc_cond_batch,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
)
return executor.execute(model, conds, x_in, timestep, model_options)
-def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
+def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
out_conds = []
out_counts = []
# separate conds by matching hooks
diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py
new file mode 100644
index 000000000..1c3d9e697
--- /dev/null
+++ b/comfy_extras/nodes_context_windows.py
@@ -0,0 +1,89 @@
+from __future__ import annotations
+from comfy_api.latest import ComfyExtension, io
+import comfy.context_windows
+import nodes
+
+
+class ContextWindowsManualNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls) -> io.Schema:
+ return io.Schema(
+ node_id="ContextWindowsManual",
+ display_name="Context Windows (Manual)",
+ category="context",
+ description="Manually set context windows.",
+ inputs=[
+ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
+ io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."),
+ io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."),
+ io.Combo.Input("context_schedule", options=[
+ comfy.context_windows.ContextSchedules.STATIC_STANDARD,
+ comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
+ comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
+ comfy.context_windows.ContextSchedules.BATCHED,
+ ], tooltip="The stride of the context window."),
+ io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
+ io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
+ io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
+ io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
+ ],
+ outputs=[
+ io.Model.Output(tooltip="The model with context windows applied during sampling."),
+ ],
+ is_experimental=True,
+ )
+
+ @classmethod
+ def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
+ model = model.clone()
+ model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
+ context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
+ fuse_method=comfy.context_windows.get_matching_fuse_method(fuse_method),
+ context_length=context_length,
+ context_overlap=context_overlap,
+ context_stride=context_stride,
+ closed_loop=closed_loop,
+ dim=dim)
+ # make memory usage calculation only take into account the context window latents
+ comfy.context_windows.create_prepare_sampling_wrapper(model)
+ return io.NodeOutput(model)
+
+class WanContextWindowsManualNode(ContextWindowsManualNode):
+ @classmethod
+ def define_schema(cls) -> io.Schema:
+ schema = super().define_schema()
+ schema.node_id = "WanContextWindowsManual"
+ schema.display_name = "WAN Context Windows (Manual)"
+ schema.description = "Manually set context windows for WAN-like models (dim=2)."
+ schema.inputs = [
+ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
+ io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window."),
+ io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window."),
+ io.Combo.Input("context_schedule", options=[
+ comfy.context_windows.ContextSchedules.STATIC_STANDARD,
+ comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
+ comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
+ comfy.context_windows.ContextSchedules.BATCHED,
+ ], tooltip="The stride of the context window."),
+ io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
+ io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
+ io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
+ ]
+ return schema
+
+ @classmethod
+ def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
+ context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
+ context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
+ return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
+
+
+class ContextWindowsExtension(ComfyExtension):
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return [
+ ContextWindowsManualNode,
+ WanContextWindowsManualNode,
+ ]
+
+def comfy_entrypoint():
+ return ContextWindowsExtension()
diff --git a/nodes.py b/nodes.py
index 704d39dcd..1ec588487 100644
--- a/nodes.py
+++ b/nodes.py
@@ -2329,6 +2329,7 @@ async def init_builtin_extra_nodes():
"nodes_camera_trajectory.py",
"nodes_edit_model.py",
"nodes_tcfg.py",
+ "nodes_context_windows.py",
]
import_failed = []