Update to 0.3.51

This commit is contained in:
doctorpangloss 2025-08-22 17:29:18 -07:00
parent 664349eabf
commit 735a133ad4
19 changed files with 131 additions and 96 deletions

View File

@ -78,7 +78,6 @@ class ModelFileManager:
return web.Response(status=404)
def get_model_file_list(self, folder_name: str):
folder_name = folder_paths.map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name]
output_list: list[dict] = []

View File

@ -45,6 +45,7 @@ logging.getLogger("torch.distributed.elastic.multiprocessing.redirects").addFilt
lambda record: log_msg_to_filter not in record.getMessage()
)
logging.getLogger("alembic.runtime.migration").setLevel(logging.WARNING)
logging.getLogger("asyncio").addFilter(lambda record: 'Using selector:' not in record.getMessage())
from ..cli_args import args

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import logging
import os
import platform
import time
from concurrent.futures import Future
from pathlib import Path
@ -17,16 +18,19 @@ from pebble import ThreadPool
from .tqdm_watcher import TqdmWatcher
logger = logging.getLogger(__name__)
_VAR = "HF_HUB_ENABLE_HF_TRANSFER"
_XET_VAR = "HF_XET_HIGH_PERFORMANCE"
os.environ[_VAR] = "True"
os.environ["HF_HUB_DISABLE_XET"] = "1"
# os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
if platform.system() == "Windows":
os.environ["HF_HUB_DISABLE_XET"] = "1"
logger.debug("Xet was disabled since it is currently not reliable")
os.environ[_VAR] = "True"
else:
os.environ[_XET_VAR] = "True"
logger = logging.getLogger(__name__)
logger.debug("Xet was disabled since it is currently not reliable")
def hf_hub_download_with_disable_fast(repo_id=None, filename=None, disable_fast=None, hf_env: dict[str, str] = None, **kwargs):

View File

