Update to 0.3.51

This commit is contained in:
doctorpangloss 2025-08-22 17:29:18 -07:00
parent 664349eabf
commit 735a133ad4
19 changed files with 131 additions and 96 deletions

View File

@ -78,7 +78,6 @@ class ModelFileManager:
return web.Response(status=404) return web.Response(status=404)
def get_model_file_list(self, folder_name: str): def get_model_file_list(self, folder_name: str):
folder_name = folder_paths.map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name] folders = folder_paths.folder_names_and_paths[folder_name]
output_list: list[dict] = [] output_list: list[dict] = []

View File

@ -45,6 +45,7 @@ logging.getLogger("torch.distributed.elastic.multiprocessing.redirects").addFilt
lambda record: log_msg_to_filter not in record.getMessage() lambda record: log_msg_to_filter not in record.getMessage()
) )
logging.getLogger("alembic.runtime.migration").setLevel(logging.WARNING) logging.getLogger("alembic.runtime.migration").setLevel(logging.WARNING)
logging.getLogger("asyncio").addFilter(lambda record: 'Using selector:' not in record.getMessage())
from ..cli_args import args from ..cli_args import args

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import logging import logging
import os import os
import platform
import time import time
from concurrent.futures import Future from concurrent.futures import Future
from pathlib import Path from pathlib import Path
@ -17,16 +18,19 @@ from pebble import ThreadPool
from .tqdm_watcher import TqdmWatcher from .tqdm_watcher import TqdmWatcher
logger = logging.getLogger(__name__)
_VAR = "HF_HUB_ENABLE_HF_TRANSFER" _VAR = "HF_HUB_ENABLE_HF_TRANSFER"
_XET_VAR = "HF_XET_HIGH_PERFORMANCE" _XET_VAR = "HF_XET_HIGH_PERFORMANCE"
os.environ[_VAR] = "True"
os.environ["HF_HUB_DISABLE_XET"] = "1" if platform.system() == "Windows":
# os.environ["HF_XET_HIGH_PERFORMANCE"] = "True" os.environ["HF_HUB_DISABLE_XET"] = "1"
logger.debug("Xet was disabled since it is currently not reliable")
os.environ[_VAR] = "True"
else:
os.environ[_XET_VAR] = "True"
logger = logging.getLogger(__name__)
logger.debug("Xet was disabled since it is currently not reliable")
def hf_hub_download_with_disable_fast(repo_id=None, filename=None, disable_fast=None, hf_env: dict[str, str] = None, **kwargs): def hf_hub_download_with_disable_fast(repo_id=None, filename=None, disable_fast=None, hf_env: dict[str, str] = None, **kwargs):

View File

