Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2025-07-25 12:48:05 -07:00
commit 3684cff31b
25 changed files with 1070 additions and 182 deletions

View File

@ -17,6 +17,7 @@ jobs:
- name: Check for Windows line endings (CRLF) - name: Check for Windows line endings (CRLF)
run: | run: |
# Get the list of changed files in the PR # Get the list of changed files in the PR
git merge origin/${{ github.base_ref }} --no-edit
CHANGED_FILES=$(git diff --name-only origin/${{ github.base_ref }}..HEAD) CHANGED_FILES=$(git diff --name-only origin/${{ github.base_ref }}..HEAD)
# Flag to track if CRLF is found # Flag to track if CRLF is found

View File

@ -42,6 +42,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- Image Editing Models - Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
- Video Models - Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/) - [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
@ -1574,6 +1575,14 @@ If you need to use the legacy frontend for any reason, you can access it using t
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend). This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
### Iluvatar Corex
For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method:
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
2. Launch ComfyUI by running `python main.py`
## Community ## Community
[Discord](https://comfy.org/discord): Try the #help or #feedback channels. [Discord](https://comfy.org/discord): Try the #help or #feedback channels.

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.44" __version__ = "0.3.45"

View File

@ -7,6 +7,7 @@ import os
import re import re
import tempfile import tempfile
import zipfile import zipfile
import importlib.metadata
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
@ -22,10 +23,12 @@ REQUEST_TIMEOUT = 10 # seconds
def check_frontend_version(): def check_frontend_version():
"""the thing this does makes no sense, so it got cut"""
return None return None
def frontend_install_warning_message() -> str: def frontend_install_warning_message() -> str:
"""the end user never needs to be messaged this"""
return "" return ""
@ -135,6 +138,15 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager: class FrontendManager:
CUSTOM_FRONTENDS_ROOT = add_model_folder_path("web_custom_versions", extensions=set()) CUSTOM_FRONTENDS_ROOT = add_model_folder_path("web_custom_versions", extensions=set())
@classmethod
def get_required_frontend_version(cls) -> str:
"""Get the required frontend package version."""
try:
# this isn't used the way it says
return importlib.metadata.version("comfyui_frontend_package")
except Exception as exc_info:
return "1.23.4"
@classmethod @classmethod
def default_frontend_path(cls) -> str: def default_frontend_path(cls) -> str:
try: try:

View File

@ -44,7 +44,8 @@ def _create_parser() -> EnhancedConfigArgParser:
help="Automatically launch ComfyUI in the default browser.") help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID",
help="Set the id of the cuda device this instance will use.") help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group() cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", cm_group.add_argument("--cuda-malloc", action="store_true",
help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")

View File

@ -165,6 +165,7 @@ class Configuration(dict):
comfy_api_base (str): Set the base URL for the ComfyUI API. (default: https://api.comfy.org) comfy_api_base (str): Set the base URL for the ComfyUI API. (default: https://api.comfy.org)
database_url (str): Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'. database_url (str): Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.
whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled. whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled.
default_device (Optional[int]): Set the id of the default device, all other devices will stay visible.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -281,6 +282,7 @@ class Configuration(dict):
self.front_end_root: Optional[str] = None self.front_end_root: Optional[str] = None
self.comfy_api_base: str = "https://api.comfy.org" self.comfy_api_base: str = "https://api.comfy.org"
self.database_url: str = db_config() self.database_url: str = db_config()
self.default_device: Optional[int] = None
for key, value in kwargs.items(): for key, value in kwargs.items():
self[key] = value self[key] = value

View File

@ -82,7 +82,8 @@ if not args.cuda_malloc:
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) spec.loader.exec_module(module)
version = module.__version__ version = module.__version__
if int(version[0]) >= 2: # enable by default for torch version 2.0 and up
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
args.cuda_malloc = cuda_malloc_supported() args.cuda_malloc = cuda_malloc_supported()
except: except:
pass pass

View File

@ -49,6 +49,15 @@ logging.getLogger("alembic.runtime.migration").setLevel(logging.WARNING)
from ..cli_args import args from ..cli_args import args
if args.default_device is not None:
default_dev = args.default_device
devices = list(range(32))
devices.remove(default_dev)
devices.insert(0, default_dev)
devices = ','.join(map(str, devices))
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
os.environ['HIP_VISIBLE_DEVICES'] = str(devices)
if args.cuda_device is not None: if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)

View File

