diff --git a/comfy/app/model_manager.py b/comfy/app/model_manager.py index 3f3107046..661b4ade8 100644 --- a/comfy/app/model_manager.py +++ b/comfy/app/model_manager.py @@ -78,7 +78,6 @@ class ModelFileManager: return web.Response(status=404) def get_model_file_list(self, folder_name: str): - folder_name = folder_paths.map_legacy(folder_name) folders = folder_paths.folder_names_and_paths[folder_name] output_list: list[dict] = [] diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index ab5722b38..8c979986b 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -45,6 +45,7 @@ logging.getLogger("torch.distributed.elastic.multiprocessing.redirects").addFilt lambda record: log_msg_to_filter not in record.getMessage() ) logging.getLogger("alembic.runtime.migration").setLevel(logging.WARNING) +logging.getLogger("asyncio").addFilter(lambda record: 'Using selector:' not in record.getMessage()) from ..cli_args import args diff --git a/comfy/component_model/hf_hub_download_with_disable_xet.py b/comfy/component_model/hf_hub_download_with_disable_xet.py index 371737f3b..48e7333a0 100644 --- a/comfy/component_model/hf_hub_download_with_disable_xet.py +++ b/comfy/component_model/hf_hub_download_with_disable_xet.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging import os +import platform import time from concurrent.futures import Future from pathlib import Path @@ -17,16 +18,19 @@ from pebble import ThreadPool from .tqdm_watcher import TqdmWatcher +logger = logging.getLogger(__name__) + _VAR = "HF_HUB_ENABLE_HF_TRANSFER" _XET_VAR = "HF_XET_HIGH_PERFORMANCE" -os.environ[_VAR] = "True" -os.environ["HF_HUB_DISABLE_XET"] = "1" -# os.environ["HF_XET_HIGH_PERFORMANCE"] = "True" +if platform.system() == "Windows": + os.environ["HF_HUB_DISABLE_XET"] = "1" + logger.debug("Xet was disabled since it is currently not reliable") + os.environ[_VAR] = "True" +else: + os.environ[_XET_VAR] = "True" -logger = logging.getLogger(__name__) -logger.debug("Xet was disabled since it is currently not reliable") def hf_hub_download_with_disable_fast(repo_id=None, filename=None, disable_fast=None, hf_env: dict[str, str] = None, **kwargs): diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 041f380f9..ed007e8db 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -6,12 +6,13 @@ import collections from dataclasses import dataclass from abc import ABC, abstractmethod import logging -import comfy.model_management -import comfy.patcher_extension +from .model_management import throw_exception_if_processing_interrupted +from .patcher_extension import get_all_callbacks, WrappersMP + if TYPE_CHECKING: - from comfy.model_base import BaseModel - from comfy.model_patcher import ModelPatcher - from comfy.controlnet import ControlBase + from .model_base import BaseModel + from .model_patcher import ModelPatcher + from .controlnet import ControlBase class ContextWindowABC(ABC): @@ -32,6 +33,7 @@ class ContextWindowABC(ABC): """ raise NotImplementedError("Not implemented.") + class ContextHandlerABC(ABC): def __init__(self): ... @@ -49,9 +51,8 @@ class ContextHandlerABC(ABC): raise NotImplementedError("Not implemented.") - class IndexListContextWindow(ContextWindowABC): - def __init__(self, index_list: list[int], dim: int=0): + def __init__(self, index_list: list[int], dim: int = 0): self.index_list = index_list self.context_length = len(index_list) self.dim = dim @@ -87,14 +88,18 @@ class ContextSchedule: name: str func: Callable + @dataclass class ContextFuseMethod: name: str func: Callable + ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window']) + + class IndexListContextHandler(ContextHandlerABC): - def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0): + def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int = 1, context_overlap: int = 0, context_stride: int = 1, closed_loop=False, dim=0): self.context_schedule = context_schedule self.fuse_method = fuse_method self.context_length = context_length @@ -152,7 +157,7 @@ class IndexListContextHandler(ContextHandlerABC): elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim): new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device)) - elif cond_key == "num_video_frames": # for SVD + elif cond_key == "num_video_frames": # for SVD new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond) new_cond_item[cond_key].cond = window.context_length resized_actual_cond[key] = new_cond_item @@ -171,7 +176,7 @@ class IndexListContextHandler(ContextHandlerABC): self._step = int(matches[0].item()) def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: - full_length = x_in.size(self.dim) # TODO: choose dim based on model + full_length = x_in.size(self.dim) # TODO: choose dim based on model context_windows = self.context_schedule.func(full_length, self, model_options) context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows] return context_windows @@ -188,14 +193,14 @@ class IndexListContextHandler(ContextHandlerABC): counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds] - for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): + for callback in get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) for enum_window in enumerated_context_windows: results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options) for result in results: self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep, - conds_final, counts_final, biases_final) + conds_final, counts_final, biases_final) try: # finalize conds if self.fuse_method.name == ContextFuseMethods.RELATIVE: @@ -209,17 +214,17 @@ class IndexListContextHandler(ContextHandlerABC): del counts_final return conds_final finally: - for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): + for callback in get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], - model_options, device=None, first_device=None): + model_options, device=None, first_device=None): results: list[ContextResults] = [] for window_idx, window in enumerated_context_windows: # allow processing to end between context window executions for faster Cancel - comfy.model_management.throw_exception_if_processing_interrupted() + throw_exception_if_processing_interrupted() - for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): + for callback in get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) # update exposed params @@ -236,9 +241,8 @@ class IndexListContextHandler(ContextHandlerABC): results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window)) return results - def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor, - conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]): + conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]): if self.fuse_method.name == ContextFuseMethods.RELATIVE: for pos, idx in enumerate(window.index_list): # bias is the influence of a specific index in relation to the whole context window @@ -263,7 +267,7 @@ class IndexListContextHandler(ContextHandlerABC): window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor) window.add_window(counts_final[i], weights_tensor) - for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks): + for callback in get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks): callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final) @@ -281,7 +285,7 @@ def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, def create_prepare_sampling_wrapper(model: ModelPatcher): model.add_wrapper_with_key( - comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, + WrappersMP.PREPARE_SAMPLING, "ContextWindows_prepare_sampling", _prepare_sampling_wrapper ) @@ -296,6 +300,7 @@ def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, dev weights_tensor = weights_tensor.unsqueeze(-1) return weights_tensor + def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]: total_dims = len(x_in.shape) shape = [] @@ -306,6 +311,7 @@ def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]: shape.append(1) return shape + class ContextSchedules: UNIFORM_LOOPED = "looped_uniform" UNIFORM_STANDARD = "standard_uniform" @@ -325,14 +331,15 @@ def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHand for context_step in 1 << np.arange(context_stride): pad = int(round(num_frames * ordered_halving(handler._step))) for j in range( - int(ordered_halving(handler._step) * context_step) + pad, - num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap), - (handler.context_length * context_step - handler.context_overlap), + int(ordered_halving(handler._step) * context_step) + pad, + num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap), + (handler.context_length * context_step - handler.context_overlap), ): windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)]) return windows + def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]): # unlike looped, uniform_straight does NOT allow windows that loop back to the beginning; # instead, they get shifted to the corresponding end of the frames. @@ -347,9 +354,9 @@ def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHa for context_step in 1 << np.arange(context_stride): pad = int(round(num_frames * ordered_halving(handler._step))) for j in range( - int(ordered_halving(handler._step) * context_step) + pad, - num_frames + pad + (-handler.context_overlap), - (handler.context_length * context_step - handler.context_overlap), + int(ordered_halving(handler._step) * context_step) + pad, + num_frames + pad + (-handler.context_overlap), + (handler.context_length * context_step - handler.context_overlap), ): windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)]) @@ -363,9 +370,9 @@ def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHa roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides shift_window_to_end(windows[win_i], num_frames=num_frames) # check if next window (cyclical) is missing roll_val - if roll_val not in windows[(win_i+1) % len(windows)]: + if roll_val not in windows[(win_i + 1) % len(windows)]: # need to insert new window here - just insert window starting at roll_val - windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length))) + windows.insert(win_i + 1, list(range(roll_val, roll_val + handler.context_length))) # delete window if it's not unique for pre_i in range(0, win_i): if windows[win_i] == windows[pre_i]: @@ -432,7 +439,7 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule: return ContextSchedule(context_schedule, func) -def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None): +def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor = None): return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs) @@ -440,6 +447,7 @@ def create_weights_flat(length: int, **kwargs) -> list[float]: # weight is the same for all return [1.0] * length + def create_weights_pyramid(length: int, **kwargs) -> list[float]: # weight is based on the distance away from the edge of the context window; # based on weighted average concept in FreeNoise paper @@ -451,6 +459,7 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]: weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) return weight_sequence + def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs): # based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302 # only expected overlap is given different weights @@ -460,11 +469,12 @@ def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int] ramp_up = torch.linspace(1e-37, 1, handler.context_overlap) weights_torch[:handler.context_overlap] = ramp_up # blend right-side on all except last window - if max(idxs) < full_length-1: + if max(idxs) < full_length - 1: ramp_down = torch.linspace(1, 1e-37, handler.context_overlap) weights_torch[-handler.context_overlap:] = ramp_down return weights_torch + class ContextFuseMethods: FLAT = "flat" PYRAMID = "pyramid" @@ -482,12 +492,14 @@ FUSE_MAPPING = { ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear, } + def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod: func = FUSE_MAPPING.get(fuse_method, None) if func is None: raise ValueError(f"Unknown fuse_method '{fuse_method}'.") return ContextFuseMethod(fuse_method, func) + # Returns fraction that has denominator that is a power of 2 def ordered_halving(val): # get binary value, padded with 0s for 64 bits diff --git a/comfy/controlnet.py b/comfy/controlnet.py index b9b3d25d5..c0d0ae589 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -662,7 +662,7 @@ def load_controlnet_qwen_instantx(sd, model_options={}): model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options) control_model = QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, sd) - latent_format = comfy.latent_formats.Wan21() + latent_format = latent_formats.Wan21() extra_conds = [] control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 5fdbeaf93..185c81958 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -5,30 +5,37 @@ from typing import Optional import torch import torch.nn.functional as F +from diffusers.models.attention_dispatch import sageattn from einops import rearrange, repeat from torch import nn, einsum from .diffusionmodules.util import AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention from ... import model_management +from ...ops import scaled_dot_product_attention + +logger = logging.getLogger(__name__) if model_management.xformers_enabled(): import xformers # pylint: disable=import-error import xformers.ops # pylint: disable=import-error +sageattn = None if model_management.sage_attention_enabled(): try: from sageattention import sageattn # pylint: disable=import-error except ModuleNotFoundError as e: if e.name == "sageattention": import sys - logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") + + logger.error(f"To use the `--use-sage-attention` feature, the `sageattention` package must be installed first.") else: raise e + sageattn = torch.nn.functional.scaled_dot_product_attention else: sageattn = torch.nn.functional.scaled_dot_product_attention - +flash_attn_func = None if model_management.flash_attention_enabled(): from flash_attn import flash_attn_func # pylint: disable=import-error else: @@ -40,7 +47,6 @@ from ... import ops ops = ops.disable_weight_init FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype() -logger = logging.getLogger(__name__) def get_attn_precision(attn_precision, current_dtype): @@ -480,7 +486,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha mask = mask.unsqueeze(1) if SDP_BATCH_LIMIT >= b: - out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -493,7 +499,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if mask.shape[0] > 1: m = mask[i: i + SDP_BATCH_LIMIT] - out[i: i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention( + out[i: i + SDP_BATCH_LIMIT] = scaled_dot_product_attention( q[i: i + SDP_BATCH_LIMIT], k[i: i + SDP_BATCH_LIMIT], v[i: i + SDP_BATCH_LIMIT], @@ -527,7 +533,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= try: out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) except Exception as e: - logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) + logger.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) if tensor_layout == "NHD": q, k, v = map( lambda t: t.transpose(1, 2), @@ -551,7 +557,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= try: @torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: + dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) # pylint: disable=possibly-used-before-assignment,used-before-assignment @@ -562,8 +568,9 @@ try: except AttributeError as error: FLASH_ATTN_ERROR = error + def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: + dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" @@ -596,7 +603,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape causal=False, ).transpose(1, 2) except Exception as e: - logging.warning(f"Flash Attention failed, using default SDPA: {e}") + logger.warning(f"Flash Attention failed, using default SDPA: {e}") out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) if not skip_output_reshape: out = ( @@ -616,7 +623,7 @@ elif model_management.xformers_enabled(): logger.debug("Using xformers attention") optimized_attention = attention_xformers elif model_management.flash_attention_enabled(): - logging.debug("Using Flash Attention") + logger.debug("Using Flash Attention") optimized_attention = attention_flash elif model_management.pytorch_attention_enabled(): logger.debug("Using pytorch attention") diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 906011ad8..1f7c5d6de 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -7,6 +7,7 @@ import torch.nn as nn from .... import model_management from .... import ops +from ....ops import scaled_dot_product_attention ops = ops.disable_weight_init @@ -295,7 +296,7 @@ def pytorch_attention(q, k, v): ) try: - out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) except model_management.OOM_EXCEPTION: logger.warning("scaled_dot_product_attention OOMed: switched to slice attention") diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 927361dc7..329941159 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -13,15 +13,13 @@ from os.path import join from pathlib import Path from typing import List, Optional, Final, Set -from .component_model.hf_hub_download_with_disable_xet import hf_hub_download_with_retries - import tqdm -from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem +from huggingface_hub import dump_environment_info +from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem, CacheNotFound from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError from requests import Session from safetensors import safe_open from safetensors.torch import save_file -from huggingface_hub import dump_environment_info from .cli_args import args from .cmd import folder_paths @@ -140,11 +138,11 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[ try: logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}") path = hf_hub_download(repo_id=known_file.repo_id, - filename=known_file.filename, - repo_type=known_file.repo_type, - revision=known_file.revision, - local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None, - ) + filename=known_file.filename, + repo_type=known_file.repo_type, + revision=known_file.revision, + local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None, + ) except IOError as exc_info: logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info) except Exception as exc_info: @@ -689,10 +687,14 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]: extra_cache_dirs = folder_paths.get_folder_paths("huggingface_cache") # all in cache directories + try: + default_cache_dir = [scan_cache_dir()] + except CacheNotFound as exc_info: + default_cache_dir = [] existing_repo_ids = frozenset( cache_item.repo_id for cache_item in \ reduce(operator.or_, - map(lambda cache_info: cache_info.repos, [scan_cache_dir()] + [scan_cache_dir(cache_dir=cache_dir) for cache_dir in extra_cache_dirs if os.path.isdir(cache_dir)])) + map(lambda cache_info: cache_info.repos, default_cache_dir + [scan_cache_dir(cache_dir=cache_dir) for cache_dir in extra_cache_dirs if os.path.isdir(cache_dir)])) if cache_item.repo_type == "model" or cache_item.repo_type == "space" ) diff --git a/comfy/nodes/comfyui_v3_package_imports.py b/comfy/nodes/comfyui_v3_package_imports.py index 79da0c3eb..77b4d7820 100644 --- a/comfy/nodes/comfyui_v3_package_imports.py +++ b/comfy/nodes/comfyui_v3_package_imports.py @@ -4,6 +4,7 @@ import logging from .package_typing import ExportedNodes from comfy_api.latest import ComfyExtension +from comfy_api.internal.async_to_sync import AsyncToSyncConverter logger = logging.getLogger(__name__) @@ -17,15 +18,13 @@ def _comfy_entrypoint_upstream_v3_imports(module) -> ExportedNodes: else: if inspect.iscoroutinefunction(entrypoint): # todo: I seriously doubt anything is going to be an async entrypoint, ever - extension_coro = entrypoint() - extension = asyncio.run(extension_coro) + extension = AsyncToSyncConverter.run_async_in_thread(entrypoint) else: extension = entrypoint() if not isinstance(extension, ComfyExtension): logger.debug(f"comfy_entrypoint in {module} did not return a ComfyExtension, skipping.") else: - node_list_coro = extension.get_node_list() - node_list = asyncio.run(node_list_coro) + node_list = AsyncToSyncConverter.run_async_in_thread(extension.get_node_list) if not isinstance(node_list, list): logger.debug(f"comfy_entrypoint in {module} did not return a list of nodes, skipping.") else: diff --git a/comfy/nodes/package.py b/comfy/nodes/package.py index efbdfe6c3..a6dd8f8c6 100644 --- a/comfy/nodes/package.py +++ b/comfy/nodes/package.py @@ -23,10 +23,11 @@ _nodes_available_at_startup: ExportedNodes = ExportedNodes() logger = logging.getLogger(__name__) -def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType): +def _import_nodes_in_module(module: types.ModuleType) -> ExportedNodes: node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None) node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None) web_directory = getattr(module, "WEB_DIRECTORY", None) + exported_nodes = ExportedNodes() if node_class_mappings: exported_nodes.NODE_CLASS_MAPPINGS.update(node_class_mappings) if node_display_names: @@ -42,7 +43,7 @@ def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleT raise ImportError(path=abs_web_directory) exported_nodes.EXTENSION_WEB_DIRS[module.__name__] = abs_web_directory exported_nodes.update(_comfy_entrypoint_upstream_v3_imports(module)) - return node_class_mappings and len(node_class_mappings) > 0 or web_directory + return exported_nodes @@ -58,17 +59,19 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, time_before = time.perf_counter() full_name = module.__name__ try: - any_content_in_module = _import_nodes_in_module(exported_nodes, module) + module_exported_nodes = _import_nodes_in_module(module) span.set_attribute("full_name", full_name) timings.append((time.perf_counter() - time_before, full_name, True, exported_nodes)) except Exception as exc: - any_content_in_module = None + module_exported_nodes = None logger.error(f"{full_name} import failed", exc_info=exc) span.set_status(Status(StatusCode.ERROR)) span.record_exception(exc) exceptions.append(exc) - if any_content_in_module is None or not any_content_in_module: - # Iterate through all the submodules + if module_exported_nodes: + exported_nodes.update(module_exported_nodes) + else: + # iterate through all the submodules and try to find exported nodes for _, name, is_pkg in pkgutil.iter_modules(module.__path__): span: Span with tracer.start_as_current_span("Load Node") as span: diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index 911db4d69..70bc94111 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -195,6 +195,9 @@ class ExportedNodes: exported_nodes = ExportedNodes().update(self) return exported_nodes.update(other) + def __bool__(self): + return len(self.NODE_CLASS_MAPPINGS) + len(self.NODE_DISPLAY_NAME_MAPPINGS) + len(self.EXTENSION_WEB_DIRS) > 0 + class _ExportedNodesAsChainMap(ExportedNodes): @classmethod diff --git a/comfy/ops.py b/comfy/ops.py index accf378a3..dbe910a4c 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -27,8 +27,6 @@ from .execution_context import current_execution_context from .float import stochastic_rounding -def scaled_dot_product_attention(q, k, v, *args, **kwargs): - return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) try: @@ -52,6 +50,9 @@ try: except (ModuleNotFoundError, TypeError): logging.warning("Could not set sdpa backend priority.") + def scaled_dot_product_attention(q, k, v, *args, **kwargs): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) + cast_to = model_management.cast_to # TODO: remove once no more references logger = logging.getLogger(__name__) diff --git a/comfy/text_encoders/qwen_vl.py b/comfy/text_encoders/qwen_vl.py index 3b18ce730..52db1132c 100644 --- a/comfy/text_encoders/qwen_vl.py +++ b/comfy/text_encoders/qwen_vl.py @@ -3,7 +3,7 @@ import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple import math -from comfy.ldm.modules.attention import optimized_attention_for_device +from ..ldm.modules.attention import optimized_attention_for_device def process_qwen2vl_images( diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py index 37ace7f28..26b3b2359 100644 --- a/comfy_api_nodes/apis/client.py +++ b/comfy_api_nodes/apis/client.py @@ -532,7 +532,7 @@ class ApiClient: request_method="PUT", request_url=upload_url, response_status_code=e.status if hasattr(e, "status") else None, - response_headers=dict(e.headers) if getattr(e, "headers") else None, + response_headers=dict(e.headers) if getattr(e, "headers") else None, # pylint: disable=no-member response_content=None, error_message=f"{type(e).__name__}: {str(e)}", ) diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 806a70e06..17fe08134 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -656,6 +656,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): async def generate( self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs ): + image_url: Optional[str] = None video = kwargs.get("video") image = kwargs.get("image", None) diff --git a/comfy_extras/nodes/nodes_context_windows.py b/comfy_extras/nodes/nodes_context_windows.py index 1c3d9e697..4cd8370aa 100644 --- a/comfy_extras/nodes/nodes_context_windows.py +++ b/comfy_extras/nodes/nodes_context_windows.py @@ -1,7 +1,7 @@ from __future__ import annotations from comfy_api.latest import ComfyExtension, io import comfy.context_windows -import nodes +from comfy.nodes import base_nodes as nodes class ContextWindowsManualNode(io.ComfyNode): diff --git a/comfy_extras/nodes/nodes_model_patch.py b/comfy_extras/nodes/nodes_model_patch.py index 3eaada9bc..83cdca899 100644 --- a/comfy_extras/nodes/nodes_model_patch.py +++ b/comfy_extras/nodes/nodes_model_patch.py @@ -1,5 +1,5 @@ import torch -import folder_paths +from comfy.cmd import folder_paths # pylint: disable=no-name-in-module import comfy.utils import comfy.ops import comfy.model_management @@ -27,12 +27,12 @@ class BlockWiseControlBlock(torch.nn.Module): class QwenImageBlockWiseControlNet(torch.nn.Module): def __init__( - self, - num_layers: int = 60, - in_dim: int = 64, - additional_in_dim: int = 0, - dim: int = 3072, - device=None, dtype=None, operations=None + self, + num_layers: int = 60, + in_dim: int = 64, + additional_in_dim: int = 0, + dim: int = 3072, + device=None, dtype=None, operations=None ): super().__init__() self.additional_in_dim = additional_in_dim @@ -61,8 +61,9 @@ class QwenImageBlockWiseControlNet(torch.nn.Module): class ModelPatchLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "name": (folder_paths.get_filename_list("model_patches"), ), - }} + return {"required": {"name": (folder_paths.get_filename_list("model_patches"),), + }} + RETURN_TYPES = ("MODEL_PATCH",) FUNCTION = "load_model_patch" EXPERIMENTAL = True @@ -125,16 +126,18 @@ class DiffSynthCnetPatch: def models(self): return [self.model_patch] + class QwenImageDiffsynthControlnet: @classmethod def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "model_patch": ("MODEL_PATCH",), - "vae": ("VAE",), - "image": ("IMAGE",), - "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - }, + return {"required": {"model": ("MODEL",), + "model_patch": ("MODEL_PATCH",), + "vae": ("VAE",), + "image": ("IMAGE",), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }, "optional": {"mask": ("MASK",)}} + RETURN_TYPES = ("MODEL",) FUNCTION = "diffsynth_controlnet" EXPERIMENTAL = True diff --git a/comfy_extras/nodes/nodes_qwen.py b/comfy_extras/nodes/nodes_qwen.py index fff89556f..214e6a372 100644 --- a/comfy_extras/nodes/nodes_qwen.py +++ b/comfy_extras/nodes/nodes_qwen.py @@ -1,4 +1,4 @@ -import node_helpers +from comfy import node_helpers import comfy.utils import math @@ -7,11 +7,11 @@ class TextEncodeQwenImageEdit: @classmethod def INPUT_TYPES(s): return {"required": { - "clip": ("CLIP", ), + "clip": ("CLIP",), "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), - }, - "optional": {"vae": ("VAE", ), - "image": ("IMAGE", ),}} + }, + "optional": {"vae": ("VAE",), + "image": ("IMAGE",), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" @@ -40,7 +40,7 @@ class TextEncodeQwenImageEdit: conditioning = clip.encode_from_tokens_scheduled(tokens) if ref_latent is not None: conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True) - return (conditioning, ) + return (conditioning,) NODE_CLASS_MAPPINGS = { diff --git a/pyproject.toml b/pyproject.toml index 951142513..e49abdb5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,7 +119,7 @@ where = ["."] include = ["comfy*"] namespaces = false -[dependency-groups] +[project.optional-dependencies] dev = [ "pytest", "pytest-asyncio", @@ -137,7 +137,6 @@ dev = [ "astroid", ] -[project.optional-dependencies] cpu = [ "torch", "torchvision",