@ -6,12 +6,13 @@ import collections
from dataclasses import dataclass from dataclasses import dataclass
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import logging import logging
import comfy.model_management from .model_management import throw_exception_if_processing_interrupted
import comfy.patcher_extension from .patcher_extension import get_all_callbacks, WrappersMP
if TYPE_CHECKING: if TYPE_CHECKING:
from comfy.model_base import BaseModel from .model_base import BaseModel
from comfy.model_patcher import ModelPatcher from .model_patcher import ModelPatcher
from comfy.controlnet import ControlBase from .controlnet import ControlBase
class ContextWindowABC(ABC): class ContextWindowABC(ABC):
@ -32,6 +33,7 @@ class ContextWindowABC(ABC):
""" """
raise NotImplementedError("Not implemented.") raise NotImplementedError("Not implemented.")
class ContextHandlerABC(ABC): class ContextHandlerABC(ABC):
def __init__(self): def __init__(self):
... ...
@ -49,9 +51,8 @@ class ContextHandlerABC(ABC):
raise NotImplementedError("Not implemented.") raise NotImplementedError("Not implemented.")
class IndexListContextWindow(ContextWindowABC): class IndexListContextWindow(ContextWindowABC):
def __init__(self, index_list: list[int], dim: int=0): def __init__(self, index_list: list[int], dim: int = 0):
self.index_list = index_list self.index_list = index_list
self.context_length = len(index_list) self.context_length = len(index_list)
self.dim = dim self.dim = dim
@ -87,14 +88,18 @@ class ContextSchedule:
name: str name: str
func: Callable func: Callable
@dataclass @dataclass
class ContextFuseMethod: class ContextFuseMethod:
name: str name: str
func: Callable func: Callable
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window']) ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
class IndexListContextHandler(ContextHandlerABC): class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0): def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int = 1, context_overlap: int = 0, context_stride: int = 1, closed_loop=False, dim=0):
self.context_schedule = context_schedule self.context_schedule = context_schedule
self.fuse_method = fuse_method self.fuse_method = fuse_method
self.context_length = context_length self.context_length = context_length
@ -152,7 +157,7 @@ class IndexListContextHandler(ContextHandlerABC):
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim): if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device)) new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
elif cond_key == "num_video_frames": # for SVD elif cond_key == "num_video_frames": # for SVD
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond) new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
new_cond_item[cond_key].cond = window.context_length new_cond_item[cond_key].cond = window.context_length
resized_actual_cond[key] = new_cond_item resized_actual_cond[key] = new_cond_item
@ -171,7 +176,7 @@ class IndexListContextHandler(ContextHandlerABC):
self._step = int(matches[0].item()) self._step = int(matches[0].item())
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model full_length = x_in.size(self.dim) # TODO: choose dim based on model
context_windows = self.context_schedule.func(full_length, self, model_options) context_windows = self.context_schedule.func(full_length, self, model_options)
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows] context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
return context_windows return context_windows
@ -188,14 +193,14 @@ class IndexListContextHandler(ContextHandlerABC):
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds] biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): for callback in get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options) callback(self, model, x_in, conds, timestep, model_options)
for enum_window in enumerated_context_windows: for enum_window in enumerated_context_windows:
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options) results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
for result in results: for result in results:
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep, self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
conds_final, counts_final, biases_final) conds_final, counts_final, biases_final)
try: try:
# finalize conds # finalize conds
if self.fuse_method.name == ContextFuseMethods.RELATIVE: if self.fuse_method.name == ContextFuseMethods.RELATIVE:
@ -209,17 +214,17 @@ class IndexListContextHandler(ContextHandlerABC):
del counts_final del counts_final
return conds_final return conds_final
finally: finally:
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): for callback in get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options) callback(self, model, x_in, conds, timestep, model_options)
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
model_options, device=None, first_device=None): model_options, device=None, first_device=None):
results: list[ContextResults] = [] results: list[ContextResults] = []
for window_idx, window in enumerated_context_windows: for window_idx, window in enumerated_context_windows:
# allow processing to end between context window executions for faster Cancel # allow processing to end between context window executions for faster Cancel
comfy.model_management.throw_exception_if_processing_interrupted() throw_exception_if_processing_interrupted()
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): for callback in get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
# update exposed params # update exposed params
@ -236,9 +241,8 @@ class IndexListContextHandler(ContextHandlerABC):
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window)) results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
return results return results
def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor, def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor,
conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]): conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
if self.fuse_method.name == ContextFuseMethods.RELATIVE: if self.fuse_method.name == ContextFuseMethods.RELATIVE:
for pos, idx in enumerate(window.index_list): for pos, idx in enumerate(window.index_list):
# bias is the influence of a specific index in relation to the whole context window # bias is the influence of a specific index in relation to the whole context window
@ -263,7 +267,7 @@ class IndexListContextHandler(ContextHandlerABC):
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor) window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
window.add_window(counts_final[i], weights_tensor) window.add_window(counts_final[i], weights_tensor)
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks): for callback in get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final) callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
@ -281,7 +285,7 @@ def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args,
def create_prepare_sampling_wrapper(model: ModelPatcher): def create_prepare_sampling_wrapper(model: ModelPatcher):
model.add_wrapper_with_key( model.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, WrappersMP.PREPARE_SAMPLING,
"ContextWindows_prepare_sampling", "ContextWindows_prepare_sampling",
_prepare_sampling_wrapper _prepare_sampling_wrapper
) )
@ -296,6 +300,7 @@ def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, dev
weights_tensor = weights_tensor.unsqueeze(-1) weights_tensor = weights_tensor.unsqueeze(-1)
return weights_tensor return weights_tensor
def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]: def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
total_dims = len(x_in.shape) total_dims = len(x_in.shape)
shape = [] shape = []
@ -306,6 +311,7 @@ def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
shape.append(1) shape.append(1)
return shape return shape
class ContextSchedules: class ContextSchedules:
UNIFORM_LOOPED = "looped_uniform" UNIFORM_LOOPED = "looped_uniform"
UNIFORM_STANDARD = "standard_uniform" UNIFORM_STANDARD = "standard_uniform"
@ -325,14 +331,15 @@ def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHand
for context_step in 1 << np.arange(context_stride): for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(handler._step))) pad = int(round(num_frames * ordered_halving(handler._step)))
for j in range( for j in range(
int(ordered_halving(handler._step) * context_step) + pad, int(ordered_halving(handler._step) * context_step) + pad,
num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap), num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
(handler.context_length * context_step - handler.context_overlap), (handler.context_length * context_step - handler.context_overlap),
): ):
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)]) windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
return windows return windows
def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]): def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
# unlike looped, uniform_straight does NOT allow windows that loop back to the beginning; # unlike looped, uniform_straight does NOT allow windows that loop back to the beginning;
# instead, they get shifted to the corresponding end of the frames. # instead, they get shifted to the corresponding end of the frames.
@ -347,9 +354,9 @@ def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHa
for context_step in 1 << np.arange(context_stride): for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(handler._step))) pad = int(round(num_frames * ordered_halving(handler._step)))
for j in range( for j in range(
int(ordered_halving(handler._step) * context_step) + pad, int(ordered_halving(handler._step) * context_step) + pad,
num_frames + pad + (-handler.context_overlap), num_frames + pad + (-handler.context_overlap),
(handler.context_length * context_step - handler.context_overlap), (handler.context_length * context_step - handler.context_overlap),
): ):
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)]) windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
@ -363,9 +370,9 @@ def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHa
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
shift_window_to_end(windows[win_i], num_frames=num_frames) shift_window_to_end(windows[win_i], num_frames=num_frames)
# check if next window (cyclical) is missing roll_val # check if next window (cyclical) is missing roll_val
if roll_val not in windows[(win_i+1) % len(windows)]: if roll_val not in windows[(win_i + 1) % len(windows)]:
# need to insert new window here - just insert window starting at roll_val # need to insert new window here - just insert window starting at roll_val
windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length))) windows.insert(win_i + 1, list(range(roll_val, roll_val + handler.context_length)))
# delete window if it's not unique # delete window if it's not unique
for pre_i in range(0, win_i): for pre_i in range(0, win_i):
if windows[win_i] == windows[pre_i]: if windows[win_i] == windows[pre_i]:
@ -432,7 +439,7 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
return ContextSchedule(context_schedule, func) return ContextSchedule(context_schedule, func)
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None): def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor = None):
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs) return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
@ -440,6 +447,7 @@ def create_weights_flat(length: int, **kwargs) -> list[float]:
# weight is the same for all # weight is the same for all
return [1.0] * length return [1.0] * length
def create_weights_pyramid(length: int, **kwargs) -> list[float]: def create_weights_pyramid(length: int, **kwargs) -> list[float]:
# weight is based on the distance away from the edge of the context window; # weight is based on the distance away from the edge of the context window;
# based on weighted average concept in FreeNoise paper # based on weighted average concept in FreeNoise paper
@ -451,6 +459,7 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
return weight_sequence return weight_sequence
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs): def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302 # based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
# only expected overlap is given different weights # only expected overlap is given different weights
@ -460,11 +469,12 @@ def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int]
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap) ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
weights_torch[:handler.context_overlap] = ramp_up weights_torch[:handler.context_overlap] = ramp_up
# blend right-side on all except last window # blend right-side on all except last window
if max(idxs) < full_length-1: if max(idxs) < full_length - 1:
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap) ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
weights_torch[-handler.context_overlap:] = ramp_down weights_torch[-handler.context_overlap:] = ramp_down
return weights_torch return weights_torch
class ContextFuseMethods: class ContextFuseMethods:
FLAT = "flat" FLAT = "flat"
PYRAMID = "pyramid" PYRAMID = "pyramid"
@ -482,12 +492,14 @@ FUSE_MAPPING = {
ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear, ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear,
} }
def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod: def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod:
func = FUSE_MAPPING.get(fuse_method, None) func = FUSE_MAPPING.get(fuse_method, None)
if func is None: if func is None:
raise ValueError(f"Unknown fuse_method '{fuse_method}'.") raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
return ContextFuseMethod(fuse_method, func) return ContextFuseMethod(fuse_method, func)
# Returns fraction that has denominator that is a power of 2 # Returns fraction that has denominator that is a power of 2
def ordered_halving(val): def ordered_halving(val):
# get binary value, padded with 0s for 64 bits # get binary value, padded with 0s for 64 bits