@ -6,12 +6,13 @@ import collections
from dataclasses import dataclass
from abc import ABC, abstractmethod
import logging
import comfy.model_management
import comfy.patcher_extension
from .model_management import throw_exception_if_processing_interrupted
from .patcher_extension import get_all_callbacks, WrappersMP
if TYPE_CHECKING:
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.controlnet import ControlBase
from .model_base import BaseModel
from .model_patcher import ModelPatcher
from .controlnet import ControlBase
class ContextWindowABC(ABC):
@ -32,6 +33,7 @@ class ContextWindowABC(ABC):
"""
raise NotImplementedError("Not implemented.")
class ContextHandlerABC(ABC):
def __init__(self):
...
@ -49,9 +51,8 @@ class ContextHandlerABC(ABC):
raise NotImplementedError("Not implemented.")
class IndexListContextWindow(ContextWindowABC):
def __init__(self, index_list: list[int], dim: int=0):
def __init__(self, index_list: list[int], dim: int = 0):
self.index_list = index_list
self.context_length = len(index_list)
self.dim = dim
@ -87,14 +88,18 @@ class ContextSchedule:
name: str
func: Callable
@dataclass
class ContextFuseMethod:
name: str
func: Callable
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int = 1, context_overlap: int = 0, context_stride: int = 1, closed_loop=False, dim=0):
self.context_schedule = context_schedule
self.fuse_method = fuse_method
self.context_length = context_length
@ -152,7 +157,7 @@ class IndexListContextHandler(ContextHandlerABC):
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
elif cond_key == "num_video_frames": # for SVD
elif cond_key == "num_video_frames": # for SVD
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
new_cond_item[cond_key].cond = window.context_length
resized_actual_cond[key] = new_cond_item
@ -171,7 +176,7 @@ class IndexListContextHandler(ContextHandlerABC):
self._step = int(matches[0].item())
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model
full_length = x_in.size(self.dim) # TODO: choose dim based on model
context_windows = self.context_schedule.func(full_length, self, model_options)
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
return context_windows
@ -188,14 +193,14 @@ class IndexListContextHandler(ContextHandlerABC):
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
for callback in get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options)
for enum_window in enumerated_context_windows:
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
for result in results:
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
conds_final, counts_final, biases_final)
conds_final, counts_final, biases_final)
try:
# finalize conds
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
@ -209,17 +214,17 @@ class IndexListContextHandler(ContextHandlerABC):
del counts_final
return conds_final
finally:
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
for callback in get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options)
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
model_options, device=None, first_device=None):
model_options, device=None, first_device=None):
results: list[ContextResults] = []
for window_idx, window in enumerated_context_windows:
# allow processing to end between context window executions for faster Cancel
comfy.model_management.throw_exception_if_processing_interrupted()
throw_exception_if_processing_interrupted()
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
for callback in get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
# update exposed params
@ -236,9 +241,8 @@ class IndexListContextHandler(ContextHandlerABC):
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
return results
def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor,
conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
for pos, idx in enumerate(window.index_list):
# bias is the influence of a specific index in relation to the whole context window
@ -263,7 +267,7 @@ class IndexListContextHandler(ContextHandlerABC):
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
window.add_window(counts_final[i], weights_tensor)
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
for callback in get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
@ -281,7 +285,7 @@ def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args,
def create_prepare_sampling_wrapper(model: ModelPatcher):
model.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING,
WrappersMP.PREPARE_SAMPLING,
"ContextWindows_prepare_sampling",
_prepare_sampling_wrapper
)
@ -296,6 +300,7 @@ def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, dev
weights_tensor = weights_tensor.unsqueeze(-1)
return weights_tensor
def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
total_dims = len(x_in.shape)
shape = []
@ -306,6 +311,7 @@ def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
shape.append(1)
return shape
class ContextSchedules:
UNIFORM_LOOPED = "looped_uniform"
UNIFORM_STANDARD = "standard_uniform"
@ -325,14 +331,15 @@ def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHand
for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(handler._step)))
for j in range(
int(ordered_halving(handler._step) * context_step) + pad,
num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
(handler.context_length * context_step - handler.context_overlap),
int(ordered_halving(handler._step) * context_step) + pad,
num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
(handler.context_length * context_step - handler.context_overlap),
):
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
return windows
def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
# unlike looped, uniform_straight does NOT allow windows that loop back to the beginning;
# instead, they get shifted to the corresponding end of the frames.
@ -347,9 +354,9 @@ def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHa
for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(handler._step)))
for j in range(
int(ordered_halving(handler._step) * context_step) + pad,
num_frames + pad + (-handler.context_overlap),
(handler.context_length * context_step - handler.context_overlap),
int(ordered_halving(handler._step) * context_step) + pad,
num_frames + pad + (-handler.context_overlap),
(handler.context_length * context_step - handler.context_overlap),
):
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
@ -363,9 +370,9 @@ def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHa
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
shift_window_to_end(windows[win_i], num_frames=num_frames)
# check if next window (cyclical) is missing roll_val
if roll_val not in windows[(win_i+1) % len(windows)]:
if roll_val not in windows[(win_i + 1) % len(windows)]:
# need to insert new window here - just insert window starting at roll_val
windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length)))
windows.insert(win_i + 1, list(range(roll_val, roll_val + handler.context_length)))
# delete window if it's not unique
for pre_i in range(0, win_i):
if windows[win_i] == windows[pre_i]:
@ -432,7 +439,7 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
return ContextSchedule(context_schedule, func)
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor = None):
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
@ -440,6 +447,7 @@ def create_weights_flat(length: int, **kwargs) -> list[float]:
# weight is the same for all
return [1.0] * length
def create_weights_pyramid(length: int, **kwargs) -> list[float]:
# weight is based on the distance away from the edge of the context window;
# based on weighted average concept in FreeNoise paper
@ -451,6 +459,7 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
return weight_sequence
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
# only expected overlap is given different weights
@ -460,11 +469,12 @@ def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int]
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
weights_torch[:handler.context_overlap] = ramp_up
# blend right-side on all except last window
if max(idxs) < full_length-1:
if max(idxs) < full_length - 1:
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
weights_torch[-handler.context_overlap:] = ramp_down
return weights_torch
class ContextFuseMethods:
FLAT = "flat"
PYRAMID = "pyramid"
@ -482,12 +492,14 @@ FUSE_MAPPING = {
ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear,
}
def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod:
func = FUSE_MAPPING.get(fuse_method, None)
if func is None:
raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
return ContextFuseMethod(fuse_method, func)
# Returns fraction that has denominator that is a power of 2
def ordered_halving(val):
# get binary value, padded with 0s for 64 bits

View File

@ -662,7 +662,7 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
latent_format = comfy.latent_formats.Wan21()
latent_format = latent_formats.Wan21()
extra_conds = []
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control

View File

@ -5,30 +5,37 @@ from typing import Optional
import torch
import torch.nn.functional as F
from diffusers.models.attention_dispatch import sageattn
from einops import rearrange, repeat
from torch import nn, einsum
from .diffusionmodules.util import AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention
from ... import model_management
from ...ops import scaled_dot_product_attention
logger = logging.getLogger(__name__)
if model_management.xformers_enabled():
import xformers # pylint: disable=import-error
import xformers.ops # pylint: disable=import-error
sageattn = None
if model_management.sage_attention_enabled():
try:
from sageattention import sageattn # pylint: disable=import-error
except ModuleNotFoundError as e:
if e.name == "sageattention":
import sys
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
logger.error(f"To use the `--use-sage-attention` feature, the `sageattention` package must be installed first.")
else:
raise e
sageattn = torch.nn.functional.scaled_dot_product_attention
else:
sageattn = torch.nn.functional.scaled_dot_product_attention
flash_attn_func = None
if model_management.flash_attention_enabled():
from flash_attn import flash_attn_func # pylint: disable=import-error
else:
@ -40,7 +47,6 @@ from ... import ops
ops = ops.disable_weight_init
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
logger = logging.getLogger(__name__)
def get_attn_precision(attn_precision, current_dtype):
@ -480,7 +486,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
mask = mask.unsqueeze(1)
if SDP_BATCH_LIMIT >= b:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -493,7 +499,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.shape[0] > 1:
m = mask[i: i + SDP_BATCH_LIMIT]
out[i: i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention(
out[i: i + SDP_BATCH_LIMIT] = scaled_dot_product_attention(
q[i: i + SDP_BATCH_LIMIT],
k[i: i + SDP_BATCH_LIMIT],
v[i: i + SDP_BATCH_LIMIT],
@ -527,7 +533,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
try:
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
logger.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
if tensor_layout == "NHD":
q, k, v = map(
lambda t: t.transpose(1, 2),
@ -551,7 +557,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) # pylint: disable=possibly-used-before-assignment,used-before-assignment
@ -562,8 +568,9 @@ try:
except AttributeError as error:
FLASH_ATTN_ERROR = error
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
@ -596,7 +603,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
causal=False,
).transpose(1, 2)
except Exception as e:
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
logger.warning(f"Flash Attention failed, using default SDPA: {e}")
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = (
@ -616,7 +623,7 @@ elif model_management.xformers_enabled():
logger.debug("Using xformers attention")
optimized_attention = attention_xformers
elif model_management.flash_attention_enabled():
logging.debug("Using Flash Attention")
logger.debug("Using Flash Attention")
optimized_attention = attention_flash
elif model_management.pytorch_attention_enabled():
logger.debug("Using pytorch attention")

View File

@ -7,6 +7,7 @@ import torch.nn as nn
from .... import model_management
from .... import ops
from ....ops import scaled_dot_product_attention
ops = ops.disable_weight_init
@ -295,7 +296,7 @@ def pytorch_attention(q, k, v):
)
try:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logger.warning("scaled_dot_product_attention OOMed: switched to slice attention")

View File

@ -13,15 +13,13 @@ from os.path import join
from pathlib import Path
from typing import List, Optional, Final, Set
from .component_model.hf_hub_download_with_disable_xet import hf_hub_download_with_retries
import tqdm
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
from huggingface_hub import dump_environment_info
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem, CacheNotFound
from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError
from requests import Session
from safetensors import safe_open
from safetensors.torch import save_file
from huggingface_hub import dump_environment_info
from .cli_args import args
from .cmd import folder_paths
@ -140,11 +138,11 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
try:
logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}")
path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename,
repo_type=known_file.repo_type,
revision=known_file.revision,
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
)
filename=known_file.filename,
repo_type=known_file.repo_type,
revision=known_file.revision,
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
)
except IOError as exc_info:
logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
except Exception as exc_info:
@ -689,10 +687,14 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
extra_cache_dirs = folder_paths.get_folder_paths("huggingface_cache")
# all in cache directories
try:
default_cache_dir = [scan_cache_dir()]
except CacheNotFound as exc_info:
default_cache_dir = []
existing_repo_ids = frozenset(
cache_item.repo_id for cache_item in \
reduce(operator.or_,
map(lambda cache_info: cache_info.repos, [scan_cache_dir()] + [scan_cache_dir(cache_dir=cache_dir) for cache_dir in extra_cache_dirs if os.path.isdir(cache_dir)]))
map(lambda cache_info: cache_info.repos, default_cache_dir + [scan_cache_dir(cache_dir=cache_dir) for cache_dir in extra_cache_dirs if os.path.isdir(cache_dir)]))
if cache_item.repo_type == "model" or cache_item.repo_type == "space"
)

View File

@ -4,6 +4,7 @@ import logging
from .package_typing import ExportedNodes
from comfy_api.latest import ComfyExtension
from comfy_api.internal.async_to_sync import AsyncToSyncConverter
logger = logging.getLogger(__name__)
@ -17,15 +18,13 @@ def _comfy_entrypoint_upstream_v3_imports(module) -> ExportedNodes:
else:
if inspect.iscoroutinefunction(entrypoint):
# todo: I seriously doubt anything is going to be an async entrypoint, ever
extension_coro = entrypoint()
extension = asyncio.run(extension_coro)
extension = AsyncToSyncConverter.run_async_in_thread(entrypoint)
else:
extension = entrypoint()
if not isinstance(extension, ComfyExtension):
logger.debug(f"comfy_entrypoint in {module} did not return a ComfyExtension, skipping.")
else:
node_list_coro = extension.get_node_list()
node_list = asyncio.run(node_list_coro)
node_list = AsyncToSyncConverter.run_async_in_thread(extension.get_node_list)
if not isinstance(node_list, list):
logger.debug(f"comfy_entrypoint in {module} did not return a list of nodes, skipping.")
else:

View File

@ -23,10 +23,11 @@ _nodes_available_at_startup: ExportedNodes = ExportedNodes()
logger = logging.getLogger(__name__)
def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType):
def _import_nodes_in_module(module: types.ModuleType) -> ExportedNodes:
node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None)
node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None)
web_directory = getattr(module, "WEB_DIRECTORY", None)
exported_nodes = ExportedNodes()
if node_class_mappings:
exported_nodes.NODE_CLASS_MAPPINGS.update(node_class_mappings)
if node_display_names:
@ -42,7 +43,7 @@ def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleT
raise ImportError(path=abs_web_directory)
exported_nodes.EXTENSION_WEB_DIRS[module.__name__] = abs_web_directory
exported_nodes.update(_comfy_entrypoint_upstream_v3_imports(module))
return node_class_mappings and len(node_class_mappings) > 0 or web_directory
return exported_nodes
@ -58,17 +59,19 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
time_before = time.perf_counter()
full_name = module.__name__
try:
any_content_in_module = _import_nodes_in_module(exported_nodes, module)
module_exported_nodes = _import_nodes_in_module(module)
span.set_attribute("full_name", full_name)
timings.append((time.perf_counter() - time_before, full_name, True, exported_nodes))
except Exception as exc:
any_content_in_module = None
module_exported_nodes = None
logger.error(f"{full_name} import failed", exc_info=exc)
span.set_status(Status(StatusCode.ERROR))
span.record_exception(exc)
exceptions.append(exc)
if any_content_in_module is None or not any_content_in_module:
# Iterate through all the submodules
if module_exported_nodes:
exported_nodes.update(module_exported_nodes)
else:
# iterate through all the submodules and try to find exported nodes
for _, name, is_pkg in pkgutil.iter_modules(module.__path__):
span: Span
with tracer.start_as_current_span("Load Node") as span:

View File

@ -195,6 +195,9 @@ class ExportedNodes:
exported_nodes = ExportedNodes().update(self)
return exported_nodes.update(other)
def __bool__(self):
return len(self.NODE_CLASS_MAPPINGS) + len(self.NODE_DISPLAY_NAME_MAPPINGS) + len(self.EXTENSION_WEB_DIRS) > 0
class _ExportedNodesAsChainMap(ExportedNodes):
@classmethod

View File

@ -27,8 +27,6 @@ from .execution_context import current_execution_context
from .float import stochastic_rounding
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
try:
@ -52,6 +50,9 @@ try:
except (ModuleNotFoundError, TypeError):
logging.warning("Could not set sdpa backend priority.")
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
cast_to = model_management.cast_to # TODO: remove once no more references
logger = logging.getLogger(__name__)

View File

@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import math
from comfy.ldm.modules.attention import optimized_attention_for_device
from ..ldm.modules.attention import optimized_attention_for_device
def process_qwen2vl_images(

View File

@ -532,7 +532,7 @@ class ApiClient:
request_method="PUT",
request_url=upload_url,
response_status_code=e.status if hasattr(e, "status") else None,
response_headers=dict(e.headers) if getattr(e, "headers") else None,
response_headers=dict(e.headers) if getattr(e, "headers") else None, # pylint: disable=no-member
response_content=None,
error_message=f"{type(e).__name__}: {str(e)}",
)

View File

@ -656,6 +656,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
async def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
image_url: Optional[str] = None
video = kwargs.get("video")
image = kwargs.get("image", None)

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from comfy_api.latest import ComfyExtension, io
import comfy.context_windows
import nodes
from comfy.nodes import base_nodes as nodes
class ContextWindowsManualNode(io.ComfyNode):

View File

@ -1,5 +1,5 @@
import torch
import folder_paths
from comfy.cmd import folder_paths # pylint: disable=no-name-in-module
import comfy.utils
import comfy.ops
import comfy.model_management
@ -27,12 +27,12 @@ class BlockWiseControlBlock(torch.nn.Module):
class QwenImageBlockWiseControlNet(torch.nn.Module):
def __init__(
self,
num_layers: int = 60,
in_dim: int = 64,
additional_in_dim: int = 0,
dim: int = 3072,
device=None, dtype=None, operations=None
self,
num_layers: int = 60,
in_dim: int = 64,
additional_in_dim: int = 0,
dim: int = 3072,
device=None, dtype=None, operations=None
):
super().__init__()
self.additional_in_dim = additional_in_dim
@ -61,8 +61,9 @@ class QwenImageBlockWiseControlNet(torch.nn.Module):
class ModelPatchLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "name": (folder_paths.get_filename_list("model_patches"), ),
}}
return {"required": {"name": (folder_paths.get_filename_list("model_patches"),),
}}
RETURN_TYPES = ("MODEL_PATCH",)
FUNCTION = "load_model_patch"
EXPERIMENTAL = True
@ -125,16 +126,18 @@ class DiffSynthCnetPatch:
def models(self):
return [self.model_patch]
class QwenImageDiffsynthControlnet:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"model_patch": ("MODEL_PATCH",),
"vae": ("VAE",),
"image": ("IMAGE",),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
},
return {"required": {"model": ("MODEL",),
"model_patch": ("MODEL_PATCH",),
"vae": ("VAE",),
"image": ("IMAGE",),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
},
"optional": {"mask": ("MASK",)}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "diffsynth_controlnet"
EXPERIMENTAL = True

View File

@ -1,4 +1,4 @@
import node_helpers
from comfy import node_helpers
import comfy.utils
import math
@ -7,11 +7,11 @@ class TextEncodeQwenImageEdit:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip": ("CLIP",),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
},
"optional": {"vae": ("VAE", ),
"image": ("IMAGE", ),}}
},
"optional": {"vae": ("VAE",),
"image": ("IMAGE",), }}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
@ -40,7 +40,7 @@ class TextEncodeQwenImageEdit:
conditioning = clip.encode_from_tokens_scheduled(tokens)
if ref_latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True)
return (conditioning, )
return (conditioning,)
NODE_CLASS_MAPPINGS = {

View File

@ -119,7 +119,7 @@ where = ["."]
include = ["comfy*"]
namespaces = false
[dependency-groups]
[project.optional-dependencies]
dev = [
"pytest",
"pytest-asyncio",
@ -137,7 +137,6 @@ dev = [
"astroid",
]
[project.optional-dependencies]
cpu = [
"torch",
"torchvision",