Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2025-07-25 12:48:05 -07:00
commit 3684cff31b
25 changed files with 1070 additions and 182 deletions

View File

@ -17,6 +17,7 @@ jobs:
- name: Check for Windows line endings (CRLF)
run: |
# Get the list of changed files in the PR
git merge origin/${{ github.base_ref }} --no-edit
CHANGED_FILES=$(git diff --name-only origin/${{ github.base_ref }}..HEAD)
# Flag to track if CRLF is found

View File

@ -42,6 +42,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
@ -1574,6 +1575,14 @@ If you need to use the legacy frontend for any reason, you can access it using t
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
### Iluvatar Corex
For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method:
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
2. Launch ComfyUI by running `python main.py`
## Community
[Discord](https://comfy.org/discord): Try the #help or #feedback channels.

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.44"
__version__ = "0.3.45"

View File

@ -7,6 +7,7 @@ import os
import re
import tempfile
import zipfile
import importlib.metadata
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
@ -22,10 +23,12 @@ REQUEST_TIMEOUT = 10 # seconds
def check_frontend_version():
"""the thing this does makes no sense, so it got cut"""
return None
def frontend_install_warning_message() -> str:
"""the end user never needs to be messaged this"""
return ""
@ -135,6 +138,15 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager:
CUSTOM_FRONTENDS_ROOT = add_model_folder_path("web_custom_versions", extensions=set())
@classmethod
def get_required_frontend_version(cls) -> str:
"""Get the required frontend package version."""
try:
# this isn't used the way it says
return importlib.metadata.version("comfyui_frontend_package")
except Exception as exc_info:
return "1.23.4"
@classmethod
def default_frontend_path(cls) -> str:
try:

View File

@ -44,7 +44,8 @@ def _create_parser() -> EnhancedConfigArgParser:
help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID",
help="Set the id of the cuda device this instance will use.")
help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true",
help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")

View File

@ -165,6 +165,7 @@ class Configuration(dict):
comfy_api_base (str): Set the base URL for the ComfyUI API. (default: https://api.comfy.org)
database_url (str): Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.
whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled.
default_device (Optional[int]): Set the id of the default device, all other devices will stay visible.
"""
def __init__(self, **kwargs):
@ -281,6 +282,7 @@ class Configuration(dict):
self.front_end_root: Optional[str] = None
self.comfy_api_base: str = "https://api.comfy.org"
self.database_url: str = db_config()
self.default_device: Optional[int] = None
for key, value in kwargs.items():
self[key] = value

View File

@ -82,7 +82,8 @@ if not args.cuda_malloc:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
version = module.__version__
if int(version[0]) >= 2: # enable by default for torch version 2.0 and up
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
args.cuda_malloc = cuda_malloc_supported()
except:
pass

View File

@ -49,6 +49,15 @@ logging.getLogger("alembic.runtime.migration").setLevel(logging.WARNING)
from ..cli_args import args
if args.default_device is not None:
default_dev = args.default_device
devices = list(range(32))
devices.remove(default_dev)
devices.insert(0, default_dev)
devices = ','.join(map(str, devices))
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
os.environ['HIP_VISIBLE_DEVICES'] = str(devices)
if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)

View File

@ -593,6 +593,7 @@ class PromptServer(ExecutorToClientProgress):
ram_free = model_management.get_free_memory(cpu_device)
vram_total, torch_vram_total = get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = get_free_memory(device, torch_free_too=True)
required_frontend_version = FrontendManager.get_required_frontend_version()
system_stats = {
"system": {
@ -600,6 +601,7 @@ class PromptServer(ExecutorToClientProgress):
"ram_total": ram_total,
"ram_free": ram_free,
"comfyui_version": __version__,
"required_frontend_version": required_frontend_version,
"python_version": sys.version,
"pytorch_version": torch_version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",

View File

@ -1244,42 +1244,20 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
return x_next
@torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
extra_args = {} if extra_args is None else extra_args
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
sigma_hat = sigmas[i]
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, temp[0])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
# Euler method
x = denoised + d * sigmas[i + 1]
return x
@torch.no_grad()
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
"""Ancestral sampling with Euler method steps (CFG++)."""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
temp = [0]
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
uncond_denoised = None
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
@ -1288,17 +1266,33 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], temp[0])
# Euler method
x = denoised + d * sigma_down
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
# DDIM stochastic sampling
sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
sigma_down = alpha_t * sigma_down
# Euler method
x = alpha_t * denoised + sigma_down * d
if eta > 0 and s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Euler method steps (CFG++)."""
return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
@torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""

View File

@ -52,15 +52,6 @@ class RMS_norm(nn.Module):
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
@ -73,11 +64,11 @@ class Resample(nn.Module):
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))

View File