View File

@ -662,7 +662,7 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options) model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd) control_model = controlnet_load_state_dict(control_model, sd)
latent_format = comfy.latent_formats.Wan21() latent_format = latent_formats.Wan21()
extra_conds = [] extra_conds = []
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control return control

View File

@ -5,30 +5,37 @@ from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from diffusers.models.attention_dispatch import sageattn
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import nn, einsum from torch import nn, einsum
from .diffusionmodules.util import AlphaBlender, timestep_embedding from .diffusionmodules.util import AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
from ... import model_management from ... import model_management
from ...ops import scaled_dot_product_attention
logger = logging.getLogger(__name__)
if model_management.xformers_enabled(): if model_management.xformers_enabled():
import xformers # pylint: disable=import-error import xformers # pylint: disable=import-error
import xformers.ops # pylint: disable=import-error import xformers.ops # pylint: disable=import-error
sageattn = None
if model_management.sage_attention_enabled(): if model_management.sage_attention_enabled():
try: try:
from sageattention import sageattn # pylint: disable=import-error from sageattention import sageattn # pylint: disable=import-error
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
if e.name == "sageattention": if e.name == "sageattention":
import sys import sys
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
logger.error(f"To use the `--use-sage-attention` feature, the `sageattention` package must be installed first.")
else: else:
raise e raise e
sageattn = torch.nn.functional.scaled_dot_product_attention
else: else:
sageattn = torch.nn.functional.scaled_dot_product_attention sageattn = torch.nn.functional.scaled_dot_product_attention
flash_attn_func = None
if model_management.flash_attention_enabled(): if model_management.flash_attention_enabled():
from flash_attn import flash_attn_func # pylint: disable=import-error from flash_attn import flash_attn_func # pylint: disable=import-error
else: else:
@ -40,7 +47,6 @@ from ... import ops
ops = ops.disable_weight_init ops = ops.disable_weight_init
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype() FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
logger = logging.getLogger(__name__)
def get_attn_precision(attn_precision, current_dtype): def get_attn_precision(attn_precision, current_dtype):
@ -480,7 +486,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
mask = mask.unsqueeze(1) mask = mask.unsqueeze(1)
if SDP_BATCH_LIMIT >= b: if SDP_BATCH_LIMIT >= b:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape: if not skip_output_reshape:
out = ( out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head) out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -493,7 +499,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.shape[0] > 1: if mask.shape[0] > 1:
m = mask[i: i + SDP_BATCH_LIMIT] m = mask[i: i + SDP_BATCH_LIMIT]
out[i: i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention( out[i: i + SDP_BATCH_LIMIT] = scaled_dot_product_attention(
q[i: i + SDP_BATCH_LIMIT], q[i: i + SDP_BATCH_LIMIT],
k[i: i + SDP_BATCH_LIMIT], k[i: i + SDP_BATCH_LIMIT],
v[i: i + SDP_BATCH_LIMIT], v[i: i + SDP_BATCH_LIMIT],
@ -527,7 +533,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
try: try:
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e: except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) logger.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
if tensor_layout == "NHD": if tensor_layout == "NHD":
q, k, v = map( q, k, v = map(
lambda t: t.transpose(1, 2), lambda t: t.transpose(1, 2),
@ -551,7 +557,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
try: try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) @torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) # pylint: disable=possibly-used-before-assignment,used-before-assignment return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) # pylint: disable=possibly-used-before-assignment,used-before-assignment
@ -562,8 +568,9 @@ try:
except AttributeError as error: except AttributeError as error:
FLASH_ATTN_ERROR = error FLASH_ATTN_ERROR = error
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
@ -596,7 +603,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
causal=False, causal=False,
).transpose(1, 2) ).transpose(1, 2)
except Exception as e: except Exception as e:
logging.warning(f"Flash Attention failed, using default SDPA: {e}") logger.warning(f"Flash Attention failed, using default SDPA: {e}")
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape: if not skip_output_reshape:
out = ( out = (
@ -616,7 +623,7 @@ elif model_management.xformers_enabled():
logger.debug("Using xformers attention") logger.debug("Using xformers attention")
optimized_attention = attention_xformers optimized_attention = attention_xformers
elif model_management.flash_attention_enabled(): elif model_management.flash_attention_enabled():
logging.debug("Using Flash Attention") logger.debug("Using Flash Attention")
optimized_attention = attention_flash optimized_attention = attention_flash
elif model_management.pytorch_attention_enabled(): elif model_management.pytorch_attention_enabled():
logger.debug("Using pytorch attention") logger.debug("Using pytorch attention")

View File

@ -7,6 +7,7 @@ import torch.nn as nn
from .... import model_management from .... import model_management
from .... import ops from .... import ops
from ....ops import scaled_dot_product_attention
ops = ops.disable_weight_init ops = ops.disable_weight_init
@ -295,7 +296,7 @@ def pytorch_attention(q, k, v):
) )
try: try:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape) out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION: except model_management.OOM_EXCEPTION:
logger.warning("scaled_dot_product_attention OOMed: switched to slice attention") logger.warning("scaled_dot_product_attention OOMed: switched to slice attention")