@ -593,6 +593,7 @@ class PromptServer(ExecutorToClientProgress):
ram_free = model_management.get_free_memory(cpu_device) ram_free = model_management.get_free_memory(cpu_device)
vram_total, torch_vram_total = get_total_memory(device, torch_total_too=True) vram_total, torch_vram_total = get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = get_free_memory(device, torch_free_too=True) vram_free, torch_vram_free = get_free_memory(device, torch_free_too=True)
required_frontend_version = FrontendManager.get_required_frontend_version()
system_stats = { system_stats = {
"system": { "system": {
@ -600,6 +601,7 @@ class PromptServer(ExecutorToClientProgress):
"ram_total": ram_total, "ram_total": ram_total,
"ram_free": ram_free, "ram_free": ram_free,
"comfyui_version": __version__, "comfyui_version": __version__,
"required_frontend_version": required_frontend_version,
"python_version": sys.version, "python_version": sys.version,
"pytorch_version": torch_version, "pytorch_version": torch_version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded", "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",

View File

@ -1244,42 +1244,20 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
return x_next return x_next
@torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
extra_args = {} if extra_args is None else extra_args
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
sigma_hat = sigmas[i]
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, temp[0])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
# Euler method
x = denoised + d * sigmas[i + 1]
return x
@torch.no_grad() @torch.no_grad()
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps.""" """Ancestral sampling with Euler method steps (CFG++)."""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None) seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
temp = [0] model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
uncond_denoised = None
def post_cfg_function(args): def post_cfg_function(args):
temp[0] = args["uncond_denoised"] nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"] return args["denoised"]
model_options = extra_args.get("model_options", {}).copy() model_options = extra_args.get("model_options", {}).copy()
@ -1288,17 +1266,33 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args) denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None: if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], temp[0]) if sigmas[i + 1] == 0:
# Euler method # Denoising step
x = denoised + d * sigma_down x = denoised
if sigmas[i + 1] > 0: else:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
# DDIM stochastic sampling
sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
sigma_down = alpha_t * sigma_down
# Euler method
x = alpha_t * denoised + sigma_down * d
if eta > 0 and s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x return x
@torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Euler method steps (CFG++)."""
return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps.""" """Ancestral sampling with DPM-Solver++(2S) second-order steps."""

View File

@ -52,15 +52,6 @@ class RMS_norm(nn.Module):
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0) x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module): class Resample(nn.Module):
def __init__(self, dim, mode): def __init__(self, dim, mode):
@ -73,11 +64,11 @@ class Resample(nn.Module):
# layers # layers
if mode == 'upsample2d': if mode == 'upsample2d':
self.resample = nn.Sequential( self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'), nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1)) ops.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d': elif mode == 'upsample3d':
self.resample = nn.Sequential( self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'), nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1)) ops.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d( self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))

View File

@ -125,7 +125,7 @@ if args.directml is not None:
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default. lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
try: try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error import intel_extension_for_pytorch as ipex # pylint: disable=import-error, noqa: F401
_ = torch.xpu.device_count() _ = torch.xpu.device_count()
xpu_available = xpu_available or torch.xpu.is_available() xpu_available = xpu_available or torch.xpu.is_available()
@ -155,6 +155,11 @@ try:
except: except:
mlu_available = False mlu_available = False
try:
ixuca_available = hasattr(torch, "corex")
except:
ixuca_available = False
if args.cpu: if args.cpu:
cpu_state = CPUState.CPU cpu_state = CPUState.CPU
@ -181,6 +186,12 @@ def is_mlu():
return True return True
return False return False
def is_ixuca():
global ixuca_available
if ixuca_available:
return True
return False
def get_torch_device(): def get_torch_device():
global directml_device global directml_device
@ -220,7 +231,11 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_total = 1024 * 1024 * 1024 # TODO mem_total = 1024 * 1024 * 1024 # TODO
mem_total_torch = mem_total mem_total_torch = mem_total
elif is_intel_xpu(): elif is_intel_xpu():
mem_total = torch.xpu.get_device_properties(dev).total_memory stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_reserved
mem_total = mem_total_xpu
elif is_ascend_npu(): elif is_ascend_npu():
stats = torch.npu.memory_stats(dev) stats = torch.npu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
@ -334,7 +349,7 @@ try:
if torch_version_numeric[0] >= 2: if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if is_intel_xpu() or is_ascend_npu() or is_mlu(): if is_intel_xpu() or is_ascend_npu() or is_mlu() or is_ixuca():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
except: except:
@ -352,7 +367,10 @@ try:
logging.info("ROCm version: {}".format(rocm_version)) logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950 if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 8):
if any((a in arch) for a in ["gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
@ -430,6 +448,8 @@ def get_torch_device_name(device):
except: except:
allocator_backend = "" allocator_backend = ""
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend) return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
elif device.type == "xpu":
return "{} {}".format(device, torch.xpu.get_device_name(device))
else: else:
return "{}".format(device.type) return "{}".format(device.type)
elif is_intel_xpu(): elif is_intel_xpu():
@ -984,6 +1004,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
return d return d
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32 # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device): if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
return d return d
@ -1044,7 +1065,7 @@ def device_supports_non_blocking(device):
if is_device_mps(device): if is_device_mps(device):
return False # pytorch bug? mps doesn't support non blocking return False # pytorch bug? mps doesn't support non blocking
if is_intel_xpu(): if is_intel_xpu():
return False return True
if args.deterministic: # TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews) if args.deterministic: # TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
return False return False
if directml_device: if directml_device:
@ -1087,6 +1108,8 @@ def get_offload_stream(device):
stream_counter = (stream_counter + 1) % len(ss) stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device): if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream()) ss[stream_counter].wait_stream(torch.cuda.current_stream())
elif is_device_xpu(device):
ss[stream_counter].wait_stream(torch.xpu.current_stream())
stream_counters[device] = stream_counter stream_counters[device] = stream_counter
return s return s
elif is_device_cuda(device): elif is_device_cuda(device):
@ -1098,6 +1121,15 @@ def get_offload_stream(device):
stream_counter = (stream_counter + 1) % len(ss) stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter stream_counters[device] = stream_counter
return s return s
elif is_device_xpu(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.xpu.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None return None
@ -1106,6 +1138,8 @@ def sync_stream(device, stream):
return return
if is_device_cuda(device): if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream) torch.cuda.current_stream().wait_stream(stream)
elif is_device_xpu(device):
torch.xpu.current_stream().wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
@ -1152,6 +1186,8 @@ def xformers_enabled():
return False return False
if is_mlu(): if is_mlu():
return False return False
if is_ixuca():
return False
if directml_device: if directml_device:
return False return False
return XFORMERS_IS_AVAILABLE return XFORMERS_IS_AVAILABLE
@ -1202,6 +1238,8 @@ def pytorch_attention_flash_attention():
return True return True
if is_amd(): if is_amd():
return True # if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention return True # if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
if is_ixuca():
return True
return False return False
@ -1234,8 +1272,8 @@ def get_free_memory(dev=None, torch_free_too=False):
stats = torch.xpu.memory_stats(dev) # pylint: disable=no-member stats = torch.xpu.memory_stats(dev) # pylint: disable=no-member
mem_active = stats['active_bytes.all.current'] mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
mem_free_torch = mem_reserved - mem_active
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_xpu + mem_free_torch mem_free_total = mem_free_xpu + mem_free_torch
elif is_ascend_npu(): elif is_ascend_npu():
stats = torch.npu.memory_stats(dev) stats = torch.npu.memory_stats(dev)
@ -1289,6 +1327,8 @@ def is_device_cpu(device):
def is_device_mps(device): def is_device_mps(device):
return is_device_type(device, 'mps') return is_device_type(device, 'mps')
def is_device_xpu(device):
return is_device_type(device, 'xpu')
def is_device_cuda(device): def is_device_cuda(device):
return is_device_type(device, 'cuda') return is_device_type(device, 'cuda')
@ -1323,7 +1363,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return False return False
if is_intel_xpu(): if is_intel_xpu():
return True if torch_version_numeric < (2, 3):
return True
else:
return torch.xpu.get_device_properties(device).has_fp16
if is_ascend_npu(): if is_ascend_npu():
return True return True
@ -1331,6 +1374,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_mlu(): if is_mlu():
return True return True
if is_ixuca():
return True
if is_amd(): if is_amd():
return True return True
try: try:
@ -1394,11 +1440,17 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False return False
if is_intel_xpu(): if is_intel_xpu():
return True if torch_version_numeric < (2, 6):
return True
else:
return torch.xpu.get_device_capability(device)['has_bfloat16_conversions']
if is_ascend_npu(): if is_ascend_npu():
return True return True
if is_ixuca():
return True
if is_amd(): if is_amd():
arch = torch.cuda.get_device_properties(device).gcnArchName arch = torch.cuda.get_device_properties(device).gcnArchName
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16 if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16

View File

@ -37,6 +37,7 @@ from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.wan.vae import WanVAE from .ldm.wan.vae import WanVAE
from .lora_convert import convert_lora from .lora_convert import convert_lora
from .model_management import load_models_gpu from .model_management import load_models_gpu
from .model_patcher import ModelPatcher
from .t2i_adapter import adapter from .t2i_adapter import adapter
from .taesd import taesd from .taesd import taesd
from .text_encoders import aura_t5 from .text_encoders import aura_t5
@ -71,7 +72,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
_lora = convert_lora(_lora) _lora = convert_lora(_lora)
loaded = lora.load_lora(_lora, key_map) loaded = lora.load_lora(_lora, key_map)
if model is not None: if model is not None:
new_modelpatcher = model.clone() new_modelpatcher: ModelPatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model) k = new_modelpatcher.add_patches(loaded, strength_model)
else: else:
k = () k = ()

View File

@ -781,6 +781,27 @@ def resize_to_batch_size(tensor, batch_size):
return output return output
def resize_list_to_batch_size(l, batch_size):
in_batch_size = len(l)
if in_batch_size == batch_size or in_batch_size == 0:
return l
if batch_size <= 1:
return l[:batch_size]
output = []
if batch_size < in_batch_size:
scale = (in_batch_size - 1) / (batch_size - 1)
for i in range(batch_size):
output.append(l[min(round(i * scale), in_batch_size - 1)])
else:
scale = in_batch_size / batch_size
for i in range(batch_size):
output.append(l[min(math.floor((i + 0.5) * scale), in_batch_size - 1)])
return output
def convert_sd_to(state_dict, dtype): def convert_sd_to(state_dict, dtype):
keys = list(state_dict.keys()) keys = list(state_dict.keys())
for k in keys: for k in keys:

View File

@ -15,9 +15,20 @@ adapters: list[type[WeightAdapterBase]] = [
OFTAdapter, OFTAdapter,
BOFTAdapter, BOFTAdapter,
] ]
adapter_maps: dict[str, type[WeightAdapterBase]] = {
"LoRA": LoRAAdapter,
"LoHa": LoHaAdapter,
"LoKr": LoKrAdapter,
"OFT": OFTAdapter,
## We disable not implemented algo for now
# "GLoRA": GLoRAAdapter,
# "BOFT": BOFTAdapter,
}
__all__ = [ __all__ = [
"WeightAdapterBase", "WeightAdapterBase",
"WeightAdapterTrainBase", "WeightAdapterTrainBase",
"adapters" "adapters",
"adapter_maps",
] + [a.__name__ for a in adapters] ] + [a.__name__ for a in adapters]

View File

@ -133,3 +133,43 @@ def tucker_weight_from_conv(up, down, mid):
def tucker_weight(wa, wb, t): def tucker_weight(wa, wb, t):
temp = torch.einsum("i j ..., j r -> i r ...", t, wb) temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
return torch.einsum("i j ..., i r -> r j ...", temp, wa) return torch.einsum("i j ..., i r -> r j ...", temp, wa)
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
"""
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
examples)
factor
-1 2 4 8 16 ...
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
"""
if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2:
m = factor
n = dimension // factor
if m > n:
n, m = m, n
return m, n
if factor < 0:
factor = dimension
m, n = 1, dimension
length = m + n
while m < n:
new_m = m + 1
while dimension % new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m > factor:
break
else:
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n

View File

@ -3,7 +3,120 @@ from typing import Optional
import torch import torch
from ..model_management import cast_to_device from ..model_management import cast_to_device
from .base import WeightAdapterBase, weight_decompose from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
class HadaWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
return diff_weight
@staticmethod
def backward(ctx, grad_out):
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = grad_out * (w2u @ w2d)
grad_w1u = temp @ w1d.T
grad_w1d = w1u.T @ temp
temp = grad_out * (w1u @ w1d)
grad_w2u = temp @ w2d.T
grad_w2d = w2u.T @ temp
del temp
return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None
class HadaWeightTucker(torch.autograd.Function):
@staticmethod
def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)):
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
return rebuild1 * rebuild2 * scale
@staticmethod
def backward(ctx, grad_out):
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
grad_w = rebuild * grad_out
del rebuild
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
del grad_w, temp
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
del grad_temp
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
grad_w = rebuild * grad_out
del rebuild
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
del grad_w, temp
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
del grad_temp
return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None
class LohaDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
# Unpack weights tuple from LoHaAdapter
w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights
# Create trainable parameters
self.hada_w1_a = torch.nn.Parameter(w1a)
self.hada_w1_b = torch.nn.Parameter(w1b)
self.hada_w2_a = torch.nn.Parameter(w2a)
self.hada_w2_b = torch.nn.Parameter(w2b)
self.use_tucker = False
if t1 is not None and t2 is not None:
self.use_tucker = True
self.hada_t1 = torch.nn.Parameter(t1)
self.hada_t2 = torch.nn.Parameter(t2)
else:
# Keep the attributes for consistent access
self.hada_t1 = None
self.hada_t2 = None
# Store rank and non-trainable alpha
self.rank = w1b.shape[0]
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
def __call__(self, w):
org_dtype = w.dtype
scale = self.alpha / self.rank
if self.use_tucker:
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
else:
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
# Add the scaled difference to the original weight
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
return weight.to(org_dtype)
def passive_memory_usage(self):
"""Calculates memory usage of the trainable parameters."""
return sum(param.numel() * param.element_size() for param in self.parameters())
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -14,6 +127,25 @@ class LoHaAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys self.loaded_keys = loaded_keys
self.weights = weights self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
torch.nn.init.normal_(mat1, 0.1)
torch.nn.init.constant_(mat2, 0.0)
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
torch.nn.init.normal_(mat3, 0.1)
torch.nn.init.normal_(mat4, 0.01)
return LohaDiff(
(mat1, mat2, alpha, mat3, mat4, None, None, None)
)
def to_train(self):
return LohaDiff(self.weights)
@classmethod @classmethod
def load( def load(
cls, cls,

View File

@ -3,7 +3,77 @@ from typing import Optional
import torch import torch
from ..model_management import cast_to_device from ..model_management import cast_to_device
from .base import WeightAdapterBase, weight_decompose from .base import (
WeightAdapterBase,
WeightAdapterTrainBase,
weight_decompose,
factorization,
)
class LokrDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
self.use_tucker = False
if lokr_w1_a is not None:
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1]
self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a)
self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b)
self.w1_rebuild = True
self.ranka = rank_a
if lokr_w2_a is not None:
_, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1]
rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1]
self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a)
self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b)
if lokr_t2 is not None:
self.use_tucker = True
self.lokr_t2 = torch.nn.Parameter(lokr_t2)
self.w2_rebuild = True
self.rankb = rank_b
if lokr_w1 is not None:
self.lokr_w1 = torch.nn.Parameter(lokr_w1)
self.w1_rebuild = False
if lokr_w2 is not None:
self.lokr_w2 = torch.nn.Parameter(lokr_w2)
self.w2_rebuild = False
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
@property
def w1(self):
if self.w1_rebuild:
return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka)
else:
return self.lokr_w1
@property
def w2(self):
if self.w2_rebuild:
if self.use_tucker:
w2 = torch.einsum(
'i j k l, j r, i p -> p r k l',
self.lokr_t2,
self.lokr_w2_b,
self.lokr_w2_a
)
else:
w2 = self.lokr_w2_a @ self.lokr_w2_b
return w2 * (self.alpha / self.rankb)
else:
return self.lokr_w2
def __call__(self, w):
diff = torch.kron(self.w1, self.w2)
return w + diff.reshape(w.shape).to(w)
def passive_memory_usage(self):
return sum(param.numel() * param.element_size() for param in self.parameters())
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -14,6 +84,20 @@ class LoKrAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys self.loaded_keys = loaded_keys
self.weights = weights self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
out1, out2 = factorization(out_dim, rank)
in1, in2 = factorization(in_dim, rank)
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
torch.nn.init.constant_(mat1, 0.0)
return LokrDiff(
(mat1, mat2, alpha, None, None, None, None, None, None)
)
@classmethod @classmethod
def load( def load(
cls, cls,

View File

@ -2,11 +2,64 @@ import logging
from typing import Optional from typing import Optional
import torch import torch
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
from ..model_management import cast_to_device from ..model_management import cast_to_device
from .base import WeightAdapterBase, weight_decompose
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OFTDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
# Unpack weights tuple from LoHaAdapter
blocks, rescale, alpha, _ = weights
# Create trainable parameters
self.oft_blocks = torch.nn.Parameter(blocks)
if rescale is not None:
self.rescale = torch.nn.Parameter(rescale)
self.rescaled = True
else:
self.rescaled = False
self.block_num, self.block_size, _ = blocks.shape
self.constraint = float(alpha)
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
def __call__(self, w):
org_dtype = w.dtype
I = torch.eye(self.block_size, device=self.oft_blocks.device)
## generate r
# for Q = -Q^T
q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
normed_q = q
if self.constraint:
q_norm = torch.norm(q) + 1e-8
if q_norm > self.constraint:
normed_q = q * self.constraint / q_norm
# use float() to prevent unsupported type
r = (I + normed_q) @ (I - normed_q).float().inverse()
## Apply chunked matmul on weight
_, *shape = w.shape
org_weight = w.to(dtype=r.dtype)
org_weight = org_weight.unflatten(0, (self.block_num, self.block_size))
# Init R=0, so add I on it to ensure the output of step0 is original model output
weight = torch.einsum(
"k n m, k n ... -> k m ...",
r,
org_weight,
).flatten(0, 1)
if self.rescaled:
weight = self.rescale * weight
return weight.to(org_dtype)
def passive_memory_usage(self):
"""Calculates memory usage of the trainable parameters."""
return sum(param.numel() * param.element_size() for param in self.parameters())
class OFTAdapter(WeightAdapterBase): class OFTAdapter(WeightAdapterBase):
name = "oft" name = "oft"
@ -14,14 +67,26 @@ class OFTAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys self.loaded_keys = loaded_keys
self.weights = weights self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
block_size, block_num = factorization(out_dim, rank)
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
return OFTDiff(
(block, None, alpha, None)
)
def to_train(self):
return OFTDiff(self.weights)
@classmethod @classmethod
def load( def load(
cls, cls,
x: str, x: str,
lora: dict[str, torch.Tensor], lora: dict[str, torch.Tensor],
alpha: float, alpha: float,
dora_scale: torch.Tensor, dora_scale: torch.Tensor,
loaded_keys: set[str] = None, loaded_keys: set[str] = None,
) -> Optional["OFTAdapter"]: ) -> Optional["OFTAdapter"]:
if loaded_keys is None: if loaded_keys is None:
loaded_keys = set() loaded_keys = set()
@ -47,20 +112,22 @@ class OFTAdapter(WeightAdapterBase):
return cls(loaded_keys, weights) return cls(loaded_keys, weights)
def calculate_weight( def calculate_weight(
self, self,
weight, weight,
key, key,
strength, strength,
strength_model, strength_model,
offset, offset,
function, function,
intermediate_dtype=torch.float32, intermediate_dtype=torch.float32,
original_weight=None, original_weight=None,
): ):
v = self.weights v = self.weights
blocks = v[0] blocks = v[0]
rescale = v[1] rescale = v[1]
alpha = v[2] alpha = v[2]
if alpha is None:
alpha = 0
dora_scale = v[3] dora_scale = v[3]
blocks = cast_to_device(blocks, weight.device, intermediate_dtype) blocks = cast_to_device(blocks, weight.device, intermediate_dtype)
@ -75,7 +142,7 @@ class OFTAdapter(WeightAdapterBase):
# for Q = -Q^T # for Q = -Q^T
q = blocks - blocks.transpose(1, 2) q = blocks - blocks.transpose(1, 2)
normed_q = q normed_q = q
if alpha > 0: # alpha in oft/boft is for constraint if alpha > 0: # alpha in oft/boft is for constraint
q_norm = torch.norm(q) + 1e-8 q_norm = torch.norm(q) + 1e-8
if q_norm > alpha: if q_norm > alpha:
normed_q = q * alpha / q_norm normed_q = q * alpha / q_norm

View File

@ -1,11 +1,12 @@
from __future__ import annotations from __future__ import annotations
import av
import hashlib import hashlib
import io import io
import json import json
import os import os
import random import random
import av
import torch import torch
import comfy.model_management import comfy.model_management
@ -112,7 +113,6 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
except ImportError as exc_info: except ImportError as exc_info:
raise TorchAudioNotFoundError() raise TorchAudioNotFoundError()
filename_prefix += self.prefix_append filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results: list[FileLocator] = [] results: list[FileLocator] = []
@ -178,13 +178,13 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
elif format == "mp3": elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate) out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
if quality == "V0": if quality == "V0":
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool # TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1 out_stream.codec_context.qscale = 1
elif quality == "128k": elif quality == "128k":
out_stream.bit_rate = 128000 out_stream.bit_rate = 128000
elif quality == "320k": elif quality == "320k":
out_stream.bit_rate = 320000 out_stream.bit_rate = 320000
else: #format == "flac": else: # format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate) out_stream = output_container.add_stream("flac", rate=sample_rate)
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo') frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
@ -212,6 +212,7 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
return {"ui": {"audio": results}} return {"ui": {"audio": results}}
class SaveAudio: class SaveAudio:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
@ -220,9 +221,9 @@ class SaveAudio:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ), return {"required": {"audio": ("AUDIO",),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
}, },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
} }
@ -236,6 +237,7 @@ class SaveAudio:
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None): def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo) return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
class SaveAudioMP3: class SaveAudioMP3:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
@ -244,10 +246,10 @@ class SaveAudioMP3:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ), return {"required": {"audio": ("AUDIO",),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["V0", "128k", "320k"], {"default": "V0"}), "quality": (["V0", "128k", "320k"], {"default": "V0"}),
}, },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
} }
@ -261,6 +263,7 @@ class SaveAudioMP3:
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"): def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality) return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
class SaveAudioOpus: class SaveAudioOpus:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
@ -269,10 +272,10 @@ class SaveAudioOpus:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ), return {"required": {"audio": ("AUDIO",),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}), "quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
}, },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
} }
@ -286,6 +289,7 @@ class SaveAudioOpus:
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"): def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality) return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
class PreviewAudio(SaveAudio): class PreviewAudio(SaveAudio):
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_temp_directory() self.output_dir = folder_paths.get_temp_directory()
@ -300,6 +304,44 @@ class PreviewAudio(SaveAudio):
} }
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format."""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / (2 ** 15)
elif wav.dtype == torch.int32:
return wav.float() / (2 ** 31)
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def load(filepath: str) -> tuple[torch.Tensor, int]:
with av.open(filepath) as af:
if not af.streams.audio:
raise ValueError("No audio stream found in the file.")
stream = af.streams.audio[0]
sr = stream.codec_context.sample_rate
n_channels = stream.channels
frames = []
length = 0
for frame in af.decode(streams=stream.index):
buf = torch.from_numpy(frame.to_ndarray())
if buf.shape[0] != n_channels:
buf = buf.view(-1, n_channels).t()
frames.append(buf)
length += buf.shape[1]
if not frames:
raise ValueError("No audio frames decoded.")
wav = torch.cat(frames, dim=1)
wav = f32_pcm(wav)
return wav, sr
class LoadAudio: class LoadAudio:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -319,7 +361,7 @@ class LoadAudio:
raise TorchAudioNotFoundError() raise TorchAudioNotFoundError()
audio_path = folder_paths.get_annotated_filepath(audio) audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path) waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio,) return (audio,)

