mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Update to 0.3.51
This commit is contained in:
parent
664349eabf
commit
735a133ad4
@ -78,7 +78,6 @@ class ModelFileManager:
|
||||
return web.Response(status=404)
|
||||
|
||||
def get_model_file_list(self, folder_name: str):
|
||||
folder_name = folder_paths.map_legacy(folder_name)
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
output_list: list[dict] = []
|
||||
|
||||
|
||||
@ -45,6 +45,7 @@ logging.getLogger("torch.distributed.elastic.multiprocessing.redirects").addFilt
|
||||
lambda record: log_msg_to_filter not in record.getMessage()
|
||||
)
|
||||
logging.getLogger("alembic.runtime.migration").setLevel(logging.WARNING)
|
||||
logging.getLogger("asyncio").addFilter(lambda record: 'Using selector:' not in record.getMessage())
|
||||
|
||||
from ..cli_args import args
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
from pathlib import Path
|
||||
@ -17,16 +18,19 @@ from pebble import ThreadPool
|
||||
|
||||
from .tqdm_watcher import TqdmWatcher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_VAR = "HF_HUB_ENABLE_HF_TRANSFER"
|
||||
_XET_VAR = "HF_XET_HIGH_PERFORMANCE"
|
||||
|
||||
os.environ[_VAR] = "True"
|
||||
|
||||
os.environ["HF_HUB_DISABLE_XET"] = "1"
|
||||
# os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
|
||||
if platform.system() == "Windows":
|
||||
os.environ["HF_HUB_DISABLE_XET"] = "1"
|
||||
logger.debug("Xet was disabled since it is currently not reliable")
|
||||
os.environ[_VAR] = "True"
|
||||
else:
|
||||
os.environ[_XET_VAR] = "True"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.debug("Xet was disabled since it is currently not reliable")
|
||||
|
||||
|
||||
def hf_hub_download_with_disable_fast(repo_id=None, filename=None, disable_fast=None, hf_env: dict[str, str] = None, **kwargs):
|
||||
|
||||
@ -6,12 +6,13 @@ import collections
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from .model_management import throw_exception_if_processing_interrupted
|
||||
from .patcher_extension import get_all_callbacks, WrappersMP
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.controlnet import ControlBase
|
||||
from .model_base import BaseModel
|
||||
from .model_patcher import ModelPatcher
|
||||
from .controlnet import ControlBase
|
||||
|
||||
|
||||
class ContextWindowABC(ABC):
|
||||
@ -32,6 +33,7 @@ class ContextWindowABC(ABC):
|
||||
"""
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
|
||||
class ContextHandlerABC(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
@ -49,9 +51,8 @@ class ContextHandlerABC(ABC):
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0):
|
||||
def __init__(self, index_list: list[int], dim: int = 0):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.dim = dim
|
||||
@ -87,14 +88,18 @@ class ContextSchedule:
|
||||
name: str
|
||||
func: Callable
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextFuseMethod:
|
||||
name: str
|
||||
func: Callable
|
||||
|
||||
|
||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||
|
||||
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int = 1, context_overlap: int = 0, context_stride: int = 1, closed_loop=False, dim=0):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
@ -152,7 +157,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
|
||||
elif cond_key == "num_video_frames": # for SVD
|
||||
elif cond_key == "num_video_frames": # for SVD
|
||||
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||
new_cond_item[cond_key].cond = window.context_length
|
||||
resized_actual_cond[key] = new_cond_item
|
||||
@ -171,7 +176,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
self._step = int(matches[0].item())
|
||||
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
|
||||
return context_windows
|
||||
@ -188,14 +193,14 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
for callback in get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
for enum_window in enumerated_context_windows:
|
||||
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
||||
for result in results:
|
||||
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
||||
conds_final, counts_final, biases_final)
|
||||
conds_final, counts_final, biases_final)
|
||||
try:
|
||||
# finalize conds
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
@ -209,17 +214,17 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
del counts_final
|
||||
return conds_final
|
||||
finally:
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||
for callback in get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||
model_options, device=None, first_device=None):
|
||||
model_options, device=None, first_device=None):
|
||||
results: list[ContextResults] = []
|
||||
for window_idx, window in enumerated_context_windows:
|
||||
# allow processing to end between context window executions for faster Cancel
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
throw_exception_if_processing_interrupted()
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
for callback in get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||
|
||||
# update exposed params
|
||||
@ -236,9 +241,8 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||
return results
|
||||
|
||||
|
||||
def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor,
|
||||
conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
|
||||
conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
for pos, idx in enumerate(window.index_list):
|
||||
# bias is the influence of a specific index in relation to the whole context window
|
||||
@ -263,7 +267,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
|
||||
window.add_window(counts_final[i], weights_tensor)
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
|
||||
for callback in get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
|
||||
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
|
||||
|
||||
|
||||
@ -281,7 +285,7 @@ def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args,
|
||||
|
||||
def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING,
|
||||
WrappersMP.PREPARE_SAMPLING,
|
||||
"ContextWindows_prepare_sampling",
|
||||
_prepare_sampling_wrapper
|
||||
)
|
||||
@ -296,6 +300,7 @@ def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, dev
|
||||
weights_tensor = weights_tensor.unsqueeze(-1)
|
||||
return weights_tensor
|
||||
|
||||
|
||||
def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
|
||||
total_dims = len(x_in.shape)
|
||||
shape = []
|
||||
@ -306,6 +311,7 @@ def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
|
||||
shape.append(1)
|
||||
return shape
|
||||
|
||||
|
||||
class ContextSchedules:
|
||||
UNIFORM_LOOPED = "looped_uniform"
|
||||
UNIFORM_STANDARD = "standard_uniform"
|
||||
@ -325,14 +331,15 @@ def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHand
|
||||
for context_step in 1 << np.arange(context_stride):
|
||||
pad = int(round(num_frames * ordered_halving(handler._step)))
|
||||
for j in range(
|
||||
int(ordered_halving(handler._step) * context_step) + pad,
|
||||
num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
|
||||
(handler.context_length * context_step - handler.context_overlap),
|
||||
int(ordered_halving(handler._step) * context_step) + pad,
|
||||
num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
|
||||
(handler.context_length * context_step - handler.context_overlap),
|
||||
):
|
||||
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
||||
|
||||
return windows
|
||||
|
||||
|
||||
def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
# unlike looped, uniform_straight does NOT allow windows that loop back to the beginning;
|
||||
# instead, they get shifted to the corresponding end of the frames.
|
||||
@ -347,9 +354,9 @@ def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHa
|
||||
for context_step in 1 << np.arange(context_stride):
|
||||
pad = int(round(num_frames * ordered_halving(handler._step)))
|
||||
for j in range(
|
||||
int(ordered_halving(handler._step) * context_step) + pad,
|
||||
num_frames + pad + (-handler.context_overlap),
|
||||
(handler.context_length * context_step - handler.context_overlap),
|
||||
int(ordered_halving(handler._step) * context_step) + pad,
|
||||
num_frames + pad + (-handler.context_overlap),
|
||||
(handler.context_length * context_step - handler.context_overlap),
|
||||
):
|
||||
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
||||
|
||||
@ -363,9 +370,9 @@ def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHa
|
||||
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
|
||||
shift_window_to_end(windows[win_i], num_frames=num_frames)
|
||||
# check if next window (cyclical) is missing roll_val
|
||||
if roll_val not in windows[(win_i+1) % len(windows)]:
|
||||
if roll_val not in windows[(win_i + 1) % len(windows)]:
|
||||
# need to insert new window here - just insert window starting at roll_val
|
||||
windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length)))
|
||||
windows.insert(win_i + 1, list(range(roll_val, roll_val + handler.context_length)))
|
||||
# delete window if it's not unique
|
||||
for pre_i in range(0, win_i):
|
||||
if windows[win_i] == windows[pre_i]:
|
||||
@ -432,7 +439,7 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
|
||||
return ContextSchedule(context_schedule, func)
|
||||
|
||||
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor = None):
|
||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
|
||||
|
||||
|
||||
@ -440,6 +447,7 @@ def create_weights_flat(length: int, **kwargs) -> list[float]:
|
||||
# weight is the same for all
|
||||
return [1.0] * length
|
||||
|
||||
|
||||
def create_weights_pyramid(length: int, **kwargs) -> list[float]:
|
||||
# weight is based on the distance away from the edge of the context window;
|
||||
# based on weighted average concept in FreeNoise paper
|
||||
@ -451,6 +459,7 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
|
||||
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
|
||||
return weight_sequence
|
||||
|
||||
|
||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
|
||||
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
|
||||
# only expected overlap is given different weights
|
||||
@ -460,11 +469,12 @@ def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int]
|
||||
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
|
||||
weights_torch[:handler.context_overlap] = ramp_up
|
||||
# blend right-side on all except last window
|
||||
if max(idxs) < full_length-1:
|
||||
if max(idxs) < full_length - 1:
|
||||
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
|
||||
weights_torch[-handler.context_overlap:] = ramp_down
|
||||
return weights_torch
|
||||
|
||||
|
||||
class ContextFuseMethods:
|
||||
FLAT = "flat"
|
||||
PYRAMID = "pyramid"
|
||||
@ -482,12 +492,14 @@ FUSE_MAPPING = {
|
||||
ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear,
|
||||
}
|
||||
|
||||
|
||||
def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod:
|
||||
func = FUSE_MAPPING.get(fuse_method, None)
|
||||
if func is None:
|
||||
raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
|
||||
return ContextFuseMethod(fuse_method, func)
|
||||
|
||||
|
||||
# Returns fraction that has denominator that is a power of 2
|
||||
def ordered_halving(val):
|
||||
# get binary value, padded with 0s for 64 bits
|
||||
|
||||
@ -662,7 +662,7 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_model = QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
latent_format = latent_formats.Wan21()
|
||||
extra_conds = []
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
@ -5,30 +5,37 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.attention_dispatch import sageattn
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn, einsum
|
||||
|
||||
from .diffusionmodules.util import AlphaBlender, timestep_embedding
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
from ... import model_management
|
||||
from ...ops import scaled_dot_product_attention
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers # pylint: disable=import-error
|
||||
import xformers.ops # pylint: disable=import-error
|
||||
|
||||
sageattn = None
|
||||
if model_management.sage_attention_enabled():
|
||||
try:
|
||||
from sageattention import sageattn # pylint: disable=import-error
|
||||
except ModuleNotFoundError as e:
|
||||
if e.name == "sageattention":
|
||||
import sys
|
||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||
|
||||
logger.error(f"To use the `--use-sage-attention` feature, the `sageattention` package must be installed first.")
|
||||
else:
|
||||
raise e
|
||||
sageattn = torch.nn.functional.scaled_dot_product_attention
|
||||
else:
|
||||
sageattn = torch.nn.functional.scaled_dot_product_attention
|
||||
|
||||
|
||||
flash_attn_func = None
|
||||
if model_management.flash_attention_enabled():
|
||||
from flash_attn import flash_attn_func # pylint: disable=import-error
|
||||
else:
|
||||
@ -40,7 +47,6 @@ from ... import ops
|
||||
ops = ops.disable_weight_init
|
||||
|
||||
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_attn_precision(attn_precision, current_dtype):
|
||||
@ -480,7 +486,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
@ -493,7 +499,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
if mask.shape[0] > 1:
|
||||
m = mask[i: i + SDP_BATCH_LIMIT]
|
||||
|
||||
out[i: i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention(
|
||||
out[i: i + SDP_BATCH_LIMIT] = scaled_dot_product_attention(
|
||||
q[i: i + SDP_BATCH_LIMIT],
|
||||
k[i: i + SDP_BATCH_LIMIT],
|
||||
v[i: i + SDP_BATCH_LIMIT],
|
||||
@ -527,7 +533,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
try:
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
except Exception as e:
|
||||
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||
logger.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||
if tensor_layout == "NHD":
|
||||
q, k, v = map(
|
||||
lambda t: t.transpose(1, 2),
|
||||
@ -551,7 +557,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
try:
|
||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) # pylint: disable=possibly-used-before-assignment,used-before-assignment
|
||||
|
||||
|
||||
@ -562,8 +568,9 @@ try:
|
||||
except AttributeError as error:
|
||||
FLASH_ATTN_ERROR = error
|
||||
|
||||
|
||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||
|
||||
|
||||
@ -596,7 +603,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
causal=False,
|
||||
).transpose(1, 2)
|
||||
except Exception as e:
|
||||
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
|
||||
logger.warning(f"Flash Attention failed, using default SDPA: {e}")
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
@ -616,7 +623,7 @@ elif model_management.xformers_enabled():
|
||||
logger.debug("Using xformers attention")
|
||||
optimized_attention = attention_xformers
|
||||
elif model_management.flash_attention_enabled():
|
||||
logging.debug("Using Flash Attention")
|
||||
logger.debug("Using Flash Attention")
|
||||
optimized_attention = attention_flash
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logger.debug("Using pytorch attention")
|
||||
|
||||
@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
|
||||
from .... import model_management
|
||||
from .... import ops
|
||||
from ....ops import scaled_dot_product_attention
|
||||
|
||||
ops = ops.disable_weight_init
|
||||
|
||||
@ -295,7 +296,7 @@ def pytorch_attention(q, k, v):
|
||||
)
|
||||
|
||||
try:
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(orig_shape)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logger.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
|
||||
@ -13,15 +13,13 @@ from os.path import join
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Final, Set
|
||||
|
||||
from .component_model.hf_hub_download_with_disable_xet import hf_hub_download_with_retries
|
||||
|
||||
import tqdm
|
||||
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
|
||||
from huggingface_hub import dump_environment_info
|
||||
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem, CacheNotFound
|
||||
from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError
|
||||
from requests import Session
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
from huggingface_hub import dump_environment_info
|
||||
|
||||
from .cli_args import args
|
||||
from .cmd import folder_paths
|
||||
@ -140,11 +138,11 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
||||
try:
|
||||
logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}")
|
||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||
filename=known_file.filename,
|
||||
repo_type=known_file.repo_type,
|
||||
revision=known_file.revision,
|
||||
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
|
||||
)
|
||||
filename=known_file.filename,
|
||||
repo_type=known_file.repo_type,
|
||||
revision=known_file.revision,
|
||||
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
|
||||
)
|
||||
except IOError as exc_info:
|
||||
logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
|
||||
except Exception as exc_info:
|
||||
@ -689,10 +687,14 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
|
||||
extra_cache_dirs = folder_paths.get_folder_paths("huggingface_cache")
|
||||
|
||||
# all in cache directories
|
||||
try:
|
||||
default_cache_dir = [scan_cache_dir()]
|
||||
except CacheNotFound as exc_info:
|
||||
default_cache_dir = []
|
||||
existing_repo_ids = frozenset(
|
||||
cache_item.repo_id for cache_item in \
|
||||
reduce(operator.or_,
|
||||
map(lambda cache_info: cache_info.repos, [scan_cache_dir()] + [scan_cache_dir(cache_dir=cache_dir) for cache_dir in extra_cache_dirs if os.path.isdir(cache_dir)]))
|
||||
map(lambda cache_info: cache_info.repos, default_cache_dir + [scan_cache_dir(cache_dir=cache_dir) for cache_dir in extra_cache_dirs if os.path.isdir(cache_dir)]))
|
||||
if cache_item.repo_type == "model" or cache_item.repo_type == "space"
|
||||
)
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import logging
|
||||
|
||||
from .package_typing import ExportedNodes
|
||||
from comfy_api.latest import ComfyExtension
|
||||
from comfy_api.internal.async_to_sync import AsyncToSyncConverter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -17,15 +18,13 @@ def _comfy_entrypoint_upstream_v3_imports(module) -> ExportedNodes:
|
||||
else:
|
||||
if inspect.iscoroutinefunction(entrypoint):
|
||||
# todo: I seriously doubt anything is going to be an async entrypoint, ever
|
||||
extension_coro = entrypoint()
|
||||
extension = asyncio.run(extension_coro)
|
||||
extension = AsyncToSyncConverter.run_async_in_thread(entrypoint)
|
||||
else:
|
||||
extension = entrypoint()
|
||||
if not isinstance(extension, ComfyExtension):
|
||||
logger.debug(f"comfy_entrypoint in {module} did not return a ComfyExtension, skipping.")
|
||||
else:
|
||||
node_list_coro = extension.get_node_list()
|
||||
node_list = asyncio.run(node_list_coro)
|
||||
node_list = AsyncToSyncConverter.run_async_in_thread(extension.get_node_list)
|
||||
if not isinstance(node_list, list):
|
||||
logger.debug(f"comfy_entrypoint in {module} did not return a list of nodes, skipping.")
|
||||
else:
|
||||
|
||||
@ -23,10 +23,11 @@ _nodes_available_at_startup: ExportedNodes = ExportedNodes()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType):
|
||||
def _import_nodes_in_module(module: types.ModuleType) -> ExportedNodes:
|
||||
node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None)
|
||||
node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None)
|
||||
web_directory = getattr(module, "WEB_DIRECTORY", None)
|
||||
exported_nodes = ExportedNodes()
|
||||
if node_class_mappings:
|
||||
exported_nodes.NODE_CLASS_MAPPINGS.update(node_class_mappings)
|
||||
if node_display_names:
|
||||
@ -42,7 +43,7 @@ def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleT
|
||||
raise ImportError(path=abs_web_directory)
|
||||
exported_nodes.EXTENSION_WEB_DIRS[module.__name__] = abs_web_directory
|
||||
exported_nodes.update(_comfy_entrypoint_upstream_v3_imports(module))
|
||||
return node_class_mappings and len(node_class_mappings) > 0 or web_directory
|
||||
return exported_nodes
|
||||
|
||||
|
||||
|
||||
@ -58,17 +59,19 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
||||
time_before = time.perf_counter()
|
||||
full_name = module.__name__
|
||||
try:
|
||||
any_content_in_module = _import_nodes_in_module(exported_nodes, module)
|
||||
module_exported_nodes = _import_nodes_in_module(module)
|
||||
span.set_attribute("full_name", full_name)
|
||||
timings.append((time.perf_counter() - time_before, full_name, True, exported_nodes))
|
||||
except Exception as exc:
|
||||
any_content_in_module = None
|
||||
module_exported_nodes = None
|
||||
logger.error(f"{full_name} import failed", exc_info=exc)
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
span.record_exception(exc)
|
||||
exceptions.append(exc)
|
||||
if any_content_in_module is None or not any_content_in_module:
|
||||
# Iterate through all the submodules
|
||||
if module_exported_nodes:
|
||||
exported_nodes.update(module_exported_nodes)
|
||||
else:
|
||||
# iterate through all the submodules and try to find exported nodes
|
||||
for _, name, is_pkg in pkgutil.iter_modules(module.__path__):
|
||||
span: Span
|
||||
with tracer.start_as_current_span("Load Node") as span:
|
||||
|
||||
@ -195,6 +195,9 @@ class ExportedNodes:
|
||||
exported_nodes = ExportedNodes().update(self)
|
||||
return exported_nodes.update(other)
|
||||
|
||||
def __bool__(self):
|
||||
return len(self.NODE_CLASS_MAPPINGS) + len(self.NODE_DISPLAY_NAME_MAPPINGS) + len(self.EXTENSION_WEB_DIRS) > 0
|
||||
|
||||
|
||||
class _ExportedNodesAsChainMap(ExportedNodes):
|
||||
@classmethod
|
||||
|
||||
@ -27,8 +27,6 @@ from .execution_context import current_execution_context
|
||||
from .float import stochastic_rounding
|
||||
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
|
||||
|
||||
try:
|
||||
@ -52,6 +50,9 @@ try:
|
||||
except (ModuleNotFoundError, TypeError):
|
||||
logging.warning("Could not set sdpa backend priority.")
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
|
||||
cast_to = model_management.cast_to # TODO: remove once no more references
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -3,7 +3,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
import math
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from ..ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
|
||||
def process_qwen2vl_images(
|
||||
|
||||
@ -532,7 +532,7 @@ class ApiClient:
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=e.status if hasattr(e, "status") else None,
|
||||
response_headers=dict(e.headers) if getattr(e, "headers") else None,
|
||||
response_headers=dict(e.headers) if getattr(e, "headers") else None, # pylint: disable=no-member
|
||||
response_content=None,
|
||||
error_message=f"{type(e).__name__}: {str(e)}",
|
||||
)
|
||||
|
||||
@ -656,6 +656,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
async def generate(
|
||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||
):
|
||||
image_url: Optional[str] = None
|
||||
video = kwargs.get("video")
|
||||
image = kwargs.get("image", None)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import comfy.context_windows
|
||||
import nodes
|
||||
from comfy.nodes import base_nodes as nodes
|
||||
|
||||
|
||||
class ContextWindowsManualNode(io.ComfyNode):
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
import folder_paths
|
||||
from comfy.cmd import folder_paths # pylint: disable=no-name-in-module
|
||||
import comfy.utils
|
||||
import comfy.ops
|
||||
import comfy.model_management
|
||||
@ -27,12 +27,12 @@ class BlockWiseControlBlock(torch.nn.Module):
|
||||
|
||||
class QwenImageBlockWiseControlNet(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 60,
|
||||
in_dim: int = 64,
|
||||
additional_in_dim: int = 0,
|
||||
dim: int = 3072,
|
||||
device=None, dtype=None, operations=None
|
||||
self,
|
||||
num_layers: int = 60,
|
||||
in_dim: int = 64,
|
||||
additional_in_dim: int = 0,
|
||||
dim: int = 3072,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.additional_in_dim = additional_in_dim
|
||||
@ -61,8 +61,9 @@ class QwenImageBlockWiseControlNet(torch.nn.Module):
|
||||
class ModelPatchLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "name": (folder_paths.get_filename_list("model_patches"), ),
|
||||
}}
|
||||
return {"required": {"name": (folder_paths.get_filename_list("model_patches"),),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL_PATCH",)
|
||||
FUNCTION = "load_model_patch"
|
||||
EXPERIMENTAL = True
|
||||
@ -125,16 +126,18 @@ class DiffSynthCnetPatch:
|
||||
def models(self):
|
||||
return [self.model_patch]
|
||||
|
||||
|
||||
class QwenImageDiffsynthControlnet:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"model_patch": ("MODEL_PATCH",),
|
||||
"vae": ("VAE",),
|
||||
"image": ("IMAGE",),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
},
|
||||
return {"required": {"model": ("MODEL",),
|
||||
"model_patch": ("MODEL_PATCH",),
|
||||
"vae": ("VAE",),
|
||||
"image": ("IMAGE",),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {"mask": ("MASK",)}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "diffsynth_controlnet"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import node_helpers
|
||||
from comfy import node_helpers
|
||||
import comfy.utils
|
||||
import math
|
||||
|
||||
@ -7,11 +7,11 @@ class TextEncodeQwenImageEdit:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"clip": ("CLIP",),
|
||||
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
},
|
||||
"optional": {"vae": ("VAE", ),
|
||||
"image": ("IMAGE", ),}}
|
||||
},
|
||||
"optional": {"vae": ("VAE",),
|
||||
"image": ("IMAGE",), }}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
@ -40,7 +40,7 @@ class TextEncodeQwenImageEdit:
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
if ref_latent is not None:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True)
|
||||
return (conditioning, )
|
||||
return (conditioning,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@ -119,7 +119,7 @@ where = ["."]
|
||||
include = ["comfy*"]
|
||||
namespaces = false
|
||||
|
||||
[dependency-groups]
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest",
|
||||
"pytest-asyncio",
|
||||
@ -137,7 +137,6 @@ dev = [
|
||||
"astroid",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
cpu = [
|
||||
"torch",
|
||||
"torchvision",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user