View File

@ -13,15 +13,13 @@ from os.path import join
from pathlib import Path from pathlib import Path
from typing import List, Optional, Final, Set from typing import List, Optional, Final, Set
from .component_model.hf_hub_download_with_disable_xet import hf_hub_download_with_retries
import tqdm import tqdm
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem from huggingface_hub import dump_environment_info
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem, CacheNotFound
from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError
from requests import Session from requests import Session
from safetensors import safe_open from safetensors import safe_open
from safetensors.torch import save_file from safetensors.torch import save_file
from huggingface_hub import dump_environment_info
from .cli_args import args from .cli_args import args
from .cmd import folder_paths from .cmd import folder_paths
@ -140,11 +138,11 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
try: try:
logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}") logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}")
path = hf_hub_download(repo_id=known_file.repo_id, path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename, filename=known_file.filename,
repo_type=known_file.repo_type, repo_type=known_file.repo_type,
revision=known_file.revision, revision=known_file.revision,
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None, local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
) )
except IOError as exc_info: except IOError as exc_info:
logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info) logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
except Exception as exc_info: except Exception as exc_info:
@ -689,10 +687,14 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
extra_cache_dirs = folder_paths.get_folder_paths("huggingface_cache") extra_cache_dirs = folder_paths.get_folder_paths("huggingface_cache")
# all in cache directories # all in cache directories
try:
default_cache_dir = [scan_cache_dir()]
except CacheNotFound as exc_info:
default_cache_dir = []
existing_repo_ids = frozenset( existing_repo_ids = frozenset(
cache_item.repo_id for cache_item in \ cache_item.repo_id for cache_item in \
reduce(operator.or_, reduce(operator.or_,
map(lambda cache_info: cache_info.repos, [scan_cache_dir()] + [scan_cache_dir(cache_dir=cache_dir) for cache_dir in extra_cache_dirs if os.path.isdir(cache_dir)])) map(lambda cache_info: cache_info.repos, default_cache_dir + [scan_cache_dir(cache_dir=cache_dir) for cache_dir in extra_cache_dirs if os.path.isdir(cache_dir)]))
if cache_item.repo_type == "model" or cache_item.repo_type == "space" if cache_item.repo_type == "model" or cache_item.repo_type == "space"
) )