@ -125,7 +125,7 @@ if args.directml is not None:
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, noqa: F401
_ = torch.xpu.device_count()
xpu_available = xpu_available or torch.xpu.is_available()
@ -155,6 +155,11 @@ try:
except:
mlu_available = False
try:
ixuca_available = hasattr(torch, "corex")
except:
ixuca_available = False
if args.cpu:
cpu_state = CPUState.CPU
@ -181,6 +186,12 @@ def is_mlu():
return True
return False
def is_ixuca():
global ixuca_available
if ixuca_available:
return True
return False
def get_torch_device():
global directml_device
@ -220,7 +231,11 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_total = 1024 * 1024 * 1024 # TODO
mem_total_torch = mem_total
elif is_intel_xpu():
mem_total = torch.xpu.get_device_properties(dev).total_memory
stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_reserved
mem_total = mem_total_xpu
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@ -334,7 +349,7 @@ try:
if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if is_intel_xpu() or is_ascend_npu() or is_mlu():
if is_intel_xpu() or is_ascend_npu() or is_mlu() or is_ixuca():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
except:
@ -352,7 +367,10 @@ try:
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 8):
if any((a in arch) for a in ["gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
@ -430,6 +448,8 @@ def get_torch_device_name(device):
except:
allocator_backend = ""
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
elif device.type == "xpu":
return "{} {}".format(device, torch.xpu.get_device_name(device))
else:
return "{}".format(device.type)
elif is_intel_xpu():
@ -984,6 +1004,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
return d
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
return d
@ -1044,7 +1065,7 @@ def device_supports_non_blocking(device):
if is_device_mps(device):
return False # pytorch bug? mps doesn't support non blocking
if is_intel_xpu():
return False
return True
if args.deterministic: # TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
return False
if directml_device:
@ -1087,6 +1108,8 @@ def get_offload_stream(device):
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
elif is_device_xpu(device):
ss[stream_counter].wait_stream(torch.xpu.current_stream())
stream_counters[device] = stream_counter
return s
elif is_device_cuda(device):
@ -1098,6 +1121,15 @@ def get_offload_stream(device):
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
elif is_device_xpu(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.xpu.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None
@ -1106,6 +1138,8 @@ def sync_stream(device, stream):
return
if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream)
elif is_device_xpu(device):
torch.xpu.current_stream().wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
@ -1152,6 +1186,8 @@ def xformers_enabled():
return False
if is_mlu():
return False
if is_ixuca():
return False
if directml_device:
return False
return XFORMERS_IS_AVAILABLE
@ -1202,6 +1238,8 @@ def pytorch_attention_flash_attention():
return True
if is_amd():
return True # if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
if is_ixuca():
return True
return False
@ -1234,8 +1272,8 @@ def get_free_memory(dev=None, torch_free_too=False):
stats = torch.xpu.memory_stats(dev) # pylint: disable=no-member
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_torch = mem_reserved - mem_active
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_xpu + mem_free_torch
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
@ -1289,6 +1327,8 @@ def is_device_cpu(device):
def is_device_mps(device):
return is_device_type(device, 'mps')
def is_device_xpu(device):
return is_device_type(device, 'xpu')
def is_device_cuda(device):
return is_device_type(device, 'cuda')
@ -1323,7 +1363,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return False
if is_intel_xpu():
return True
if torch_version_numeric < (2, 3):
return True
else:
return torch.xpu.get_device_properties(device).has_fp16
if is_ascend_npu():
return True
@ -1331,6 +1374,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_mlu():
return True
if is_ixuca():
return True
if is_amd():
return True
try:
@ -1394,11 +1440,17 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
if is_intel_xpu():
return True
if torch_version_numeric < (2, 6):
return True
else:
return torch.xpu.get_device_capability(device)['has_bfloat16_conversions']
if is_ascend_npu():
return True
if is_ixuca():
return True
if is_amd():
arch = torch.cuda.get_device_properties(device).gcnArchName
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16

View File

@ -37,6 +37,7 @@ from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.wan.vae import WanVAE
from .lora_convert import convert_lora
from .model_management import load_models_gpu
from .model_patcher import ModelPatcher
from .t2i_adapter import adapter
from .taesd import taesd
from .text_encoders import aura_t5
@ -71,7 +72,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
_lora = convert_lora(_lora)
loaded = lora.load_lora(_lora, key_map)
if model is not None:
new_modelpatcher = model.clone()
new_modelpatcher: ModelPatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)
else:
k = ()

View File

@ -781,6 +781,27 @@ def resize_to_batch_size(tensor, batch_size):
return output
def resize_list_to_batch_size(l, batch_size):
in_batch_size = len(l)
if in_batch_size == batch_size or in_batch_size == 0:
return l
if batch_size <= 1:
return l[:batch_size]
output = []
if batch_size < in_batch_size:
scale = (in_batch_size - 1) / (batch_size - 1)
for i in range(batch_size):
output.append(l[min(round(i * scale), in_batch_size - 1)])
else:
scale = in_batch_size / batch_size
for i in range(batch_size):
output.append(l[min(math.floor((i + 0.5) * scale), in_batch_size - 1)])
return output
def convert_sd_to(state_dict, dtype):
keys = list(state_dict.keys())
for k in keys:

View File

@ -15,9 +15,20 @@ adapters: list[type[WeightAdapterBase]] = [
OFTAdapter,
BOFTAdapter,
]
adapter_maps: dict[str, type[WeightAdapterBase]] = {
"LoRA": LoRAAdapter,
"LoHa": LoHaAdapter,
"LoKr": LoKrAdapter,
"OFT": OFTAdapter,
## We disable not implemented algo for now
# "GLoRA": GLoRAAdapter,
# "BOFT": BOFTAdapter,
}
__all__ = [
"WeightAdapterBase",
"WeightAdapterTrainBase",
"adapters"
"adapters",
"adapter_maps",
] + [a.__name__ for a in adapters]

View File

@ -133,3 +133,43 @@ def tucker_weight_from_conv(up, down, mid):
def tucker_weight(wa, wb, t):
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
"""
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
examples)
factor
-1 2 4 8 16 ...
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
"""
if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2:
m = factor
n = dimension // factor
if m > n:
n, m = m, n
return m, n
if factor < 0:
factor = dimension
m, n = 1, dimension
length = m + n
while m < n:
new_m = m + 1
while dimension % new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m > factor:
break
else:
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n

View File

@ -3,7 +3,120 @@ from typing import Optional
import torch
from ..model_management import cast_to_device
from .base import WeightAdapterBase, weight_decompose
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
class HadaWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
return diff_weight
@staticmethod
def backward(ctx, grad_out):
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = grad_out * (w2u @ w2d)
grad_w1u = temp @ w1d.T
grad_w1d = w1u.T @ temp
temp = grad_out * (w1u @ w1d)
grad_w2u = temp @ w2d.T
grad_w2d = w2u.T @ temp
del temp
return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None
class HadaWeightTucker(torch.autograd.Function):
@staticmethod
def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)):
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
return rebuild1 * rebuild2 * scale
@staticmethod
def backward(ctx, grad_out):
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
grad_w = rebuild * grad_out
del rebuild
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
del grad_w, temp
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
del grad_temp
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
grad_w = rebuild * grad_out
del rebuild
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
del grad_w, temp
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
del grad_temp
return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None
class LohaDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
# Unpack weights tuple from LoHaAdapter
w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights
# Create trainable parameters
self.hada_w1_a = torch.nn.Parameter(w1a)
self.hada_w1_b = torch.nn.Parameter(w1b)
self.hada_w2_a = torch.nn.Parameter(w2a)
self.hada_w2_b = torch.nn.Parameter(w2b)
self.use_tucker = False
if t1 is not None and t2 is not None:
self.use_tucker = True
self.hada_t1 = torch.nn.Parameter(t1)
self.hada_t2 = torch.nn.Parameter(t2)
else:
# Keep the attributes for consistent access
self.hada_t1 = None
self.hada_t2 = None
# Store rank and non-trainable alpha
self.rank = w1b.shape[0]
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
def __call__(self, w):
org_dtype = w.dtype
scale = self.alpha / self.rank
if self.use_tucker:
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
else:
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
# Add the scaled difference to the original weight
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
return weight.to(org_dtype)
def passive_memory_usage(self):
"""Calculates memory usage of the trainable parameters."""
return sum(param.numel() * param.element_size() for param in self.parameters())
logger = logging.getLogger(__name__)
@ -14,6 +127,25 @@ class LoHaAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
torch.nn.init.normal_(mat1, 0.1)
torch.nn.init.constant_(mat2, 0.0)
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
torch.nn.init.normal_(mat3, 0.1)
torch.nn.init.normal_(mat4, 0.01)
return LohaDiff(
(mat1, mat2, alpha, mat3, mat4, None, None, None)
)
def to_train(self):
return LohaDiff(self.weights)
@classmethod
def load(
cls,

View File

@ -3,7 +3,77 @@ from typing import Optional
import torch
from ..model_management import cast_to_device
from .base import WeightAdapterBase, weight_decompose
from .base import (
WeightAdapterBase,
WeightAdapterTrainBase,
weight_decompose,
factorization,
)
class LokrDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
self.use_tucker = False
if lokr_w1_a is not None:
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1]
self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a)
self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b)
self.w1_rebuild = True
self.ranka = rank_a
if lokr_w2_a is not None:
_, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1]
rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1]
self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a)
self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b)
if lokr_t2 is not None:
self.use_tucker = True
self.lokr_t2 = torch.nn.Parameter(lokr_t2)
self.w2_rebuild = True
self.rankb = rank_b
if lokr_w1 is not None:
self.lokr_w1 = torch.nn.Parameter(lokr_w1)
self.w1_rebuild = False
if lokr_w2 is not None:
self.lokr_w2 = torch.nn.Parameter(lokr_w2)
self.w2_rebuild = False
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
@property
def w1(self):
if self.w1_rebuild:
return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka)
else:
return self.lokr_w1
@property
def w2(self):
if self.w2_rebuild:
if self.use_tucker:
w2 = torch.einsum(
'i j k l, j r, i p -> p r k l',
self.lokr_t2,
self.lokr_w2_b,
self.lokr_w2_a
)
else:
w2 = self.lokr_w2_a @ self.lokr_w2_b
return w2 * (self.alpha / self.rankb)
else:
return self.lokr_w2
def __call__(self, w):
diff = torch.kron(self.w1, self.w2)
return w + diff.reshape(w.shape).to(w)
def passive_memory_usage(self):
return sum(param.numel() * param.element_size() for param in self.parameters())
logger = logging.getLogger(__name__)
@ -14,6 +84,20 @@ class LoKrAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
out1, out2 = factorization(out_dim, rank)
in1, in2 = factorization(in_dim, rank)
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
torch.nn.init.constant_(mat1, 0.0)
return LokrDiff(
(mat1, mat2, alpha, None, None, None, None, None, None)
)
@classmethod
def load(
cls,

View File

@ -2,11 +2,64 @@ import logging
from typing import Optional
import torch
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
from ..model_management import cast_to_device
from .base import WeightAdapterBase, weight_decompose
logger = logging.getLogger(__name__)
class OFTDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
# Unpack weights tuple from LoHaAdapter
blocks, rescale, alpha, _ = weights
# Create trainable parameters
self.oft_blocks = torch.nn.Parameter(blocks)
if rescale is not None:
self.rescale = torch.nn.Parameter(rescale)
self.rescaled = True
else:
self.rescaled = False
self.block_num, self.block_size, _ = blocks.shape
self.constraint = float(alpha)
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
def __call__(self, w):
org_dtype = w.dtype
I = torch.eye(self.block_size, device=self.oft_blocks.device)
## generate r
# for Q = -Q^T
q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
normed_q = q
if self.constraint:
q_norm = torch.norm(q) + 1e-8
if q_norm > self.constraint:
normed_q = q * self.constraint / q_norm
# use float() to prevent unsupported type
r = (I + normed_q) @ (I - normed_q).float().inverse()
## Apply chunked matmul on weight
_, *shape = w.shape
org_weight = w.to(dtype=r.dtype)
org_weight = org_weight.unflatten(0, (self.block_num, self.block_size))
# Init R=0, so add I on it to ensure the output of step0 is original model output
weight = torch.einsum(
"k n m, k n ... -> k m ...",
r,
org_weight,
).flatten(0, 1)
if self.rescaled:
weight = self.rescale * weight
return weight.to(org_dtype)
def passive_memory_usage(self):
"""Calculates memory usage of the trainable parameters."""
return sum(param.numel() * param.element_size() for param in self.parameters())
class OFTAdapter(WeightAdapterBase):
name = "oft"
@ -14,14 +67,26 @@ class OFTAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
block_size, block_num = factorization(out_dim, rank)
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
return OFTDiff(
(block, None, alpha, None)
)
def to_train(self):
return OFTDiff(self.weights)
@classmethod
def load(
cls,
x: str,
lora: dict[str, torch.Tensor],
alpha: float,
dora_scale: torch.Tensor,
loaded_keys: set[str] = None,
cls,
x: str,
lora: dict[str, torch.Tensor],
alpha: float,
dora_scale: torch.Tensor,
loaded_keys: set[str] = None,
) -> Optional["OFTAdapter"]:
if loaded_keys is None:
loaded_keys = set()
@ -47,20 +112,22 @@ class OFTAdapter(WeightAdapterBase):
return cls(loaded_keys, weights)
def calculate_weight(
self,
weight,
key,
strength,
strength_model,
offset,
function,
intermediate_dtype=torch.float32,
original_weight=None,
self,
weight,
key,
strength,
strength_model,
offset,
function,
intermediate_dtype=torch.float32,
original_weight=None,
):
v = self.weights
blocks = v[0]
rescale = v[1]
alpha = v[2]
if alpha is None:
alpha = 0
dora_scale = v[3]
blocks = cast_to_device(blocks, weight.device, intermediate_dtype)
@ -75,7 +142,7 @@ class OFTAdapter(WeightAdapterBase):
# for Q = -Q^T
q = blocks - blocks.transpose(1, 2)
normed_q = q
if alpha > 0: # alpha in oft/boft is for constraint
if alpha > 0: # alpha in oft/boft is for constraint
q_norm = torch.norm(q) + 1e-8
if q_norm > alpha:
normed_q = q * alpha / q_norm

View File

@ -1,11 +1,12 @@
from __future__ import annotations
import av
import hashlib
import io
import json
import os
import random
import av
import torch
import comfy.model_management
@ -112,7 +113,6 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
except ImportError as exc_info:
raise TorchAudioNotFoundError()
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results: list[FileLocator] = []
@ -178,13 +178,13 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
if quality == "V0":
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "320k":
out_stream.bit_rate = 320000
else: #format == "flac":
else: # format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
@ -212,6 +212,7 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
return {"ui": {"audio": results}}
class SaveAudio:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@ -220,9 +221,9 @@ class SaveAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
},
return {"required": {"audio": ("AUDIO",),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
@ -236,6 +237,7 @@ class SaveAudio:
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
class SaveAudioMP3:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@ -244,10 +246,10 @@ class SaveAudioMP3:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
},
return {"required": {"audio": ("AUDIO",),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
@ -261,6 +263,7 @@ class SaveAudioMP3:
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
class SaveAudioOpus:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@ -269,10 +272,10 @@ class SaveAudioOpus:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
},
return {"required": {"audio": ("AUDIO",),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
@ -286,6 +289,7 @@ class SaveAudioOpus:
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
class PreviewAudio(SaveAudio):
def __init__(self):
self.output_dir = folder_paths.get_temp_directory()
@ -300,6 +304,44 @@ class PreviewAudio(SaveAudio):
}
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format."""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / (2 ** 15)
elif wav.dtype == torch.int32:
return wav.float() / (2 ** 31)
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def load(filepath: str) -> tuple[torch.Tensor, int]:
with av.open(filepath) as af:
if not af.streams.audio:
raise ValueError("No audio stream found in the file.")
stream = af.streams.audio[0]
sr = stream.codec_context.sample_rate
n_channels = stream.channels
frames = []
length = 0
for frame in af.decode(streams=stream.index):
buf = torch.from_numpy(frame.to_ndarray())
if buf.shape[0] != n_channels:
buf = buf.view(-1, n_channels).t()
frames.append(buf)
length += buf.shape[1]
if not frames:
raise ValueError("No audio frames decoded.")
wav = torch.cat(frames, dim=1)
wav = f32_pcm(wav)
return wav, sr
class LoadAudio:
@classmethod
def INPUT_TYPES(s):
@ -319,7 +361,7 @@ class LoadAudio:
raise TorchAudioNotFoundError()
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path)
waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio,)

View File

@ -306,6 +306,35 @@ class ExtendIntermediateSigmas:
return (extended_sigmas,)
class SamplingPercentToSigma:
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"model": (IO.MODEL, {}),
"sampling_percent": (IO.FLOAT, {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.0001}),
"return_actual_sigma": (IO.BOOLEAN, {"default": False, "tooltip": "Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."}),
}
}
RETURN_TYPES = (IO.FLOAT,)
RETURN_NAMES = ("sigma_value",)
CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "get_sigma"
def get_sigma(self, model, sampling_percent, return_actual_sigma):
model_sampling = model.get_model_object("model_sampling")
sigma_val = model_sampling.percent_to_sigma(sampling_percent)
if return_actual_sigma:
if sampling_percent == 0.0:
sigma_val = model_sampling.sigma_max.item()
elif sampling_percent == 1.0:
sigma_val = model_sampling.sigma_min.item()
return (sigma_val,)
class KSamplerSelect:
@classmethod
def INPUT_TYPES(s):
@ -688,9 +717,10 @@ class CFGGuider:
return (guider,)
class Guider_DualCFG(comfy.samplers.CFGGuider):
def set_cfg(self, cfg1, cfg2):
def set_cfg(self, cfg1, cfg2, nested=False):
self.cfg1 = cfg1
self.cfg2 = cfg2
self.nested = nested
def set_conds(self, positive, middle, negative):
middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"})
@ -700,14 +730,20 @@ class Guider_DualCFG(comfy.samplers.CFGGuider):
negative_cond = self.conds.get("negative", None)
middle_cond = self.conds.get("middle", None)
positive_cond = self.conds.get("positive", None)
if model_options.get("disable_cfg1_optimization", False) == False:
if math.isclose(self.cfg2, 1.0):
negative_cond = None
if math.isclose(self.cfg1, 1.0):
middle_cond = None
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
if self.nested:
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
pred_text = comfy.samplers.cfg_function(self.inner_model, out[2], out[1], self.cfg1, x, timestep, model_options=model_options, cond=positive_cond, uncond=middle_cond)
return out[0] + self.cfg2 * (pred_text - out[0])
else:
if model_options.get("disable_cfg1_optimization", False) == False:
if math.isclose(self.cfg2, 1.0):
negative_cond = None
if math.isclose(self.cfg1, 1.0):
middle_cond = None
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
class DualCFGGuider:
@classmethod
@ -719,6 +755,7 @@ class DualCFGGuider:
"negative": ("CONDITIONING", ),
"cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"style": (["regular", "nested"],),
}
}
@ -727,10 +764,10 @@ class DualCFGGuider:
FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative):
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style):
guider = Guider_DualCFG(model)
guider.set_conds(cond1, cond2, negative)
guider.set_cfg(cfg_conds, cfg_cond2_negative)
guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested"))
return (guider,)
class DisableNoise:
@ -878,6 +915,7 @@ NODE_CLASS_MAPPINGS = {
"FlipSigmas": FlipSigmas,
"SetFirstSigma": SetFirstSigma,
"ExtendIntermediateSigmas": ExtendIntermediateSigmas,
"SamplingPercentToSigma": SamplingPercentToSigma,
"CFGGuider": CFGGuider,
"DualCFGGuider": DualCFGGuider,

View File

@ -20,7 +20,7 @@ from comfy import node_helpers
from comfy.cli_args import args
from comfy.comfy_types.node_typing import IO
from comfy.execution_context import current_execution_context
from comfy.weight_adapter import adapters
from comfy.weight_adapter import adapters, adapter_maps
from . import nodes_custom_sampler
from .nodes_custom_sampler import Noise_RandomNoise
@ -41,13 +41,13 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
self.loss_fn = loss_fn
self.optimizer = optimizer
self.loss_callback = loss_callback
self.batch_size = batch_size
self.total_steps = total_steps
self.grad_acc = grad_acc
self.seed = seed
self.training_dtype = training_dtype
@ -94,8 +94,9 @@ class TrainSampler(comfy.samplers.Sampler):
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
self.optimizer.step()
self.optimizer.zero_grad()
if (i+1) % self.grad_acc == 0:
self.optimizer.step()
self.optimizer.zero_grad()
torch.cuda.empty_cache()
return torch.zeros_like(latent_image)
@ -421,6 +422,16 @@ class TrainLoraNode:
"tooltip": "The batch size to use for training.",
},
),
"grad_accumulation_steps": (
IO.INT,
{
"default": 1,
"min": 1,
"max": 1024,
"step": 1,
"tooltip": "The number of gradient accumulation steps to use for training.",
}
),
"steps": (
IO.INT,
{
@ -480,6 +491,17 @@ class TrainLoraNode:
["bf16", "fp32"],
{"default": "bf16", "tooltip": "The dtype to use for lora."},
),
"algorithm": (
list(adapter_maps.keys()),
{"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."},
),
"gradient_checkpointing": (
IO.BOOLEAN,
{
"default": True,
"tooltip": "Use gradient checkpointing for training.",
}
),
"existing_lora": (
folder_paths.get_filename_list("loras") + ["[None]"],
{
@ -503,6 +525,7 @@ class TrainLoraNode:
positive,
batch_size,
steps,
grad_accumulation_steps,
learning_rate,
rank,
optimizer,
@ -510,6 +533,8 @@ class TrainLoraNode:
seed,
training_dtype,
lora_dtype,
algorithm,
gradient_checkpointing,
existing_lora,
):
mp = model.clone()
@ -560,10 +585,8 @@ class TrainLoraNode:
if existing_adapter is not None:
break
else:
# If no existing adapter found, use LoRA
# We will add algo option in the future
existing_adapter = None
adapter_cls = adapters[0]
adapter_cls = adapter_maps[algorithm]
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype)
@ -619,8 +642,9 @@ class TrainLoraNode:
criterion = None
# setup models
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
patch(m)
if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
patch(m)
mp.model.requires_grad_(False)
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
@ -633,7 +657,8 @@ class TrainLoraNode:
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
total_steps=steps,
grad_acc=grad_accumulation_steps,
total_steps=steps*grad_accumulation_steps,
seed=seed,
training_dtype=dtype
)

View File

@ -1,26 +1,33 @@
from comfy.nodes import base_nodes as nodes
from comfy import node_helpers
import json
import math
from typing import Tuple
import numpy as np
import torch
import comfy.clip_vision
import comfy.latent_formats
import comfy.model_management
import comfy.utils
import comfy.latent_formats
import comfy.clip_vision
import nodes
from comfy import node_helpers
from comfy.nodes import base_nodes as nodes
class WanImageToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"vae": ("VAE",),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
}}
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
"start_image": ("IMAGE",),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
@ -54,18 +61,18 @@ class WanImageToVideo:
class WanFunControlToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"vae": ("VAE",),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"control_video": ("IMAGE", ),
}}
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
"start_image": ("IMAGE",),
"control_video": ("IMAGE",),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
@ -82,12 +89,12 @@ class WanFunControlToVideo:
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
concat_latent[:, 16:, :concat_latent_image.shape[2]] = concat_latent_image[:, :, :concat_latent.shape[2]]
if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(control_video[:, :, :, :3])
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
concat_latent[:, :16, :concat_latent_image.shape[2]] = concat_latent_image[:, :, :concat_latent.shape[2]]
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
@ -100,22 +107,23 @@ class WanFunControlToVideo:
out_latent["samples"] = latent
return (positive, negative, out_latent)
class WanFirstLastFrameToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"vae": ("VAE",),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ),
"clip_vision_end_image": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
},
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT",),
"clip_vision_end_image": ("CLIP_VISION_OUTPUT",),
"start_image": ("IMAGE",),
"end_image": ("IMAGE",),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
@ -169,18 +177,18 @@ class WanFirstLastFrameToVideo:
class WanFunInpaintToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"vae": ("VAE",),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
"start_image": ("IMAGE",),
"end_image": ("IMAGE",),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
@ -196,19 +204,19 @@ class WanFunInpaintToVideo:
class WanVaceToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"vae": ("VAE",),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
},
"optional": {"control_video": ("IMAGE", ),
"control_masks": ("MASK", ),
"reference_image": ("IMAGE", ),
}}
},
"optional": {"control_video": ("IMAGE",),
"control_masks": ("MASK",),
"reference_image": ("IMAGE",),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT")
RETURN_NAMES = ("positive", "negative", "latent", "trim_latent")
@ -277,11 +285,12 @@ class WanVaceToVideo:
out_latent["samples"] = latent
return (positive, negative, out_latent, trim_latent)
class TrimVideoLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}),
return {"required": {"samples": ("LATENT",),
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}),
}}
RETURN_TYPES = ("LATENT",)
@ -298,21 +307,22 @@ class TrimVideoLatent:
samples_out["samples"] = s1[:, :, trim_amount:]
return (samples_out,)
class WanCameraImageToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"vae": ("VAE",),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"camera_conditions": ("WAN_CAMERA_EMBEDDING", ),
}}
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
"start_image": ("IMAGE",),
"camera_conditions": ("WAN_CAMERA_EMBEDDING",),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
@ -328,7 +338,7 @@ class WanCameraImageToVideo:
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
concat_latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image[:, :, :concat_latent.shape[2]]
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
@ -345,19 +355,20 @@ class WanCameraImageToVideo:
out_latent["samples"] = latent
return (positive, negative, out_latent)
class WanPhantomSubjectToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"vae": ("VAE",),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"images": ("IMAGE", ),
}}
},
"optional": {"images": ("IMAGE",),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent")
@ -383,7 +394,316 @@ class WanPhantomSubjectToVideo:
out_latent["samples"] = latent
return (positive, cond2, negative, out_latent)
def parse_json_tracks(tracks):
"""Parse JSON track data into a standardized format"""
tracks_data = []
try:
# If tracks is a string, try to parse it as JSON
if isinstance(tracks, str):
parsed = json.loads(tracks.replace("'", '"'))
tracks_data.extend(parsed)
else:
# If tracks is a list of strings, parse each one
for track_str in tracks:
parsed = json.loads(track_str.replace("'", '"'))
tracks_data.append(parsed)
# Check if we have a single track (dict with x,y) or a list of tracks
if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]:
# Single track detected, wrap it in a list
tracks_data = [tracks_data]
elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]:
# Already a list of tracks, nothing to do
pass
else:
# Unexpected format
pass
except json.JSONDecodeError:
tracks_data = []
return tracks_data
def process_tracks(tracks_np: np.ndarray, frame_size: Tuple[int, int], num_frames, quant_multi: int = 8, **kwargs):
# tracks: shape [t, h, w, 3] => samples align with 24 fps, model trained with 16 fps.
# frame_size: tuple (W, H)
tracks = torch.from_numpy(tracks_np).float()
if tracks.shape[1] == 121:
tracks = torch.permute(tracks, (1, 0, 2, 3))
tracks, visibles = tracks[..., :2], tracks[..., 2:3]
short_edge = min(*frame_size)
frame_center = torch.tensor([*frame_size]).type_as(tracks) / 2
tracks = tracks - frame_center
tracks = tracks / short_edge * 2
visibles = visibles * 2 - 1
trange = torch.linspace(-1, 1, tracks.shape[0]).view(-1, 1, 1, 1).expand(*visibles.shape)
out_ = torch.cat([trange, tracks, visibles], dim=-1).view(121, -1, 4)
out_0 = out_[:1]
out_l = out_[1:] # 121 => 120 | 1
a = 120 // math.gcd(120, num_frames)
b = num_frames // math.gcd(120, num_frames)
out_l = torch.repeat_interleave(out_l, b, dim=0)[1::a] # 120 => 120 * b => 120 * b / a == F
final_result = torch.cat([out_0, out_l], dim=0)
return final_result
FIXED_LENGTH = 121
def pad_pts(tr):
"""Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating."""
pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32)
n = pts.shape[0]
if n < FIXED_LENGTH:
pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32)
pts = np.vstack((pts, pad))
else:
pts = pts[:FIXED_LENGTH]
return pts.reshape(FIXED_LENGTH, 1, 3)
def ind_sel(target: torch.Tensor, ind: torch.Tensor, dim: int = 1):
"""Index selection utility function"""
assert (
len(ind.shape) > dim
), "Index must have the target dim, but get dim: %d, ind shape: %s" % (dim, str(ind.shape))
target = target.expand(
*tuple(
[ind.shape[k] if target.shape[k] == 1 else -1 for k in range(dim)]
+ [
-1,
]
* (len(target.shape) - dim)
)
)
ind_pad = ind
if len(target.shape) > dim + 1:
for _ in range(len(target.shape) - (dim + 1)):
ind_pad = ind_pad.unsqueeze(-1)
ind_pad = ind_pad.expand(*(-1,) * (dim + 1), *target.shape[(dim + 1)::])
return torch.gather(target, dim=dim, index=ind_pad)
def merge_final(vert_attr: torch.Tensor, weight: torch.Tensor, vert_assign: torch.Tensor):
"""Merge vertex attributes with weights"""
target_dim = len(vert_assign.shape) - 1
if len(vert_attr.shape) == 2:
assert vert_attr.shape[0] > vert_assign.max()
new_shape = [1] * target_dim + list(vert_attr.shape)
tensor = vert_attr.reshape(new_shape)
sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim)
else:
assert vert_attr.shape[1] > vert_assign.max()
new_shape = [vert_attr.shape[0]] + [1] * (target_dim - 1) + list(vert_attr.shape[1:])
tensor = vert_attr.reshape(new_shape)
sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim)
final_attr = torch.sum(sel_attr * weight.unsqueeze(-1), dim=-2)
return final_attr
def _patch_motion_single(
tracks: torch.FloatTensor, # (B, T, N, 4)
vid: torch.FloatTensor, # (C, T, H, W)
temperature: float,
vae_divide: tuple,
topk: int,
):
"""Apply motion patching based on tracks"""
_, T, H, W = vid.shape
N = tracks.shape[2]
_, tracks_xy, visible = torch.split(
tracks, [1, 2, 1], dim=-1
) # (B, T, N, 2) | (B, T, N, 1)
tracks_n = tracks_xy / torch.tensor([W / min(H, W), H / min(H, W)], device=tracks_xy.device)
tracks_n = tracks_n.clamp(-1, 1)
visible = visible.clamp(0, 1)
xx = torch.linspace(-W / min(H, W), W / min(H, W), W)
yy = torch.linspace(-H / min(H, W), H / min(H, W), H)
grid = torch.stack(torch.meshgrid(yy, xx, indexing="ij")[::-1], dim=-1).to(
tracks_xy.device
)
tracks_pad = tracks_xy[:, 1:]
visible_pad = visible[:, 1:]
visible_align = visible_pad.view(T - 1, 4, *visible_pad.shape[2:]).sum(1)
tracks_align = (tracks_pad * visible_pad).view(T - 1, 4, *tracks_pad.shape[2:]).sum(
1
) / (visible_align + 1e-5)
dist_ = (
(tracks_align[:, None, None] - grid[None, :, :, None]).pow(2).sum(-1)
) # T, H, W, N
weight = torch.exp(-dist_ * temperature) * visible_align.clamp(0, 1).view(
T - 1, 1, 1, N
)
vert_weight, vert_index = torch.topk(
weight, k=min(topk, weight.shape[-1]), dim=-1
)
grid_mode = "bilinear"
point_feature = torch.nn.functional.grid_sample(
vid.permute(1, 0, 2, 3)[:1],
tracks_n[:, :1].type(vid.dtype),
mode=grid_mode,
padding_mode="zeros",
align_corners=False,
)
point_feature = point_feature.squeeze(0).squeeze(1).permute(1, 0) # N, C=16
out_feature = merge_final(point_feature, vert_weight, vert_index).permute(3, 0, 1, 2) # T - 1, H, W, C => C, T - 1, H, W
out_weight = vert_weight.sum(-1) # T - 1, H, W
# out feature -> already soft weighted
mix_feature = out_feature + vid[:, 1:] * (1 - out_weight.clamp(0, 1))
out_feature_full = torch.cat([vid[:, :1], mix_feature], dim=1) # C, T, H, W
out_mask_full = torch.cat([torch.ones_like(out_weight[:1]), out_weight], dim=0) # T, H, W
return out_mask_full[None].expand(vae_divide[0], -1, -1, -1), out_feature_full
def patch_motion(
tracks: torch.FloatTensor, # (B, TB, T, N, 4)
vid: torch.FloatTensor, # (C, T, H, W)
temperature: float = 220.0,
vae_divide: tuple = (4, 16),
topk: int = 2,
):
B = len(tracks)
# Process each batch separately
out_masks = []
out_features = []
for b in range(B):
mask, feature = _patch_motion_single(
tracks[b], # (T, N, 4)
vid[b], # (C, T, H, W)
temperature,
vae_divide,
topk
)
out_masks.append(mask)
out_features.append(feature)
# Stack results: (B, C, T, H, W)
out_mask_full = torch.stack(out_masks, dim=0)
out_feature_full = torch.stack(out_features, dim=0)
return out_mask_full, out_feature_full
class WanTrackToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"vae": ("VAE",),
"tracks": ("STRING", {"multiline": True, "default": "[]"}),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
"topk": ("INT", {"default": 2, "min": 1, "max": 10}),
"start_image": ("IMAGE",),
},
"optional": {
"clip_vision_output": ("CLIP_VISION_OUTPUT",),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, tracks, width, height, length, batch_size,
temperature, topk, start_image=None, clip_vision_output=None):
tracks_data = parse_json_tracks(tracks)
if not tracks_data:
return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output)
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
device=comfy.model_management.intermediate_device())
if isinstance(tracks_data[0][0], dict):
tracks_data = [tracks_data]
processed_tracks = []
for batch in tracks_data:
arrs = []
for track in batch:
pts = pad_pts(track)
arrs.append(pts)
tracks_np = np.stack(arrs, axis=0)
processed_tracks.append(process_tracks(tracks_np, (width, height), length - 1).unsqueeze(0))
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:batch_size].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
videos = torch.ones((start_image.shape[0], length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
for i in range(start_image.shape[0]):
videos[i, 0] = start_image[i]
latent_videos = []
videos = comfy.utils.resize_to_batch_size(videos, batch_size)
for i in range(batch_size):
latent_videos += [vae.encode(videos[i, :, :, :, :3])]
y = torch.cat(latent_videos, dim=0)
# Scale latent since patch_motion is non-linear
y = comfy.latent_formats.Wan21().process_in(y)
processed_tracks = comfy.utils.resize_list_to_batch_size(processed_tracks, batch_size)
res = patch_motion(
processed_tracks, y, temperature=temperature, topk=topk, vae_divide=(4, 16)
)
mask, concat_latent_image = res
concat_latent_image = comfy.latent_formats.Wan21().process_out(concat_latent_image)
mask = -mask + 1.0 # Invert mask to match expected format
positive = node_helpers.conditioning_set_values(positive,
{"concat_mask": mask,
"concat_latent_image": concat_latent_image})
negative = node_helpers.conditioning_set_values(negative,
{"concat_mask": mask,
"concat_latent_image": concat_latent_image})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
NODE_CLASS_MAPPINGS = {
"WanTrackToVideo": WanTrackToVideo,
"WanImageToVideo": WanImageToVideo,
"WanFunControlToVideo": WanFunControlToVideo,
"WanFunInpaintToVideo": WanFunInpaintToVideo,

View File

@ -1,6 +1,6 @@
[project]
name = "comfyui"
version = "0.3.44"
version = "0.3.45"
description = "An installable version of ComfyUI"
readme = "README.md"
authors = [
@ -18,9 +18,9 @@ classifiers = [
]
dependencies = [
"comfyui-frontend-package",
"comfyui-workflow-templates",
"comfyui-embedded-docs",
"comfyui-frontend-package>=1.23.4",
"comfyui-workflow-templates>=0.1.40",
"comfyui-embedded-docs>=0.2.4",
"torch",
"torchvision",
"torchdiffeq>=0.2.3",

View File

@ -2,7 +2,7 @@ import argparse
import pytest
from requests.exceptions import HTTPError
from unittest.mock import patch
from unittest.mock import patch, mock_open
from comfy.app.frontend_management import (
FrontendManager,
@ -171,3 +171,36 @@ def test_init_frontend_fallback_on_error():
# Assert
assert frontend_path == "/default/path"
mock_check.assert_called_once()
def test_get_frontend_version():
# Arrange
expected_version = "1.25.0"
mock_requirements_content = """torch
torchsde
comfyui-frontend-package==1.25.0
other-package==1.0.0
numpy"""
# Act
with patch("builtins.open", mock_open(read_data=mock_requirements_content)):
version = FrontendManager.get_required_frontend_version()
# Assert
assert version == expected_version
def test_get_frontend_version_invalid_semver():
# Arrange
mock_requirements_content = """torch
torchsde
comfyui-frontend-package==1.29.3.75
other-package==1.0.0
numpy"""
# Act
with patch("builtins.open", mock_open(read_data=mock_requirements_content)):
version = FrontendManager.get_required_frontend_version()
# Assert
assert version is None