View File

@ -306,6 +306,35 @@ class ExtendIntermediateSigmas:
return (extended_sigmas,) return (extended_sigmas,)
class SamplingPercentToSigma:
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"model": (IO.MODEL, {}),
"sampling_percent": (IO.FLOAT, {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.0001}),
"return_actual_sigma": (IO.BOOLEAN, {"default": False, "tooltip": "Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."}),
}
}
RETURN_TYPES = (IO.FLOAT,)
RETURN_NAMES = ("sigma_value",)
CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "get_sigma"
def get_sigma(self, model, sampling_percent, return_actual_sigma):
model_sampling = model.get_model_object("model_sampling")
sigma_val = model_sampling.percent_to_sigma(sampling_percent)
if return_actual_sigma:
if sampling_percent == 0.0:
sigma_val = model_sampling.sigma_max.item()
elif sampling_percent == 1.0:
sigma_val = model_sampling.sigma_min.item()
return (sigma_val,)
class KSamplerSelect: class KSamplerSelect:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -688,9 +717,10 @@ class CFGGuider:
return (guider,) return (guider,)
class Guider_DualCFG(comfy.samplers.CFGGuider): class Guider_DualCFG(comfy.samplers.CFGGuider):
def set_cfg(self, cfg1, cfg2): def set_cfg(self, cfg1, cfg2, nested=False):
self.cfg1 = cfg1 self.cfg1 = cfg1
self.cfg2 = cfg2 self.cfg2 = cfg2
self.nested = nested
def set_conds(self, positive, middle, negative): def set_conds(self, positive, middle, negative):
middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"}) middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"})
@ -700,14 +730,20 @@ class Guider_DualCFG(comfy.samplers.CFGGuider):
negative_cond = self.conds.get("negative", None) negative_cond = self.conds.get("negative", None)
middle_cond = self.conds.get("middle", None) middle_cond = self.conds.get("middle", None)
positive_cond = self.conds.get("positive", None) positive_cond = self.conds.get("positive", None)
if model_options.get("disable_cfg1_optimization", False) == False:
if math.isclose(self.cfg2, 1.0):
negative_cond = None
if math.isclose(self.cfg1, 1.0):
middle_cond = None
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options) if self.nested:
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1 out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
pred_text = comfy.samplers.cfg_function(self.inner_model, out[2], out[1], self.cfg1, x, timestep, model_options=model_options, cond=positive_cond, uncond=middle_cond)
return out[0] + self.cfg2 * (pred_text - out[0])
else:
if model_options.get("disable_cfg1_optimization", False) == False:
if math.isclose(self.cfg2, 1.0):
negative_cond = None
if math.isclose(self.cfg1, 1.0):
middle_cond = None
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
class DualCFGGuider: class DualCFGGuider:
@classmethod @classmethod
@ -719,6 +755,7 @@ class DualCFGGuider:
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING", ),
"cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), "cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), "cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"style": (["regular", "nested"],),
} }
} }
@ -727,10 +764,10 @@ class DualCFGGuider:
FUNCTION = "get_guider" FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders" CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative): def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style):
guider = Guider_DualCFG(model) guider = Guider_DualCFG(model)
guider.set_conds(cond1, cond2, negative) guider.set_conds(cond1, cond2, negative)
guider.set_cfg(cfg_conds, cfg_cond2_negative) guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested"))
return (guider,) return (guider,)
class DisableNoise: class DisableNoise:
@ -878,6 +915,7 @@ NODE_CLASS_MAPPINGS = {
"FlipSigmas": FlipSigmas, "FlipSigmas": FlipSigmas,
"SetFirstSigma": SetFirstSigma, "SetFirstSigma": SetFirstSigma,
"ExtendIntermediateSigmas": ExtendIntermediateSigmas, "ExtendIntermediateSigmas": ExtendIntermediateSigmas,
"SamplingPercentToSigma": SamplingPercentToSigma,
"CFGGuider": CFGGuider, "CFGGuider": CFGGuider,
"DualCFGGuider": DualCFGGuider, "DualCFGGuider": DualCFGGuider,

View File

@ -20,7 +20,7 @@ from comfy import node_helpers
from comfy.cli_args import args from comfy.cli_args import args
from comfy.comfy_types.node_typing import IO from comfy.comfy_types.node_typing import IO
from comfy.execution_context import current_execution_context from comfy.execution_context import current_execution_context
from comfy.weight_adapter import adapters from comfy.weight_adapter import adapters, adapter_maps
from . import nodes_custom_sampler from . import nodes_custom_sampler
from .nodes_custom_sampler import Noise_RandomNoise from .nodes_custom_sampler import Noise_RandomNoise
@ -41,13 +41,13 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
class TrainSampler(comfy.samplers.Sampler): class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.optimizer = optimizer self.optimizer = optimizer
self.loss_callback = loss_callback self.loss_callback = loss_callback
self.batch_size = batch_size self.batch_size = batch_size
self.total_steps = total_steps self.total_steps = total_steps
self.grad_acc = grad_acc
self.seed = seed self.seed = seed
self.training_dtype = training_dtype self.training_dtype = training_dtype
@ -94,8 +94,9 @@ class TrainSampler(comfy.samplers.Sampler):
self.loss_callback(loss.item()) self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"}) pbar.set_postfix({"loss": f"{loss.item():.4f}"})
self.optimizer.step() if (i+1) % self.grad_acc == 0:
self.optimizer.zero_grad() self.optimizer.step()
self.optimizer.zero_grad()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return torch.zeros_like(latent_image) return torch.zeros_like(latent_image)
@ -421,6 +422,16 @@ class TrainLoraNode:
"tooltip": "The batch size to use for training.", "tooltip": "The batch size to use for training.",
}, },
), ),
"grad_accumulation_steps": (
IO.INT,
{
"default": 1,
"min": 1,
"max": 1024,
"step": 1,
"tooltip": "The number of gradient accumulation steps to use for training.",
}
),
"steps": ( "steps": (
IO.INT, IO.INT,
{ {
@ -480,6 +491,17 @@ class TrainLoraNode:
["bf16", "fp32"], ["bf16", "fp32"],
{"default": "bf16", "tooltip": "The dtype to use for lora."}, {"default": "bf16", "tooltip": "The dtype to use for lora."},
), ),
"algorithm": (
list(adapter_maps.keys()),
{"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."},
),
"gradient_checkpointing": (
IO.BOOLEAN,
{
"default": True,
"tooltip": "Use gradient checkpointing for training.",
}
),
"existing_lora": ( "existing_lora": (
folder_paths.get_filename_list("loras") + ["[None]"], folder_paths.get_filename_list("loras") + ["[None]"],
{ {
@ -503,6 +525,7 @@ class TrainLoraNode:
positive, positive,
batch_size, batch_size,
steps, steps,
grad_accumulation_steps,
learning_rate, learning_rate,
rank, rank,
optimizer, optimizer,
@ -510,6 +533,8 @@ class TrainLoraNode:
seed, seed,
training_dtype, training_dtype,
lora_dtype, lora_dtype,
algorithm,
gradient_checkpointing,
existing_lora, existing_lora,
): ):
mp = model.clone() mp = model.clone()
@ -560,10 +585,8 @@ class TrainLoraNode:
if existing_adapter is not None: if existing_adapter is not None:
break break
else: else:
# If no existing adapter found, use LoRA
# We will add algo option in the future
existing_adapter = None existing_adapter = None
adapter_cls = adapters[0] adapter_cls = adapter_maps[algorithm]
if existing_adapter is not None: if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype) train_adapter = existing_adapter.to_train().to(lora_dtype)
@ -619,8 +642,9 @@ class TrainLoraNode:
criterion = None criterion = None
# setup models # setup models
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): if gradient_checkpointing:
patch(m) for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
patch(m)
mp.model.requires_grad_(False) mp.model.requires_grad_(False)
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
@ -633,7 +657,8 @@ class TrainLoraNode:
optimizer, optimizer,
loss_callback=loss_callback, loss_callback=loss_callback,
batch_size=batch_size, batch_size=batch_size,
total_steps=steps, grad_acc=grad_accumulation_steps,
total_steps=steps*grad_accumulation_steps,
seed=seed, seed=seed,
training_dtype=dtype training_dtype=dtype
) )