View File

@ -4,6 +4,7 @@ import logging
from .package_typing import ExportedNodes from .package_typing import ExportedNodes
from comfy_api.latest import ComfyExtension from comfy_api.latest import ComfyExtension
from comfy_api.internal.async_to_sync import AsyncToSyncConverter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,15 +18,13 @@ def _comfy_entrypoint_upstream_v3_imports(module) -> ExportedNodes:
else: else:
if inspect.iscoroutinefunction(entrypoint): if inspect.iscoroutinefunction(entrypoint):
# todo: I seriously doubt anything is going to be an async entrypoint, ever # todo: I seriously doubt anything is going to be an async entrypoint, ever
extension_coro = entrypoint() extension = AsyncToSyncConverter.run_async_in_thread(entrypoint)
extension = asyncio.run(extension_coro)
else: else:
extension = entrypoint() extension = entrypoint()
if not isinstance(extension, ComfyExtension): if not isinstance(extension, ComfyExtension):
logger.debug(f"comfy_entrypoint in {module} did not return a ComfyExtension, skipping.") logger.debug(f"comfy_entrypoint in {module} did not return a ComfyExtension, skipping.")
else: else:
node_list_coro = extension.get_node_list() node_list = AsyncToSyncConverter.run_async_in_thread(extension.get_node_list)
node_list = asyncio.run(node_list_coro)
if not isinstance(node_list, list): if not isinstance(node_list, list):
logger.debug(f"comfy_entrypoint in {module} did not return a list of nodes, skipping.") logger.debug(f"comfy_entrypoint in {module} did not return a list of nodes, skipping.")
else: else:

