mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
3684cff31b
1
.github/workflows/check-line-endings.yml
vendored
1
.github/workflows/check-line-endings.yml
vendored
@ -17,6 +17,7 @@ jobs:
|
||||
- name: Check for Windows line endings (CRLF)
|
||||
run: |
|
||||
# Get the list of changed files in the PR
|
||||
git merge origin/${{ github.base_ref }} --no-edit
|
||||
CHANGED_FILES=$(git diff --name-only origin/${{ github.base_ref }}..HEAD)
|
||||
|
||||
# Flag to track if CRLF is found
|
||||
|
||||
@ -42,6 +42,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
|
||||
- Video Models
|
||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
@ -1574,6 +1575,14 @@ If you need to use the legacy frontend for any reason, you can access it using t
|
||||
|
||||
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
|
||||
|
||||
### Iluvatar Corex
|
||||
|
||||
For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method:
|
||||
|
||||
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
|
||||
2. Launch ComfyUI by running `python main.py`
|
||||
|
||||
|
||||
## Community
|
||||
|
||||
[Discord](https://comfy.org/discord): Try the #help or #feedback channels.
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.44"
|
||||
__version__ = "0.3.45"
|
||||
|
||||
@ -7,6 +7,7 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
import zipfile
|
||||
import importlib.metadata
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
@ -22,10 +23,12 @@ REQUEST_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
def check_frontend_version():
|
||||
"""the thing this does makes no sense, so it got cut"""
|
||||
return None
|
||||
|
||||
|
||||
def frontend_install_warning_message() -> str:
|
||||
"""the end user never needs to be messaged this"""
|
||||
return ""
|
||||
|
||||
|
||||
@ -135,6 +138,15 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||
class FrontendManager:
|
||||
CUSTOM_FRONTENDS_ROOT = add_model_folder_path("web_custom_versions", extensions=set())
|
||||
|
||||
@classmethod
|
||||
def get_required_frontend_version(cls) -> str:
|
||||
"""Get the required frontend package version."""
|
||||
try:
|
||||
# this isn't used the way it says
|
||||
return importlib.metadata.version("comfyui_frontend_package")
|
||||
except Exception as exc_info:
|
||||
return "1.23.4"
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
try:
|
||||
|
||||
@ -44,7 +44,8 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID",
|
||||
help="Set the id of the cuda device this instance will use.")
|
||||
help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
|
||||
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
|
||||
cm_group = parser.add_mutually_exclusive_group()
|
||||
cm_group.add_argument("--cuda-malloc", action="store_true",
|
||||
help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
|
||||
@ -165,6 +165,7 @@ class Configuration(dict):
|
||||
comfy_api_base (str): Set the base URL for the ComfyUI API. (default: https://api.comfy.org)
|
||||
database_url (str): Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.
|
||||
whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled.
|
||||
default_device (Optional[int]): Set the id of the default device, all other devices will stay visible.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@ -281,6 +282,7 @@ class Configuration(dict):
|
||||
self.front_end_root: Optional[str] = None
|
||||
self.comfy_api_base: str = "https://api.comfy.org"
|
||||
self.database_url: str = db_config()
|
||||
self.default_device: Optional[int] = None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
self[key] = value
|
||||
|
||||
@ -82,7 +82,8 @@ if not args.cuda_malloc:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
version = module.__version__
|
||||
if int(version[0]) >= 2: # enable by default for torch version 2.0 and up
|
||||
|
||||
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
|
||||
args.cuda_malloc = cuda_malloc_supported()
|
||||
except:
|
||||
pass
|
||||
|
||||
@ -49,6 +49,15 @@ logging.getLogger("alembic.runtime.migration").setLevel(logging.WARNING)
|
||||
|
||||
from ..cli_args import args
|
||||
|
||||
if args.default_device is not None:
|
||||
default_dev = args.default_device
|
||||
devices = list(range(32))
|
||||
devices.remove(default_dev)
|
||||
devices.insert(0, default_dev)
|
||||
devices = ','.join(map(str, devices))
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
|
||||
os.environ['HIP_VISIBLE_DEVICES'] = str(devices)
|
||||
|
||||
if args.cuda_device is not None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
|
||||
@ -593,6 +593,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
ram_free = model_management.get_free_memory(cpu_device)
|
||||
vram_total, torch_vram_total = get_total_memory(device, torch_total_too=True)
|
||||
vram_free, torch_vram_free = get_free_memory(device, torch_free_too=True)
|
||||
required_frontend_version = FrontendManager.get_required_frontend_version()
|
||||
|
||||
system_stats = {
|
||||
"system": {
|
||||
@ -600,6 +601,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
"ram_total": ram_total,
|
||||
"ram_free": ram_free,
|
||||
"comfyui_version": __version__,
|
||||
"required_frontend_version": required_frontend_version,
|
||||
"python_version": sys.version,
|
||||
"pytorch_version": torch_version,
|
||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||
|
||||
@ -1244,42 +1244,20 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
return x_next
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
temp = [0]
|
||||
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
sigma_hat = sigmas[i]
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, temp[0])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
# Euler method
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
"""Ancestral sampling with Euler method steps (CFG++)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
|
||||
temp = [0]
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
|
||||
uncond_denoised = None
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
nonlocal uncond_denoised
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
@ -1288,17 +1266,33 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], temp[0])
|
||||
# Euler method
|
||||
x = denoised + d * sigma_down
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
|
||||
d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
|
||||
|
||||
# DDIM stochastic sampling
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
|
||||
sigma_down = alpha_t * sigma_down
|
||||
|
||||
# Euler method
|
||||
x = alpha_t * denoised + sigma_down * d
|
||||
if eta > 0 and s_noise > 0:
|
||||
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""Euler method steps (CFG++)."""
|
||||
return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
|
||||
@ -52,15 +52,6 @@ class RMS_norm(nn.Module):
|
||||
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
|
||||
|
||||
|
||||
class Upsample(nn.Upsample):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Fix bfloat16 support for nearest neighbor interpolation.
|
||||
"""
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
@ -73,11 +64,11 @@ class Resample(nn.Module):
|
||||
# layers
|
||||
if mode == 'upsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
elif mode == 'upsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
@ -125,7 +125,7 @@ if args.directml is not None:
|
||||
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, noqa: F401
|
||||
|
||||
_ = torch.xpu.device_count()
|
||||
xpu_available = xpu_available or torch.xpu.is_available()
|
||||
@ -155,6 +155,11 @@ try:
|
||||
except:
|
||||
mlu_available = False
|
||||
|
||||
try:
|
||||
ixuca_available = hasattr(torch, "corex")
|
||||
except:
|
||||
ixuca_available = False
|
||||
|
||||
if args.cpu:
|
||||
cpu_state = CPUState.CPU
|
||||
|
||||
@ -181,6 +186,12 @@ def is_mlu():
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_ixuca():
|
||||
global ixuca_available
|
||||
if ixuca_available:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_torch_device():
|
||||
global directml_device
|
||||
@ -220,7 +231,11 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
mem_total = 1024 * 1024 * 1024 # TODO
|
||||
mem_total_torch = mem_total
|
||||
elif is_intel_xpu():
|
||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||
stats = torch.xpu.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = mem_total_xpu
|
||||
elif is_ascend_npu():
|
||||
stats = torch.npu.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
@ -334,7 +349,7 @@ try:
|
||||
if torch_version_numeric[0] >= 2:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if is_intel_xpu() or is_ascend_npu() or is_mlu():
|
||||
if is_intel_xpu() or is_ascend_npu() or is_mlu() or is_ixuca():
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
except:
|
||||
@ -352,7 +367,10 @@ try:
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 8):
|
||||
if any((a in arch) for a in ["gfx1201"]):
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||
@ -430,6 +448,8 @@ def get_torch_device_name(device):
|
||||
except:
|
||||
allocator_backend = ""
|
||||
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
|
||||
elif device.type == "xpu":
|
||||
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
||||
else:
|
||||
return "{}".format(device.type)
|
||||
elif is_intel_xpu():
|
||||
@ -984,6 +1004,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
||||
return d
|
||||
|
||||
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
||||
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
|
||||
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
|
||||
return d
|
||||
|
||||
@ -1044,7 +1065,7 @@ def device_supports_non_blocking(device):
|
||||
if is_device_mps(device):
|
||||
return False # pytorch bug? mps doesn't support non blocking
|
||||
if is_intel_xpu():
|
||||
return False
|
||||
return True
|
||||
if args.deterministic: # TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
||||
return False
|
||||
if directml_device:
|
||||
@ -1087,6 +1108,8 @@ def get_offload_stream(device):
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
if is_device_cuda(device):
|
||||
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
||||
elif is_device_xpu(device):
|
||||
ss[stream_counter].wait_stream(torch.xpu.current_stream())
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
elif is_device_cuda(device):
|
||||
@ -1098,6 +1121,15 @@ def get_offload_stream(device):
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
elif is_device_xpu(device):
|
||||
ss = []
|
||||
for k in range(NUM_STREAMS):
|
||||
ss.append(torch.xpu.Stream(device=device, priority=0))
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
return None
|
||||
|
||||
|
||||
@ -1106,6 +1138,8 @@ def sync_stream(device, stream):
|
||||
return
|
||||
if is_device_cuda(device):
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
elif is_device_xpu(device):
|
||||
torch.xpu.current_stream().wait_stream(stream)
|
||||
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||
@ -1152,6 +1186,8 @@ def xformers_enabled():
|
||||
return False
|
||||
if is_mlu():
|
||||
return False
|
||||
if is_ixuca():
|
||||
return False
|
||||
if directml_device:
|
||||
return False
|
||||
return XFORMERS_IS_AVAILABLE
|
||||
@ -1202,6 +1238,8 @@ def pytorch_attention_flash_attention():
|
||||
return True
|
||||
if is_amd():
|
||||
return True # if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
||||
if is_ixuca():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@ -1234,8 +1272,8 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
stats = torch.xpu.memory_stats(dev) # pylint: disable=no-member
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_xpu + mem_free_torch
|
||||
elif is_ascend_npu():
|
||||
stats = torch.npu.memory_stats(dev)
|
||||
@ -1289,6 +1327,8 @@ def is_device_cpu(device):
|
||||
def is_device_mps(device):
|
||||
return is_device_type(device, 'mps')
|
||||
|
||||
def is_device_xpu(device):
|
||||
return is_device_type(device, 'xpu')
|
||||
|
||||
def is_device_cuda(device):
|
||||
return is_device_type(device, 'cuda')
|
||||
@ -1323,7 +1363,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
if torch_version_numeric < (2, 3):
|
||||
return True
|
||||
else:
|
||||
return torch.xpu.get_device_properties(device).has_fp16
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
@ -1331,6 +1374,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if is_mlu():
|
||||
return True
|
||||
|
||||
if is_ixuca():
|
||||
return True
|
||||
|
||||
if is_amd():
|
||||
return True
|
||||
try:
|
||||
@ -1394,11 +1440,17 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
if torch_version_numeric < (2, 6):
|
||||
return True
|
||||
else:
|
||||
return torch.xpu.get_device_capability(device)['has_bfloat16_conversions']
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
|
||||
if is_ixuca():
|
||||
return True
|
||||
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
|
||||
|
||||
@ -37,6 +37,7 @@ from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
from .ldm.wan.vae import WanVAE
|
||||
from .lora_convert import convert_lora
|
||||
from .model_management import load_models_gpu
|
||||
from .model_patcher import ModelPatcher
|
||||
from .t2i_adapter import adapter
|
||||
from .taesd import taesd
|
||||
from .text_encoders import aura_t5
|
||||
@ -71,7 +72,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
||||
_lora = convert_lora(_lora)
|
||||
loaded = lora.load_lora(_lora, key_map)
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
new_modelpatcher: ModelPatcher = model.clone()
|
||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||
else:
|
||||
k = ()
|
||||
|
||||
@ -781,6 +781,27 @@ def resize_to_batch_size(tensor, batch_size):
|
||||
return output
|
||||
|
||||
|
||||
def resize_list_to_batch_size(l, batch_size):
|
||||
in_batch_size = len(l)
|
||||
if in_batch_size == batch_size or in_batch_size == 0:
|
||||
return l
|
||||
|
||||
if batch_size <= 1:
|
||||
return l[:batch_size]
|
||||
|
||||
output = []
|
||||
if batch_size < in_batch_size:
|
||||
scale = (in_batch_size - 1) / (batch_size - 1)
|
||||
for i in range(batch_size):
|
||||
output.append(l[min(round(i * scale), in_batch_size - 1)])
|
||||
else:
|
||||
scale = in_batch_size / batch_size
|
||||
for i in range(batch_size):
|
||||
output.append(l[min(math.floor((i + 0.5) * scale), in_batch_size - 1)])
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def convert_sd_to(state_dict, dtype):
|
||||
keys = list(state_dict.keys())
|
||||
for k in keys:
|
||||
|
||||
@ -15,9 +15,20 @@ adapters: list[type[WeightAdapterBase]] = [
|
||||
OFTAdapter,
|
||||
BOFTAdapter,
|
||||
]
|
||||
adapter_maps: dict[str, type[WeightAdapterBase]] = {
|
||||
"LoRA": LoRAAdapter,
|
||||
"LoHa": LoHaAdapter,
|
||||
"LoKr": LoKrAdapter,
|
||||
"OFT": OFTAdapter,
|
||||
## We disable not implemented algo for now
|
||||
# "GLoRA": GLoRAAdapter,
|
||||
# "BOFT": BOFTAdapter,
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"WeightAdapterBase",
|
||||
"WeightAdapterTrainBase",
|
||||
"adapters"
|
||||
"adapters",
|
||||
"adapter_maps",
|
||||
] + [a.__name__ for a in adapters]
|
||||
|
||||
@ -133,3 +133,43 @@ def tucker_weight_from_conv(up, down, mid):
|
||||
def tucker_weight(wa, wb, t):
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
|
||||
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
|
||||
|
||||
|
||||
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
|
||||
"""
|
||||
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||
second value is higher or equal than first value.
|
||||
|
||||
examples)
|
||||
factor
|
||||
-1 2 4 8 16 ...
|
||||
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
|
||||
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
|
||||
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
|
||||
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
|
||||
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
|
||||
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
|
||||
"""
|
||||
|
||||
if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2:
|
||||
m = factor
|
||||
n = dimension // factor
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
if factor < 0:
|
||||
factor = dimension
|
||||
m, n = 1, dimension
|
||||
length = m + n
|
||||
while m < n:
|
||||
new_m = m + 1
|
||||
while dimension % new_m != 0:
|
||||
new_m += 1
|
||||
new_n = dimension // new_m
|
||||
if new_m + new_n > length or new_m > factor:
|
||||
break
|
||||
else:
|
||||
m, n = new_m, new_n
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
|
||||
@ -3,7 +3,120 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
from ..model_management import cast_to_device
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
|
||||
|
||||
|
||||
class HadaWeight(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
|
||||
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
|
||||
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
|
||||
return diff_weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
|
||||
grad_out = grad_out * scale
|
||||
temp = grad_out * (w2u @ w2d)
|
||||
grad_w1u = temp @ w1d.T
|
||||
grad_w1d = w1u.T @ temp
|
||||
|
||||
temp = grad_out * (w1u @ w1d)
|
||||
grad_w2u = temp @ w2d.T
|
||||
grad_w2d = w2u.T @ temp
|
||||
|
||||
del temp
|
||||
return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None
|
||||
|
||||
|
||||
class HadaWeightTucker(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)):
|
||||
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
|
||||
|
||||
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
|
||||
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
|
||||
|
||||
return rebuild1 * rebuild2 * scale
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
|
||||
grad_out = grad_out * scale
|
||||
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
|
||||
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
|
||||
|
||||
grad_w = rebuild * grad_out
|
||||
del rebuild
|
||||
|
||||
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
|
||||
del grad_w, temp
|
||||
|
||||
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
|
||||
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
|
||||
del grad_temp
|
||||
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
|
||||
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
|
||||
|
||||
grad_w = rebuild * grad_out
|
||||
del rebuild
|
||||
|
||||
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
|
||||
del grad_w, temp
|
||||
|
||||
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
|
||||
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
|
||||
del grad_temp
|
||||
return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None
|
||||
|
||||
|
||||
class LohaDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
# Unpack weights tuple from LoHaAdapter
|
||||
w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights
|
||||
|
||||
# Create trainable parameters
|
||||
self.hada_w1_a = torch.nn.Parameter(w1a)
|
||||
self.hada_w1_b = torch.nn.Parameter(w1b)
|
||||
self.hada_w2_a = torch.nn.Parameter(w2a)
|
||||
self.hada_w2_b = torch.nn.Parameter(w2b)
|
||||
|
||||
self.use_tucker = False
|
||||
if t1 is not None and t2 is not None:
|
||||
self.use_tucker = True
|
||||
self.hada_t1 = torch.nn.Parameter(t1)
|
||||
self.hada_t2 = torch.nn.Parameter(t2)
|
||||
else:
|
||||
# Keep the attributes for consistent access
|
||||
self.hada_t1 = None
|
||||
self.hada_t2 = None
|
||||
|
||||
# Store rank and non-trainable alpha
|
||||
self.rank = w1b.shape[0]
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||
|
||||
def __call__(self, w):
|
||||
org_dtype = w.dtype
|
||||
|
||||
scale = self.alpha / self.rank
|
||||
if self.use_tucker:
|
||||
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
else:
|
||||
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
|
||||
# Add the scaled difference to the original weight
|
||||
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
|
||||
|
||||
return weight.to(org_dtype)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
"""Calculates memory usage of the trainable parameters."""
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -14,6 +127,25 @@ class LoHaAdapter(WeightAdapterBase):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
in_dim = weight.shape[1:].numel()
|
||||
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||
torch.nn.init.normal_(mat1, 0.1)
|
||||
torch.nn.init.constant_(mat2, 0.0)
|
||||
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||
torch.nn.init.normal_(mat3, 0.1)
|
||||
torch.nn.init.normal_(mat4, 0.01)
|
||||
return LohaDiff(
|
||||
(mat1, mat2, alpha, mat3, mat4, None, None, None)
|
||||
)
|
||||
|
||||
def to_train(self):
|
||||
return LohaDiff(self.weights)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
|
||||
@ -3,7 +3,77 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
from ..model_management import cast_to_device
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
from .base import (
|
||||
WeightAdapterBase,
|
||||
WeightAdapterTrainBase,
|
||||
weight_decompose,
|
||||
factorization,
|
||||
)
|
||||
|
||||
|
||||
class LokrDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
|
||||
self.use_tucker = False
|
||||
if lokr_w1_a is not None:
|
||||
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
|
||||
rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1]
|
||||
self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a)
|
||||
self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b)
|
||||
self.w1_rebuild = True
|
||||
self.ranka = rank_a
|
||||
|
||||
if lokr_w2_a is not None:
|
||||
_, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1]
|
||||
rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1]
|
||||
self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a)
|
||||
self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b)
|
||||
if lokr_t2 is not None:
|
||||
self.use_tucker = True
|
||||
self.lokr_t2 = torch.nn.Parameter(lokr_t2)
|
||||
self.w2_rebuild = True
|
||||
self.rankb = rank_b
|
||||
|
||||
if lokr_w1 is not None:
|
||||
self.lokr_w1 = torch.nn.Parameter(lokr_w1)
|
||||
self.w1_rebuild = False
|
||||
|
||||
if lokr_w2 is not None:
|
||||
self.lokr_w2 = torch.nn.Parameter(lokr_w2)
|
||||
self.w2_rebuild = False
|
||||
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||
|
||||
@property
|
||||
def w1(self):
|
||||
if self.w1_rebuild:
|
||||
return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka)
|
||||
else:
|
||||
return self.lokr_w1
|
||||
|
||||
@property
|
||||
def w2(self):
|
||||
if self.w2_rebuild:
|
||||
if self.use_tucker:
|
||||
w2 = torch.einsum(
|
||||
'i j k l, j r, i p -> p r k l',
|
||||
self.lokr_t2,
|
||||
self.lokr_w2_b,
|
||||
self.lokr_w2_a
|
||||
)
|
||||
else:
|
||||
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
||||
return w2 * (self.alpha / self.rankb)
|
||||
else:
|
||||
return self.lokr_w2
|
||||
|
||||
def __call__(self, w):
|
||||
diff = torch.kron(self.w1, self.w2)
|
||||
return w + diff.reshape(w.shape).to(w)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -14,6 +84,20 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
in_dim = weight.shape[1:].numel()
|
||||
out1, out2 = factorization(out_dim, rank)
|
||||
in1, in2 = factorization(in_dim, rank)
|
||||
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
|
||||
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
|
||||
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
|
||||
torch.nn.init.constant_(mat1, 0.0)
|
||||
return LokrDiff(
|
||||
(mat1, mat2, alpha, None, None, None, None, None, None)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
|
||||
@ -2,11 +2,64 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
|
||||
from ..model_management import cast_to_device
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OFTDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
# Unpack weights tuple from LoHaAdapter
|
||||
blocks, rescale, alpha, _ = weights
|
||||
|
||||
# Create trainable parameters
|
||||
self.oft_blocks = torch.nn.Parameter(blocks)
|
||||
if rescale is not None:
|
||||
self.rescale = torch.nn.Parameter(rescale)
|
||||
self.rescaled = True
|
||||
else:
|
||||
self.rescaled = False
|
||||
self.block_num, self.block_size, _ = blocks.shape
|
||||
self.constraint = float(alpha)
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||
|
||||
def __call__(self, w):
|
||||
org_dtype = w.dtype
|
||||
I = torch.eye(self.block_size, device=self.oft_blocks.device)
|
||||
|
||||
## generate r
|
||||
# for Q = -Q^T
|
||||
q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
|
||||
normed_q = q
|
||||
if self.constraint:
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > self.constraint:
|
||||
normed_q = q * self.constraint / q_norm
|
||||
# use float() to prevent unsupported type
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
|
||||
## Apply chunked matmul on weight
|
||||
_, *shape = w.shape
|
||||
org_weight = w.to(dtype=r.dtype)
|
||||
org_weight = org_weight.unflatten(0, (self.block_num, self.block_size))
|
||||
# Init R=0, so add I on it to ensure the output of step0 is original model output
|
||||
weight = torch.einsum(
|
||||
"k n m, k n ... -> k m ...",
|
||||
r,
|
||||
org_weight,
|
||||
).flatten(0, 1)
|
||||
if self.rescaled:
|
||||
weight = self.rescale * weight
|
||||
return weight.to(org_dtype)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
"""Calculates memory usage of the trainable parameters."""
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
|
||||
|
||||
class OFTAdapter(WeightAdapterBase):
|
||||
name = "oft"
|
||||
|
||||
@ -14,14 +67,26 @@ class OFTAdapter(WeightAdapterBase):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
block_size, block_num = factorization(out_dim, rank)
|
||||
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
|
||||
return OFTDiff(
|
||||
(block, None, alpha, None)
|
||||
)
|
||||
|
||||
def to_train(self):
|
||||
return OFTDiff(self.weights)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
x: str,
|
||||
lora: dict[str, torch.Tensor],
|
||||
alpha: float,
|
||||
dora_scale: torch.Tensor,
|
||||
loaded_keys: set[str] = None,
|
||||
cls,
|
||||
x: str,
|
||||
lora: dict[str, torch.Tensor],
|
||||
alpha: float,
|
||||
dora_scale: torch.Tensor,
|
||||
loaded_keys: set[str] = None,
|
||||
) -> Optional["OFTAdapter"]:
|
||||
if loaded_keys is None:
|
||||
loaded_keys = set()
|
||||
@ -47,20 +112,22 @@ class OFTAdapter(WeightAdapterBase):
|
||||
return cls(loaded_keys, weights)
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
key,
|
||||
strength,
|
||||
strength_model,
|
||||
offset,
|
||||
function,
|
||||
intermediate_dtype=torch.float32,
|
||||
original_weight=None,
|
||||
self,
|
||||
weight,
|
||||
key,
|
||||
strength,
|
||||
strength_model,
|
||||
offset,
|
||||
function,
|
||||
intermediate_dtype=torch.float32,
|
||||
original_weight=None,
|
||||
):
|
||||
v = self.weights
|
||||
blocks = v[0]
|
||||
rescale = v[1]
|
||||
alpha = v[2]
|
||||
if alpha is None:
|
||||
alpha = 0
|
||||
dora_scale = v[3]
|
||||
|
||||
blocks = cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||
@ -75,7 +142,7 @@ class OFTAdapter(WeightAdapterBase):
|
||||
# for Q = -Q^T
|
||||
q = blocks - blocks.transpose(1, 2)
|
||||
normed_q = q
|
||||
if alpha > 0: # alpha in oft/boft is for constraint
|
||||
if alpha > 0: # alpha in oft/boft is for constraint
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > alpha:
|
||||
normed_q = q * alpha / q_norm
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import av
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import av
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
@ -112,7 +113,6 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
|
||||
except ImportError as exc_info:
|
||||
raise TorchAudioNotFoundError()
|
||||
|
||||
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results: list[FileLocator] = []
|
||||
@ -178,13 +178,13 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
|
||||
elif format == "mp3":
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
||||
if quality == "V0":
|
||||
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||
out_stream.codec_context.qscale = 1
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
else: #format == "flac":
|
||||
else: # format == "flac":
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
|
||||
@ -212,6 +212,7 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
|
||||
|
||||
return {"ui": {"audio": results}}
|
||||
|
||||
|
||||
class SaveAudio:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
@ -220,9 +221,9 @@ class SaveAudio:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
},
|
||||
return {"required": {"audio": ("AUDIO",),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
@ -236,6 +237,7 @@ class SaveAudio:
|
||||
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
|
||||
|
||||
|
||||
class SaveAudioMP3:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
@ -244,10 +246,10 @@ class SaveAudioMP3:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
|
||||
},
|
||||
return {"required": {"audio": ("AUDIO",),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
@ -261,6 +263,7 @@ class SaveAudioMP3:
|
||||
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
||||
|
||||
|
||||
class SaveAudioOpus:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
@ -269,10 +272,10 @@ class SaveAudioOpus:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
|
||||
},
|
||||
return {"required": {"audio": ("AUDIO",),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
@ -286,6 +289,7 @@ class SaveAudioOpus:
|
||||
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
||||
|
||||
|
||||
class PreviewAudio(SaveAudio):
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_temp_directory()
|
||||
@ -300,6 +304,44 @@ class PreviewAudio(SaveAudio):
|
||||
}
|
||||
|
||||
|
||||
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format."""
|
||||
if wav.dtype.is_floating_point:
|
||||
return wav
|
||||
elif wav.dtype == torch.int16:
|
||||
return wav.float() / (2 ** 15)
|
||||
elif wav.dtype == torch.int32:
|
||||
return wav.float() / (2 ** 31)
|
||||
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||
|
||||
|
||||
def load(filepath: str) -> tuple[torch.Tensor, int]:
|
||||
with av.open(filepath) as af:
|
||||
if not af.streams.audio:
|
||||
raise ValueError("No audio stream found in the file.")
|
||||
|
||||
stream = af.streams.audio[0]
|
||||
sr = stream.codec_context.sample_rate
|
||||
n_channels = stream.channels
|
||||
|
||||
frames = []
|
||||
length = 0
|
||||
for frame in af.decode(streams=stream.index):
|
||||
buf = torch.from_numpy(frame.to_ndarray())
|
||||
if buf.shape[0] != n_channels:
|
||||
buf = buf.view(-1, n_channels).t()
|
||||
|
||||
frames.append(buf)
|
||||
length += buf.shape[1]
|
||||
|
||||
if not frames:
|
||||
raise ValueError("No audio frames decoded.")
|
||||
|
||||
wav = torch.cat(frames, dim=1)
|
||||
wav = f32_pcm(wav)
|
||||
return wav, sr
|
||||
|
||||
|
||||
class LoadAudio:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -319,7 +361,7 @@ class LoadAudio:
|
||||
raise TorchAudioNotFoundError()
|
||||
|
||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||
waveform, sample_rate = torchaudio.load(audio_path)
|
||||
waveform, sample_rate = load(audio_path)
|
||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||
return (audio,)
|
||||
|
||||
|
||||
@ -306,6 +306,35 @@ class ExtendIntermediateSigmas:
|
||||
|
||||
return (extended_sigmas,)
|
||||
|
||||
|
||||
class SamplingPercentToSigma:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"model": (IO.MODEL, {}),
|
||||
"sampling_percent": (IO.FLOAT, {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.0001}),
|
||||
"return_actual_sigma": (IO.BOOLEAN, {"default": False, "tooltip": "Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.FLOAT,)
|
||||
RETURN_NAMES = ("sigma_value",)
|
||||
CATEGORY = "sampling/custom_sampling/sigmas"
|
||||
|
||||
FUNCTION = "get_sigma"
|
||||
|
||||
def get_sigma(self, model, sampling_percent, return_actual_sigma):
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
sigma_val = model_sampling.percent_to_sigma(sampling_percent)
|
||||
if return_actual_sigma:
|
||||
if sampling_percent == 0.0:
|
||||
sigma_val = model_sampling.sigma_max.item()
|
||||
elif sampling_percent == 1.0:
|
||||
sigma_val = model_sampling.sigma_min.item()
|
||||
return (sigma_val,)
|
||||
|
||||
|
||||
class KSamplerSelect:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -688,9 +717,10 @@ class CFGGuider:
|
||||
return (guider,)
|
||||
|
||||
class Guider_DualCFG(comfy.samplers.CFGGuider):
|
||||
def set_cfg(self, cfg1, cfg2):
|
||||
def set_cfg(self, cfg1, cfg2, nested=False):
|
||||
self.cfg1 = cfg1
|
||||
self.cfg2 = cfg2
|
||||
self.nested = nested
|
||||
|
||||
def set_conds(self, positive, middle, negative):
|
||||
middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"})
|
||||
@ -700,14 +730,20 @@ class Guider_DualCFG(comfy.samplers.CFGGuider):
|
||||
negative_cond = self.conds.get("negative", None)
|
||||
middle_cond = self.conds.get("middle", None)
|
||||
positive_cond = self.conds.get("positive", None)
|
||||
if model_options.get("disable_cfg1_optimization", False) == False:
|
||||
if math.isclose(self.cfg2, 1.0):
|
||||
negative_cond = None
|
||||
if math.isclose(self.cfg1, 1.0):
|
||||
middle_cond = None
|
||||
|
||||
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
|
||||
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
|
||||
if self.nested:
|
||||
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
|
||||
pred_text = comfy.samplers.cfg_function(self.inner_model, out[2], out[1], self.cfg1, x, timestep, model_options=model_options, cond=positive_cond, uncond=middle_cond)
|
||||
return out[0] + self.cfg2 * (pred_text - out[0])
|
||||
else:
|
||||
if model_options.get("disable_cfg1_optimization", False) == False:
|
||||
if math.isclose(self.cfg2, 1.0):
|
||||
negative_cond = None
|
||||
if math.isclose(self.cfg1, 1.0):
|
||||
middle_cond = None
|
||||
|
||||
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
|
||||
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
|
||||
|
||||
class DualCFGGuider:
|
||||
@classmethod
|
||||
@ -719,6 +755,7 @@ class DualCFGGuider:
|
||||
"negative": ("CONDITIONING", ),
|
||||
"cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
"cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
"style": (["regular", "nested"],),
|
||||
}
|
||||
}
|
||||
|
||||
@ -727,10 +764,10 @@ class DualCFGGuider:
|
||||
FUNCTION = "get_guider"
|
||||
CATEGORY = "sampling/custom_sampling/guiders"
|
||||
|
||||
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative):
|
||||
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style):
|
||||
guider = Guider_DualCFG(model)
|
||||
guider.set_conds(cond1, cond2, negative)
|
||||
guider.set_cfg(cfg_conds, cfg_cond2_negative)
|
||||
guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested"))
|
||||
return (guider,)
|
||||
|
||||
class DisableNoise:
|
||||
@ -878,6 +915,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"FlipSigmas": FlipSigmas,
|
||||
"SetFirstSigma": SetFirstSigma,
|
||||
"ExtendIntermediateSigmas": ExtendIntermediateSigmas,
|
||||
"SamplingPercentToSigma": SamplingPercentToSigma,
|
||||
|
||||
"CFGGuider": CFGGuider,
|
||||
"DualCFGGuider": DualCFGGuider,
|
||||
|
||||
@ -20,7 +20,7 @@ from comfy import node_helpers
|
||||
from comfy.cli_args import args
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy.execution_context import current_execution_context
|
||||
from comfy.weight_adapter import adapters
|
||||
from comfy.weight_adapter import adapters, adapter_maps
|
||||
from . import nodes_custom_sampler
|
||||
from .nodes_custom_sampler import Noise_RandomNoise
|
||||
|
||||
@ -41,13 +41,13 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||
|
||||
|
||||
class TrainSampler(comfy.samplers.Sampler):
|
||||
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
self.loss_callback = loss_callback
|
||||
self.batch_size = batch_size
|
||||
self.total_steps = total_steps
|
||||
self.grad_acc = grad_acc
|
||||
self.seed = seed
|
||||
self.training_dtype = training_dtype
|
||||
|
||||
@ -94,8 +94,9 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
if (i+1) % self.grad_acc == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
return torch.zeros_like(latent_image)
|
||||
|
||||
@ -421,6 +422,16 @@ class TrainLoraNode:
|
||||
"tooltip": "The batch size to use for training.",
|
||||
},
|
||||
),
|
||||
"grad_accumulation_steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 1024,
|
||||
"step": 1,
|
||||
"tooltip": "The number of gradient accumulation steps to use for training.",
|
||||
}
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
@ -480,6 +491,17 @@ class TrainLoraNode:
|
||||
["bf16", "fp32"],
|
||||
{"default": "bf16", "tooltip": "The dtype to use for lora."},
|
||||
),
|
||||
"algorithm": (
|
||||
list(adapter_maps.keys()),
|
||||
{"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."},
|
||||
),
|
||||
"gradient_checkpointing": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": True,
|
||||
"tooltip": "Use gradient checkpointing for training.",
|
||||
}
|
||||
),
|
||||
"existing_lora": (
|
||||
folder_paths.get_filename_list("loras") + ["[None]"],
|
||||
{
|
||||
@ -503,6 +525,7 @@ class TrainLoraNode:
|
||||
positive,
|
||||
batch_size,
|
||||
steps,
|
||||
grad_accumulation_steps,
|
||||
learning_rate,
|
||||
rank,
|
||||
optimizer,
|
||||
@ -510,6 +533,8 @@ class TrainLoraNode:
|
||||
seed,
|
||||
training_dtype,
|
||||
lora_dtype,
|
||||
algorithm,
|
||||
gradient_checkpointing,
|
||||
existing_lora,
|
||||
):
|
||||
mp = model.clone()
|
||||
@ -560,10 +585,8 @@ class TrainLoraNode:
|
||||
if existing_adapter is not None:
|
||||
break
|
||||
else:
|
||||
# If no existing adapter found, use LoRA
|
||||
# We will add algo option in the future
|
||||
existing_adapter = None
|
||||
adapter_cls = adapters[0]
|
||||
adapter_cls = adapter_maps[algorithm]
|
||||
|
||||
if existing_adapter is not None:
|
||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||
@ -619,8 +642,9 @@ class TrainLoraNode:
|
||||
criterion = None
|
||||
|
||||
# setup models
|
||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||
patch(m)
|
||||
if gradient_checkpointing:
|
||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||
patch(m)
|
||||
mp.model.requires_grad_(False)
|
||||
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
||||
|
||||
@ -633,7 +657,8 @@ class TrainLoraNode:
|
||||
optimizer,
|
||||
loss_callback=loss_callback,
|
||||
batch_size=batch_size,
|
||||
total_steps=steps,
|
||||
grad_acc=grad_accumulation_steps,
|
||||
total_steps=steps*grad_accumulation_steps,
|
||||
seed=seed,
|
||||
training_dtype=dtype
|
||||
)
|
||||
|
||||
@ -1,26 +1,33 @@
|
||||
from comfy.nodes import base_nodes as nodes
|
||||
from comfy import node_helpers
|
||||
import json
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import comfy.clip_vision
|
||||
import comfy.latent_formats
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.latent_formats
|
||||
import comfy.clip_vision
|
||||
import nodes
|
||||
from comfy import node_helpers
|
||||
from comfy.nodes import base_nodes as nodes
|
||||
|
||||
|
||||
class WanImageToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
return {"required": {"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"vae": ("VAE",),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
}}
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
|
||||
"start_image": ("IMAGE",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
@ -54,18 +61,18 @@ class WanImageToVideo:
|
||||
class WanFunControlToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
return {"required": {"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"vae": ("VAE",),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"control_video": ("IMAGE", ),
|
||||
}}
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
|
||||
"start_image": ("IMAGE",),
|
||||
"control_video": ("IMAGE",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
@ -82,12 +89,12 @@ class WanFunControlToVideo:
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
||||
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
concat_latent[:, 16:, :concat_latent_image.shape[2]] = concat_latent_image[:, :, :concat_latent.shape[2]]
|
||||
|
||||
if control_video is not None:
|
||||
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(control_video[:, :, :, :3])
|
||||
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
concat_latent[:, :16, :concat_latent_image.shape[2]] = concat_latent_image[:, :, :concat_latent.shape[2]]
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
|
||||
@ -100,22 +107,23 @@ class WanFunControlToVideo:
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
|
||||
class WanFirstLastFrameToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
return {"required": {"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"vae": ("VAE",),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ),
|
||||
"clip_vision_end_image": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"end_image": ("IMAGE", ),
|
||||
}}
|
||||
},
|
||||
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT",),
|
||||
"clip_vision_end_image": ("CLIP_VISION_OUTPUT",),
|
||||
"start_image": ("IMAGE",),
|
||||
"end_image": ("IMAGE",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
@ -169,18 +177,18 @@ class WanFirstLastFrameToVideo:
|
||||
class WanFunInpaintToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
return {"required": {"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"vae": ("VAE",),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"end_image": ("IMAGE", ),
|
||||
}}
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
|
||||
"start_image": ("IMAGE",),
|
||||
"end_image": ("IMAGE",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
@ -196,19 +204,19 @@ class WanFunInpaintToVideo:
|
||||
class WanVaceToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
return {"required": {"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"vae": ("VAE",),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {"control_video": ("IMAGE", ),
|
||||
"control_masks": ("MASK", ),
|
||||
"reference_image": ("IMAGE", ),
|
||||
}}
|
||||
},
|
||||
"optional": {"control_video": ("IMAGE",),
|
||||
"control_masks": ("MASK",),
|
||||
"reference_image": ("IMAGE",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent", "trim_latent")
|
||||
@ -277,11 +285,12 @@ class WanVaceToVideo:
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent, trim_latent)
|
||||
|
||||
|
||||
class TrimVideoLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}),
|
||||
return {"required": {"samples": ("LATENT",),
|
||||
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
@ -298,21 +307,22 @@ class TrimVideoLatent:
|
||||
samples_out["samples"] = s1[:, :, trim_amount:]
|
||||
return (samples_out,)
|
||||
|
||||
|
||||
class WanCameraImageToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
return {"required": {"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"vae": ("VAE",),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"camera_conditions": ("WAN_CAMERA_EMBEDDING", ),
|
||||
}}
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
|
||||
"start_image": ("IMAGE",),
|
||||
"camera_conditions": ("WAN_CAMERA_EMBEDDING",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
@ -328,7 +338,7 @@ class WanCameraImageToVideo:
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
||||
concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
concat_latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image[:, :, :concat_latent.shape[2]]
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
|
||||
@ -345,19 +355,20 @@ class WanCameraImageToVideo:
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
|
||||
class WanPhantomSubjectToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
return {"required": {"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"vae": ("VAE",),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"images": ("IMAGE", ),
|
||||
}}
|
||||
},
|
||||
"optional": {"images": ("IMAGE",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent")
|
||||
@ -383,7 +394,316 @@ class WanPhantomSubjectToVideo:
|
||||
out_latent["samples"] = latent
|
||||
return (positive, cond2, negative, out_latent)
|
||||
|
||||
|
||||
def parse_json_tracks(tracks):
|
||||
"""Parse JSON track data into a standardized format"""
|
||||
tracks_data = []
|
||||
try:
|
||||
# If tracks is a string, try to parse it as JSON
|
||||
if isinstance(tracks, str):
|
||||
parsed = json.loads(tracks.replace("'", '"'))
|
||||
tracks_data.extend(parsed)
|
||||
else:
|
||||
# If tracks is a list of strings, parse each one
|
||||
for track_str in tracks:
|
||||
parsed = json.loads(track_str.replace("'", '"'))
|
||||
tracks_data.append(parsed)
|
||||
|
||||
# Check if we have a single track (dict with x,y) or a list of tracks
|
||||
if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]:
|
||||
# Single track detected, wrap it in a list
|
||||
tracks_data = [tracks_data]
|
||||
elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]:
|
||||
# Already a list of tracks, nothing to do
|
||||
pass
|
||||
else:
|
||||
# Unexpected format
|
||||
pass
|
||||
|
||||
except json.JSONDecodeError:
|
||||
tracks_data = []
|
||||
return tracks_data
|
||||
|
||||
|
||||
def process_tracks(tracks_np: np.ndarray, frame_size: Tuple[int, int], num_frames, quant_multi: int = 8, **kwargs):
|
||||
# tracks: shape [t, h, w, 3] => samples align with 24 fps, model trained with 16 fps.
|
||||
# frame_size: tuple (W, H)
|
||||
tracks = torch.from_numpy(tracks_np).float()
|
||||
|
||||
if tracks.shape[1] == 121:
|
||||
tracks = torch.permute(tracks, (1, 0, 2, 3))
|
||||
|
||||
tracks, visibles = tracks[..., :2], tracks[..., 2:3]
|
||||
|
||||
short_edge = min(*frame_size)
|
||||
|
||||
frame_center = torch.tensor([*frame_size]).type_as(tracks) / 2
|
||||
tracks = tracks - frame_center
|
||||
|
||||
tracks = tracks / short_edge * 2
|
||||
|
||||
visibles = visibles * 2 - 1
|
||||
|
||||
trange = torch.linspace(-1, 1, tracks.shape[0]).view(-1, 1, 1, 1).expand(*visibles.shape)
|
||||
|
||||
out_ = torch.cat([trange, tracks, visibles], dim=-1).view(121, -1, 4)
|
||||
|
||||
out_0 = out_[:1]
|
||||
|
||||
out_l = out_[1:] # 121 => 120 | 1
|
||||
a = 120 // math.gcd(120, num_frames)
|
||||
b = num_frames // math.gcd(120, num_frames)
|
||||
out_l = torch.repeat_interleave(out_l, b, dim=0)[1::a] # 120 => 120 * b => 120 * b / a == F
|
||||
|
||||
final_result = torch.cat([out_0, out_l], dim=0)
|
||||
|
||||
return final_result
|
||||
|
||||
|
||||
FIXED_LENGTH = 121
|
||||
|
||||
|
||||
def pad_pts(tr):
|
||||
"""Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating."""
|
||||
pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32)
|
||||
n = pts.shape[0]
|
||||
if n < FIXED_LENGTH:
|
||||
pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32)
|
||||
pts = np.vstack((pts, pad))
|
||||
else:
|
||||
pts = pts[:FIXED_LENGTH]
|
||||
return pts.reshape(FIXED_LENGTH, 1, 3)
|
||||
|
||||
|
||||
def ind_sel(target: torch.Tensor, ind: torch.Tensor, dim: int = 1):
|
||||
"""Index selection utility function"""
|
||||
assert (
|
||||
len(ind.shape) > dim
|
||||
), "Index must have the target dim, but get dim: %d, ind shape: %s" % (dim, str(ind.shape))
|
||||
|
||||
target = target.expand(
|
||||
*tuple(
|
||||
[ind.shape[k] if target.shape[k] == 1 else -1 for k in range(dim)]
|
||||
+ [
|
||||
-1,
|
||||
]
|
||||
* (len(target.shape) - dim)
|
||||
)
|
||||
)
|
||||
|
||||
ind_pad = ind
|
||||
|
||||
if len(target.shape) > dim + 1:
|
||||
for _ in range(len(target.shape) - (dim + 1)):
|
||||
ind_pad = ind_pad.unsqueeze(-1)
|
||||
ind_pad = ind_pad.expand(*(-1,) * (dim + 1), *target.shape[(dim + 1)::])
|
||||
|
||||
return torch.gather(target, dim=dim, index=ind_pad)
|
||||
|
||||
|
||||
def merge_final(vert_attr: torch.Tensor, weight: torch.Tensor, vert_assign: torch.Tensor):
|
||||
"""Merge vertex attributes with weights"""
|
||||
target_dim = len(vert_assign.shape) - 1
|
||||
if len(vert_attr.shape) == 2:
|
||||
assert vert_attr.shape[0] > vert_assign.max()
|
||||
new_shape = [1] * target_dim + list(vert_attr.shape)
|
||||
tensor = vert_attr.reshape(new_shape)
|
||||
sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim)
|
||||
else:
|
||||
assert vert_attr.shape[1] > vert_assign.max()
|
||||
new_shape = [vert_attr.shape[0]] + [1] * (target_dim - 1) + list(vert_attr.shape[1:])
|
||||
tensor = vert_attr.reshape(new_shape)
|
||||
sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim)
|
||||
|
||||
final_attr = torch.sum(sel_attr * weight.unsqueeze(-1), dim=-2)
|
||||
return final_attr
|
||||
|
||||
|
||||
def _patch_motion_single(
|
||||
tracks: torch.FloatTensor, # (B, T, N, 4)
|
||||
vid: torch.FloatTensor, # (C, T, H, W)
|
||||
temperature: float,
|
||||
vae_divide: tuple,
|
||||
topk: int,
|
||||
):
|
||||
"""Apply motion patching based on tracks"""
|
||||
_, T, H, W = vid.shape
|
||||
N = tracks.shape[2]
|
||||
_, tracks_xy, visible = torch.split(
|
||||
tracks, [1, 2, 1], dim=-1
|
||||
) # (B, T, N, 2) | (B, T, N, 1)
|
||||
tracks_n = tracks_xy / torch.tensor([W / min(H, W), H / min(H, W)], device=tracks_xy.device)
|
||||
tracks_n = tracks_n.clamp(-1, 1)
|
||||
visible = visible.clamp(0, 1)
|
||||
|
||||
xx = torch.linspace(-W / min(H, W), W / min(H, W), W)
|
||||
yy = torch.linspace(-H / min(H, W), H / min(H, W), H)
|
||||
|
||||
grid = torch.stack(torch.meshgrid(yy, xx, indexing="ij")[::-1], dim=-1).to(
|
||||
tracks_xy.device
|
||||
)
|
||||
|
||||
tracks_pad = tracks_xy[:, 1:]
|
||||
visible_pad = visible[:, 1:]
|
||||
|
||||
visible_align = visible_pad.view(T - 1, 4, *visible_pad.shape[2:]).sum(1)
|
||||
tracks_align = (tracks_pad * visible_pad).view(T - 1, 4, *tracks_pad.shape[2:]).sum(
|
||||
1
|
||||
) / (visible_align + 1e-5)
|
||||
dist_ = (
|
||||
(tracks_align[:, None, None] - grid[None, :, :, None]).pow(2).sum(-1)
|
||||
) # T, H, W, N
|
||||
weight = torch.exp(-dist_ * temperature) * visible_align.clamp(0, 1).view(
|
||||
T - 1, 1, 1, N
|
||||
)
|
||||
vert_weight, vert_index = torch.topk(
|
||||
weight, k=min(topk, weight.shape[-1]), dim=-1
|
||||
)
|
||||
|
||||
grid_mode = "bilinear"
|
||||
point_feature = torch.nn.functional.grid_sample(
|
||||
vid.permute(1, 0, 2, 3)[:1],
|
||||
tracks_n[:, :1].type(vid.dtype),
|
||||
mode=grid_mode,
|
||||
padding_mode="zeros",
|
||||
align_corners=False,
|
||||
)
|
||||
point_feature = point_feature.squeeze(0).squeeze(1).permute(1, 0) # N, C=16
|
||||
|
||||
out_feature = merge_final(point_feature, vert_weight, vert_index).permute(3, 0, 1, 2) # T - 1, H, W, C => C, T - 1, H, W
|
||||
out_weight = vert_weight.sum(-1) # T - 1, H, W
|
||||
|
||||
# out feature -> already soft weighted
|
||||
mix_feature = out_feature + vid[:, 1:] * (1 - out_weight.clamp(0, 1))
|
||||
|
||||
out_feature_full = torch.cat([vid[:, :1], mix_feature], dim=1) # C, T, H, W
|
||||
out_mask_full = torch.cat([torch.ones_like(out_weight[:1]), out_weight], dim=0) # T, H, W
|
||||
|
||||
return out_mask_full[None].expand(vae_divide[0], -1, -1, -1), out_feature_full
|
||||
|
||||
|
||||
def patch_motion(
|
||||
tracks: torch.FloatTensor, # (B, TB, T, N, 4)
|
||||
vid: torch.FloatTensor, # (C, T, H, W)
|
||||
temperature: float = 220.0,
|
||||
vae_divide: tuple = (4, 16),
|
||||
topk: int = 2,
|
||||
):
|
||||
B = len(tracks)
|
||||
|
||||
# Process each batch separately
|
||||
out_masks = []
|
||||
out_features = []
|
||||
|
||||
for b in range(B):
|
||||
mask, feature = _patch_motion_single(
|
||||
tracks[b], # (T, N, 4)
|
||||
vid[b], # (C, T, H, W)
|
||||
temperature,
|
||||
vae_divide,
|
||||
topk
|
||||
)
|
||||
out_masks.append(mask)
|
||||
out_features.append(feature)
|
||||
|
||||
# Stack results: (B, C, T, H, W)
|
||||
out_mask_full = torch.stack(out_masks, dim=0)
|
||||
out_feature_full = torch.stack(out_features, dim=0)
|
||||
|
||||
return out_mask_full, out_feature_full
|
||||
|
||||
|
||||
class WanTrackToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"vae": ("VAE",),
|
||||
"tracks": ("STRING", {"multiline": True, "default": "[]"}),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||
"topk": ("INT", {"default": 2, "min": 1, "max": 10}),
|
||||
"start_image": ("IMAGE",),
|
||||
},
|
||||
"optional": {
|
||||
"clip_vision_output": ("CLIP_VISION_OUTPUT",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, tracks, width, height, length, batch_size,
|
||||
temperature, topk, start_image=None, clip_vision_output=None):
|
||||
|
||||
tracks_data = parse_json_tracks(tracks)
|
||||
|
||||
if not tracks_data:
|
||||
return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output)
|
||||
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device())
|
||||
|
||||
if isinstance(tracks_data[0][0], dict):
|
||||
tracks_data = [tracks_data]
|
||||
|
||||
processed_tracks = []
|
||||
for batch in tracks_data:
|
||||
arrs = []
|
||||
for track in batch:
|
||||
pts = pad_pts(track)
|
||||
arrs.append(pts)
|
||||
|
||||
tracks_np = np.stack(arrs, axis=0)
|
||||
processed_tracks.append(process_tracks(tracks_np, (width, height), length - 1).unsqueeze(0))
|
||||
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:batch_size].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
videos = torch.ones((start_image.shape[0], length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
|
||||
for i in range(start_image.shape[0]):
|
||||
videos[i, 0] = start_image[i]
|
||||
|
||||
latent_videos = []
|
||||
videos = comfy.utils.resize_to_batch_size(videos, batch_size)
|
||||
for i in range(batch_size):
|
||||
latent_videos += [vae.encode(videos[i, :, :, :, :3])]
|
||||
y = torch.cat(latent_videos, dim=0)
|
||||
|
||||
# Scale latent since patch_motion is non-linear
|
||||
y = comfy.latent_formats.Wan21().process_in(y)
|
||||
|
||||
processed_tracks = comfy.utils.resize_list_to_batch_size(processed_tracks, batch_size)
|
||||
res = patch_motion(
|
||||
processed_tracks, y, temperature=temperature, topk=topk, vae_divide=(4, 16)
|
||||
)
|
||||
|
||||
mask, concat_latent_image = res
|
||||
concat_latent_image = comfy.latent_formats.Wan21().process_out(concat_latent_image)
|
||||
mask = -mask + 1.0 # Invert mask to match expected format
|
||||
positive = node_helpers.conditioning_set_values(positive,
|
||||
{"concat_mask": mask,
|
||||
"concat_latent_image": concat_latent_image})
|
||||
negative = node_helpers.conditioning_set_values(negative,
|
||||
{"concat_mask": mask,
|
||||
"concat_latent_image": concat_latent_image})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanTrackToVideo": WanTrackToVideo,
|
||||
"WanImageToVideo": WanImageToVideo,
|
||||
"WanFunControlToVideo": WanFunControlToVideo,
|
||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "comfyui"
|
||||
version = "0.3.44"
|
||||
version = "0.3.45"
|
||||
description = "An installable version of ComfyUI"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
@ -18,9 +18,9 @@ classifiers = [
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"comfyui-frontend-package",
|
||||
"comfyui-workflow-templates",
|
||||
"comfyui-embedded-docs",
|
||||
"comfyui-frontend-package>=1.23.4",
|
||||
"comfyui-workflow-templates>=0.1.40",
|
||||
"comfyui-embedded-docs>=0.2.4",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchdiffeq>=0.2.3",
|
||||
|
||||
@ -2,7 +2,7 @@ import argparse
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, mock_open
|
||||
|
||||
from comfy.app.frontend_management import (
|
||||
FrontendManager,
|
||||
@ -171,3 +171,36 @@ def test_init_frontend_fallback_on_error():
|
||||
# Assert
|
||||
assert frontend_path == "/default/path"
|
||||
mock_check.assert_called_once()
|
||||
|
||||
|
||||
def test_get_frontend_version():
|
||||
# Arrange
|
||||
expected_version = "1.25.0"
|
||||
mock_requirements_content = """torch
|
||||
torchsde
|
||||
comfyui-frontend-package==1.25.0
|
||||
other-package==1.0.0
|
||||
numpy"""
|
||||
|
||||
# Act
|
||||
with patch("builtins.open", mock_open(read_data=mock_requirements_content)):
|
||||
version = FrontendManager.get_required_frontend_version()
|
||||
|
||||
# Assert
|
||||
assert version == expected_version
|
||||
|
||||
|
||||
def test_get_frontend_version_invalid_semver():
|
||||
# Arrange
|
||||
mock_requirements_content = """torch
|
||||
torchsde
|
||||
comfyui-frontend-package==1.29.3.75
|
||||
other-package==1.0.0
|
||||
numpy"""
|
||||
|
||||
# Act
|
||||
with patch("builtins.open", mock_open(read_data=mock_requirements_content)):
|
||||
version = FrontendManager.get_required_frontend_version()
|
||||
|
||||
# Assert
|
||||
assert version is None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user