View File

@ -1,26 +1,33 @@
from comfy.nodes import base_nodes as nodes import json
from comfy import node_helpers import math
from typing import Tuple
import numpy as np
import torch import torch
import comfy.clip_vision
import comfy.latent_formats
import comfy.model_management import comfy.model_management
import comfy.utils import comfy.utils
import comfy.latent_formats import nodes
import comfy.clip_vision from comfy import node_helpers
from comfy.nodes import base_nodes as nodes
class WanImageToVideo: class WanImageToVideo:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ), return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING",),
"vae": ("VAE", ), "vae": ("VAE",),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}, },
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
"start_image": ("IMAGE", ), "start_image": ("IMAGE",),
}} }}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent") RETURN_NAMES = ("positive", "negative", "latent")
@ -54,18 +61,18 @@ class WanImageToVideo:
class WanFunControlToVideo: class WanFunControlToVideo:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ), return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING",),
"vae": ("VAE", ), "vae": ("VAE",),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}, },
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
"start_image": ("IMAGE", ), "start_image": ("IMAGE",),
"control_video": ("IMAGE", ), "control_video": ("IMAGE",),
}} }}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent") RETURN_NAMES = ("positive", "negative", "latent")
@ -82,12 +89,12 @@ class WanFunControlToVideo:
if start_image is not None: if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3]) concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] concat_latent[:, 16:, :concat_latent_image.shape[2]] = concat_latent_image[:, :, :concat_latent.shape[2]]
if control_video is not None: if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(control_video[:, :, :, :3]) concat_latent_image = vae.encode(control_video[:, :, :, :3])
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] concat_latent[:, :16, :concat_latent_image.shape[2]] = concat_latent_image[:, :, :concat_latent.shape[2]]
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
@ -100,22 +107,23 @@ class WanFunControlToVideo:
out_latent["samples"] = latent out_latent["samples"] = latent
return (positive, negative, out_latent) return (positive, negative, out_latent)
class WanFirstLastFrameToVideo: class WanFirstLastFrameToVideo:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ), return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING",),
"vae": ("VAE", ), "vae": ("VAE",),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}, },
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ), "optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT",),
"clip_vision_end_image": ("CLIP_VISION_OUTPUT", ), "clip_vision_end_image": ("CLIP_VISION_OUTPUT",),
"start_image": ("IMAGE", ), "start_image": ("IMAGE",),
"end_image": ("IMAGE", ), "end_image": ("IMAGE",),
}} }}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent") RETURN_NAMES = ("positive", "negative", "latent")
@ -169,18 +177,18 @@ class WanFirstLastFrameToVideo:
class WanFunInpaintToVideo: class WanFunInpaintToVideo:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ), return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING",),
"vae": ("VAE", ), "vae": ("VAE",),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}, },
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
"start_image": ("IMAGE", ), "start_image": ("IMAGE",),
"end_image": ("IMAGE", ), "end_image": ("IMAGE",),
}} }}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent") RETURN_NAMES = ("positive", "negative", "latent")
@ -196,19 +204,19 @@ class WanFunInpaintToVideo:
class WanVaceToVideo: class WanVaceToVideo:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ), return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING",),
"vae": ("VAE", ), "vae": ("VAE",),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
}, },
"optional": {"control_video": ("IMAGE", ), "optional": {"control_video": ("IMAGE",),
"control_masks": ("MASK", ), "control_masks": ("MASK",),
"reference_image": ("IMAGE", ), "reference_image": ("IMAGE",),
}} }}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT") RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT")
RETURN_NAMES = ("positive", "negative", "latent", "trim_latent") RETURN_NAMES = ("positive", "negative", "latent", "trim_latent")
@ -277,11 +285,12 @@ class WanVaceToVideo:
out_latent["samples"] = latent out_latent["samples"] = latent
return (positive, negative, out_latent, trim_latent) return (positive, negative, out_latent, trim_latent)
class TrimVideoLatent: class TrimVideoLatent:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), return {"required": {"samples": ("LATENT",),
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}), "trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}),
}} }}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
@ -298,21 +307,22 @@ class TrimVideoLatent:
samples_out["samples"] = s1[:, :, trim_amount:] samples_out["samples"] = s1[:, :, trim_amount:]
return (samples_out,) return (samples_out,)
class WanCameraImageToVideo: class WanCameraImageToVideo:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ), return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING",),
"vae": ("VAE", ), "vae": ("VAE",),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}, },
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
"start_image": ("IMAGE", ), "start_image": ("IMAGE",),
"camera_conditions": ("WAN_CAMERA_EMBEDDING", ), "camera_conditions": ("WAN_CAMERA_EMBEDDING",),
}} }}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent") RETURN_NAMES = ("positive", "negative", "latent")
@ -328,7 +338,7 @@ class WanCameraImageToVideo:
if start_image is not None: if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3]) concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] concat_latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image[:, :, :concat_latent.shape[2]]
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
@ -345,19 +355,20 @@ class WanCameraImageToVideo:
out_latent["samples"] = latent out_latent["samples"] = latent
return (positive, negative, out_latent) return (positive, negative, out_latent)
class WanPhantomSubjectToVideo: class WanPhantomSubjectToVideo:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ), return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING",),
"vae": ("VAE", ), "vae": ("VAE",),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}, },
"optional": {"images": ("IMAGE", ), "optional": {"images": ("IMAGE",),
}} }}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT") RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent") RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent")
@ -383,7 +394,316 @@ class WanPhantomSubjectToVideo:
out_latent["samples"] = latent out_latent["samples"] = latent
return (positive, cond2, negative, out_latent) return (positive, cond2, negative, out_latent)
def parse_json_tracks(tracks):
"""Parse JSON track data into a standardized format"""
tracks_data = []
try:
# If tracks is a string, try to parse it as JSON
if isinstance(tracks, str):
parsed = json.loads(tracks.replace("'", '"'))
tracks_data.extend(parsed)
else:
# If tracks is a list of strings, parse each one
for track_str in tracks:
parsed = json.loads(track_str.replace("'", '"'))
tracks_data.append(parsed)
# Check if we have a single track (dict with x,y) or a list of tracks
if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]:
# Single track detected, wrap it in a list
tracks_data = [tracks_data]
elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]:
# Already a list of tracks, nothing to do
pass
else:
# Unexpected format
pass
except json.JSONDecodeError:
tracks_data = []
return tracks_data
def process_tracks(tracks_np: np.ndarray, frame_size: Tuple[int, int], num_frames, quant_multi: int = 8, **kwargs):
# tracks: shape [t, h, w, 3] => samples align with 24 fps, model trained with 16 fps.
# frame_size: tuple (W, H)
tracks = torch.from_numpy(tracks_np).float()
if tracks.shape[1] == 121:
tracks = torch.permute(tracks, (1, 0, 2, 3))
tracks, visibles = tracks[..., :2], tracks[..., 2:3]
short_edge = min(*frame_size)
frame_center = torch.tensor([*frame_size]).type_as(tracks) / 2
tracks = tracks - frame_center
tracks = tracks / short_edge * 2
visibles = visibles * 2 - 1
trange = torch.linspace(-1, 1, tracks.shape[0]).view(-1, 1, 1, 1).expand(*visibles.shape)
out_ = torch.cat([trange, tracks, visibles], dim=-1).view(121, -1, 4)
out_0 = out_[:1]
out_l = out_[1:] # 121 => 120 | 1
a = 120 // math.gcd(120, num_frames)
b = num_frames // math.gcd(120, num_frames)
out_l = torch.repeat_interleave(out_l, b, dim=0)[1::a] # 120 => 120 * b => 120 * b / a == F
final_result = torch.cat([out_0, out_l], dim=0)
return final_result
FIXED_LENGTH = 121
def pad_pts(tr):
"""Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating."""
pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32)
n = pts.shape[0]
if n < FIXED_LENGTH:
pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32)
pts = np.vstack((pts, pad))
else:
pts = pts[:FIXED_LENGTH]
return pts.reshape(FIXED_LENGTH, 1, 3)
def ind_sel(target: torch.Tensor, ind: torch.Tensor, dim: int = 1):
"""Index selection utility function"""
assert (
len(ind.shape) > dim
), "Index must have the target dim, but get dim: %d, ind shape: %s" % (dim, str(ind.shape))
target = target.expand(
*tuple(
[ind.shape[k] if target.shape[k] == 1 else -1 for k in range(dim)]
+ [
-1,
]
* (len(target.shape) - dim)
)
)
ind_pad = ind
if len(target.shape) > dim + 1:
for _ in range(len(target.shape) - (dim + 1)):
ind_pad = ind_pad.unsqueeze(-1)
ind_pad = ind_pad.expand(*(-1,) * (dim + 1), *target.shape[(dim + 1)::])
return torch.gather(target, dim=dim, index=ind_pad)
def merge_final(vert_attr: torch.Tensor, weight: torch.Tensor, vert_assign: torch.Tensor):
"""Merge vertex attributes with weights"""
target_dim = len(vert_assign.shape) - 1
if len(vert_attr.shape) == 2:
assert vert_attr.shape[0] > vert_assign.max()
new_shape = [1] * target_dim + list(vert_attr.shape)
tensor = vert_attr.reshape(new_shape)
sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim)
else:
assert vert_attr.shape[1] > vert_assign.max()
new_shape = [vert_attr.shape[0]] + [1] * (target_dim - 1) + list(vert_attr.shape[1:])
tensor = vert_attr.reshape(new_shape)
sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim)
final_attr = torch.sum(sel_attr * weight.unsqueeze(-1), dim=-2)
return final_attr
def _patch_motion_single(
tracks: torch.FloatTensor, # (B, T, N, 4)
vid: torch.FloatTensor, # (C, T, H, W)
temperature: float,
vae_divide: tuple,
topk: int,
):
"""Apply motion patching based on tracks"""
_, T, H, W = vid.shape
N = tracks.shape[2]
_, tracks_xy, visible = torch.split(
tracks, [1, 2, 1], dim=-1
) # (B, T, N, 2) | (B, T, N, 1)
tracks_n = tracks_xy / torch.tensor([W / min(H, W), H / min(H, W)], device=tracks_xy.device)
tracks_n = tracks_n.clamp(-1, 1)
visible = visible.clamp(0, 1)
xx = torch.linspace(-W / min(H, W), W / min(H, W), W)
yy = torch.linspace(-H / min(H, W), H / min(H, W), H)
grid = torch.stack(torch.meshgrid(yy, xx, indexing="ij")[::-1], dim=-1).to(
tracks_xy.device
)
tracks_pad = tracks_xy[:, 1:]
visible_pad = visible[:, 1:]
visible_align = visible_pad.view(T - 1, 4, *visible_pad.shape[2:]).sum(1)
tracks_align = (tracks_pad * visible_pad).view(T - 1, 4, *tracks_pad.shape[2:]).sum(
1
) / (visible_align + 1e-5)
dist_ = (
(tracks_align[:, None, None] - grid[None, :, :, None]).pow(2).sum(-1)
) # T, H, W, N
weight = torch.exp(-dist_ * temperature) * visible_align.clamp(0, 1).view(
T - 1, 1, 1, N
)
vert_weight, vert_index = torch.topk(
weight, k=min(topk, weight.shape[-1]), dim=-1
)
grid_mode = "bilinear"
point_feature = torch.nn.functional.grid_sample(
vid.permute(1, 0, 2, 3)[:1],
tracks_n[:, :1].type(vid.dtype),
mode=grid_mode,
padding_mode="zeros",
align_corners=False,
)
point_feature = point_feature.squeeze(0).squeeze(1).permute(1, 0) # N, C=16
out_feature = merge_final(point_feature, vert_weight, vert_index).permute(3, 0, 1, 2) # T - 1, H, W, C => C, T - 1, H, W
out_weight = vert_weight.sum(-1) # T - 1, H, W
# out feature -> already soft weighted
mix_feature = out_feature + vid[:, 1:] * (1 - out_weight.clamp(0, 1))
out_feature_full = torch.cat([vid[:, :1], mix_feature], dim=1) # C, T, H, W
out_mask_full = torch.cat([torch.ones_like(out_weight[:1]), out_weight], dim=0) # T, H, W
return out_mask_full[None].expand(vae_divide[0], -1, -1, -1), out_feature_full
def patch_motion(
tracks: torch.FloatTensor, # (B, TB, T, N, 4)
vid: torch.FloatTensor, # (C, T, H, W)
temperature: float = 220.0,
vae_divide: tuple = (4, 16),
topk: int = 2,
):
B = len(tracks)
# Process each batch separately
out_masks = []
out_features = []
for b in range(B):
mask, feature = _patch_motion_single(
tracks[b], # (T, N, 4)
vid[b], # (C, T, H, W)
temperature,
vae_divide,
topk
)
out_masks.append(mask)
out_features.append(feature)
# Stack results: (B, C, T, H, W)
out_mask_full = torch.stack(out_masks, dim=0)
out_feature_full = torch.stack(out_features, dim=0)
return out_mask_full, out_feature_full
class WanTrackToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"vae": ("VAE",),
"tracks": ("STRING", {"multiline": True, "default": "[]"}),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
"topk": ("INT", {"default": 2, "min": 1, "max": 10}),
"start_image": ("IMAGE",),
},
"optional": {
"clip_vision_output": ("CLIP_VISION_OUTPUT",),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, tracks, width, height, length, batch_size,
temperature, topk, start_image=None, clip_vision_output=None):
tracks_data = parse_json_tracks(tracks)
if not tracks_data:
return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output)
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
device=comfy.model_management.intermediate_device())
if isinstance(tracks_data[0][0], dict):
tracks_data = [tracks_data]
processed_tracks = []
for batch in tracks_data:
arrs = []
for track in batch:
pts = pad_pts(track)
arrs.append(pts)
tracks_np = np.stack(arrs, axis=0)
processed_tracks.append(process_tracks(tracks_np, (width, height), length - 1).unsqueeze(0))
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:batch_size].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
videos = torch.ones((start_image.shape[0], length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
for i in range(start_image.shape[0]):
videos[i, 0] = start_image[i]
latent_videos = []
videos = comfy.utils.resize_to_batch_size(videos, batch_size)
for i in range(batch_size):
latent_videos += [vae.encode(videos[i, :, :, :, :3])]
y = torch.cat(latent_videos, dim=0)
# Scale latent since patch_motion is non-linear
y = comfy.latent_formats.Wan21().process_in(y)
processed_tracks = comfy.utils.resize_list_to_batch_size(processed_tracks, batch_size)
res = patch_motion(
processed_tracks, y, temperature=temperature, topk=topk, vae_divide=(4, 16)
)
mask, concat_latent_image = res
concat_latent_image = comfy.latent_formats.Wan21().process_out(concat_latent_image)
mask = -mask + 1.0 # Invert mask to match expected format
positive = node_helpers.conditioning_set_values(positive,
{"concat_mask": mask,
"concat_latent_image": concat_latent_image})
negative = node_helpers.conditioning_set_values(negative,
{"concat_mask": mask,
"concat_latent_image": concat_latent_image})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"WanTrackToVideo": WanTrackToVideo,
"WanImageToVideo": WanImageToVideo, "WanImageToVideo": WanImageToVideo,
"WanFunControlToVideo": WanFunControlToVideo, "WanFunControlToVideo": WanFunControlToVideo,
"WanFunInpaintToVideo": WanFunInpaintToVideo, "WanFunInpaintToVideo": WanFunInpaintToVideo,

View File

@ -1,6 +1,6 @@
[project] [project]
name = "comfyui" name = "comfyui"
version = "0.3.44" version = "0.3.45"
description = "An installable version of ComfyUI" description = "An installable version of ComfyUI"
readme = "README.md" readme = "README.md"
authors = [ authors = [
@ -18,9 +18,9 @@ classifiers = [
] ]
dependencies = [ dependencies = [
"comfyui-frontend-package", "comfyui-frontend-package>=1.23.4",
"comfyui-workflow-templates", "comfyui-workflow-templates>=0.1.40",
"comfyui-embedded-docs", "comfyui-embedded-docs>=0.2.4",
"torch", "torch",
"torchvision", "torchvision",
"torchdiffeq>=0.2.3", "torchdiffeq>=0.2.3",

View File

@ -2,7 +2,7 @@ import argparse
import pytest import pytest
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from unittest.mock import patch from unittest.mock import patch, mock_open
from comfy.app.frontend_management import ( from comfy.app.frontend_management import (
FrontendManager, FrontendManager,
@ -171,3 +171,36 @@ def test_init_frontend_fallback_on_error():
# Assert # Assert
assert frontend_path == "/default/path" assert frontend_path == "/default/path"
mock_check.assert_called_once() mock_check.assert_called_once()
def test_get_frontend_version():
# Arrange
expected_version = "1.25.0"
mock_requirements_content = """torch
torchsde
comfyui-frontend-package==1.25.0
other-package==1.0.0
numpy"""
# Act
with patch("builtins.open", mock_open(read_data=mock_requirements_content)):
version = FrontendManager.get_required_frontend_version()
# Assert
assert version == expected_version
def test_get_frontend_version_invalid_semver():
# Arrange
mock_requirements_content = """torch
torchsde
comfyui-frontend-package==1.29.3.75
other-package==1.0.0
numpy"""
# Act
with patch("builtins.open", mock_open(read_data=mock_requirements_content)):
version = FrontendManager.get_required_frontend_version()
# Assert
assert version is None