View File

@ -23,10 +23,11 @@ _nodes_available_at_startup: ExportedNodes = ExportedNodes()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType): def _import_nodes_in_module(module: types.ModuleType) -> ExportedNodes:
node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None) node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None)
node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None) node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None)
web_directory = getattr(module, "WEB_DIRECTORY", None) web_directory = getattr(module, "WEB_DIRECTORY", None)
exported_nodes = ExportedNodes()
if node_class_mappings: if node_class_mappings:
exported_nodes.NODE_CLASS_MAPPINGS.update(node_class_mappings) exported_nodes.NODE_CLASS_MAPPINGS.update(node_class_mappings)
if node_display_names: if node_display_names:
@ -42,7 +43,7 @@ def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleT
raise ImportError(path=abs_web_directory) raise ImportError(path=abs_web_directory)
exported_nodes.EXTENSION_WEB_DIRS[module.__name__] = abs_web_directory exported_nodes.EXTENSION_WEB_DIRS[module.__name__] = abs_web_directory
exported_nodes.update(_comfy_entrypoint_upstream_v3_imports(module)) exported_nodes.update(_comfy_entrypoint_upstream_v3_imports(module))
return node_class_mappings and len(node_class_mappings) > 0 or web_directory return exported_nodes
@ -58,17 +59,19 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
time_before = time.perf_counter() time_before = time.perf_counter()
full_name = module.__name__ full_name = module.__name__
try: try:
any_content_in_module = _import_nodes_in_module(exported_nodes, module) module_exported_nodes = _import_nodes_in_module(module)
span.set_attribute("full_name", full_name) span.set_attribute("full_name", full_name)
timings.append((time.perf_counter() - time_before, full_name, True, exported_nodes)) timings.append((time.perf_counter() - time_before, full_name, True, exported_nodes))
except Exception as exc: except Exception as exc:
any_content_in_module = None module_exported_nodes = None
logger.error(f"{full_name} import failed", exc_info=exc) logger.error(f"{full_name} import failed", exc_info=exc)
span.set_status(Status(StatusCode.ERROR)) span.set_status(Status(StatusCode.ERROR))
span.record_exception(exc) span.record_exception(exc)
exceptions.append(exc) exceptions.append(exc)
if any_content_in_module is None or not any_content_in_module: if module_exported_nodes:
# Iterate through all the submodules exported_nodes.update(module_exported_nodes)
else:
# iterate through all the submodules and try to find exported nodes
for _, name, is_pkg in pkgutil.iter_modules(module.__path__): for _, name, is_pkg in pkgutil.iter_modules(module.__path__):
span: Span span: Span
with tracer.start_as_current_span("Load Node") as span: with tracer.start_as_current_span("Load Node") as span:

View File

@ -195,6 +195,9 @@ class ExportedNodes:
exported_nodes = ExportedNodes().update(self) exported_nodes = ExportedNodes().update(self)
return exported_nodes.update(other) return exported_nodes.update(other)
def __bool__(self):
return len(self.NODE_CLASS_MAPPINGS) + len(self.NODE_DISPLAY_NAME_MAPPINGS) + len(self.EXTENSION_WEB_DIRS) > 0
class _ExportedNodesAsChainMap(ExportedNodes): class _ExportedNodesAsChainMap(ExportedNodes):
@classmethod @classmethod

View File

@ -27,8 +27,6 @@ from .execution_context import current_execution_context
from .float import stochastic_rounding from .float import stochastic_rounding
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
try: try:
@ -52,6 +50,9 @@ try:
except (ModuleNotFoundError, TypeError): except (ModuleNotFoundError, TypeError):
logging.warning("Could not set sdpa backend priority.") logging.warning("Could not set sdpa backend priority.")
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
cast_to = model_management.cast_to # TODO: remove once no more references cast_to = model_management.cast_to # TODO: remove once no more references
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, Tuple from typing import Optional, Tuple
import math import math
from comfy.ldm.modules.attention import optimized_attention_for_device from ..ldm.modules.attention import optimized_attention_for_device
def process_qwen2vl_images( def process_qwen2vl_images(

View File

@ -532,7 +532,7 @@ class ApiClient:
request_method="PUT", request_method="PUT",
request_url=upload_url, request_url=upload_url,
response_status_code=e.status if hasattr(e, "status") else None, response_status_code=e.status if hasattr(e, "status") else None,
response_headers=dict(e.headers) if getattr(e, "headers") else None, response_headers=dict(e.headers) if getattr(e, "headers") else None, # pylint: disable=no-member
response_content=None, response_content=None,
error_message=f"{type(e).__name__}: {str(e)}", error_message=f"{type(e).__name__}: {str(e)}",
) )

View File

@ -656,6 +656,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
async def generate( async def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
): ):
image_url: Optional[str] = None
video = kwargs.get("video") video = kwargs.get("video")
image = kwargs.get("image", None) image = kwargs.get("image", None)

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
import comfy.context_windows import comfy.context_windows
import nodes from comfy.nodes import base_nodes as nodes
class ContextWindowsManualNode(io.ComfyNode): class ContextWindowsManualNode(io.ComfyNode):

View File

@ -1,5 +1,5 @@
import torch import torch
import folder_paths from comfy.cmd import folder_paths # pylint: disable=no-name-in-module
import comfy.utils import comfy.utils
import comfy.ops import comfy.ops
import comfy.model_management import comfy.model_management
@ -27,12 +27,12 @@ class BlockWiseControlBlock(torch.nn.Module):
class QwenImageBlockWiseControlNet(torch.nn.Module): class QwenImageBlockWiseControlNet(torch.nn.Module):
def __init__( def __init__(
self, self,
num_layers: int = 60, num_layers: int = 60,
in_dim: int = 64, in_dim: int = 64,
additional_in_dim: int = 0, additional_in_dim: int = 0,
dim: int = 3072, dim: int = 3072,
device=None, dtype=None, operations=None device=None, dtype=None, operations=None
): ):
super().__init__() super().__init__()
self.additional_in_dim = additional_in_dim self.additional_in_dim = additional_in_dim
@ -61,8 +61,9 @@ class QwenImageBlockWiseControlNet(torch.nn.Module):
class ModelPatchLoader: class ModelPatchLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "name": (folder_paths.get_filename_list("model_patches"), ), return {"required": {"name": (folder_paths.get_filename_list("model_patches"),),
}} }}
RETURN_TYPES = ("MODEL_PATCH",) RETURN_TYPES = ("MODEL_PATCH",)
FUNCTION = "load_model_patch" FUNCTION = "load_model_patch"
EXPERIMENTAL = True EXPERIMENTAL = True
@ -125,16 +126,18 @@ class DiffSynthCnetPatch:
def models(self): def models(self):
return [self.model_patch] return [self.model_patch]
class QwenImageDiffsynthControlnet: class QwenImageDiffsynthControlnet:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": {"model": ("MODEL",),
"model_patch": ("MODEL_PATCH",), "model_patch": ("MODEL_PATCH",),
"vae": ("VAE",), "vae": ("VAE",),
"image": ("IMAGE",), "image": ("IMAGE",),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}, },
"optional": {"mask": ("MASK",)}} "optional": {"mask": ("MASK",)}}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "diffsynth_controlnet" FUNCTION = "diffsynth_controlnet"
EXPERIMENTAL = True EXPERIMENTAL = True

View File

@ -1,4 +1,4 @@
import node_helpers from comfy import node_helpers
import comfy.utils import comfy.utils
import math import math
@ -7,11 +7,11 @@ class TextEncodeQwenImageEdit:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { return {"required": {
"clip": ("CLIP", ), "clip": ("CLIP",),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}, },
"optional": {"vae": ("VAE", ), "optional": {"vae": ("VAE",),
"image": ("IMAGE", ),}} "image": ("IMAGE",), }}
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode" FUNCTION = "encode"
@ -40,7 +40,7 @@ class TextEncodeQwenImageEdit:
conditioning = clip.encode_from_tokens_scheduled(tokens) conditioning = clip.encode_from_tokens_scheduled(tokens)
if ref_latent is not None: if ref_latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True) conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True)
return (conditioning, ) return (conditioning,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

View File

@ -119,7 +119,7 @@ where = ["."]
include = ["comfy*"] include = ["comfy*"]
namespaces = false namespaces = false
[dependency-groups] [project.optional-dependencies]
dev = [ dev = [
"pytest", "pytest",
"pytest-asyncio", "pytest-asyncio",
@ -137,7 +137,6 @@ dev = [
"astroid", "astroid",
] ]
[project.optional-dependencies]
cpu = [ cpu = [
"torch", "torch",
"torchvision", "torchvision",