Merge branch 'master' into flipflop-stream

This commit is contained in:
Jedrzej Kosinski 2025-10-13 21:04:37 -07:00
commit 586a8de8da
63 changed files with 5076 additions and 3205 deletions

View File

@ -206,14 +206,32 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae Put your VAE in: models/vae
### AMD GPUs (Linux only) ### AMD GPUs (Linux)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4``` ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements: This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0```
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware.
RDNA 3 (RX 7000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/```
RDNA 4 (RX 9000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/```
### Intel GPUs (Windows and Linux) ### Intel GPUs (Windows and Linux)
@ -270,12 +288,6 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux). > **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
#### DirectML (AMD Cards on Windows)
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
#### Ascend NPUs #### Ascend NPUs
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method: For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:

View File

@ -23,8 +23,6 @@ class MusicDCAE(torch.nn.Module):
else: else:
self.source_sample_rate = source_sample_rate self.source_sample_rate = source_sample_rate
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
self.transform = transforms.Compose([ self.transform = transforms.Compose([
transforms.Normalize(0.5, 0.5), transforms.Normalize(0.5, 0.5),
]) ])
@ -37,10 +35,6 @@ class MusicDCAE(torch.nn.Module):
self.scale_factor = 0.1786 self.scale_factor = 0.1786
self.shift_factor = -1.9091 self.shift_factor = -1.9091
def load_audio(self, audio_path):
audio, sr = torchaudio.load(audio_path)
return audio, sr
def forward_mel(self, audios): def forward_mel(self, audios):
mels = [] mels = []
for i in range(len(audios)): for i in range(len(audios)):
@ -73,10 +67,8 @@ class MusicDCAE(torch.nn.Module):
latent = self.dcae.encoder(mel.unsqueeze(0)) latent = self.dcae.encoder(mel.unsqueeze(0))
latents.append(latent) latents.append(latent)
latents = torch.cat(latents, dim=0) latents = torch.cat(latents, dim=0)
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
latents = (latents - self.shift_factor) * self.scale_factor latents = (latents - self.shift_factor) * self.scale_factor
return latents return latents
# return latents, latent_lengths
@torch.no_grad() @torch.no_grad()
def decode(self, latents, audio_lengths=None, sr=None): def decode(self, latents, audio_lengths=None, sr=None):
@ -91,9 +83,7 @@ class MusicDCAE(torch.nn.Module):
wav = self.vocoder.decode(mels[0]).squeeze(1) wav = self.vocoder.decode(mels[0]).squeeze(1)
if sr is not None: if sr is not None:
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
wav = torchaudio.functional.resample(wav, 44100, sr) wav = torchaudio.functional.resample(wav, 44100, sr)
# wav = resampler(wav)
else: else:
sr = 44100 sr = 44100
pred_wavs.append(wav) pred_wavs.append(wav)
@ -101,7 +91,6 @@ class MusicDCAE(torch.nn.Module):
if audio_lengths is not None: if audio_lengths is not None:
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)] pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
return torch.stack(pred_wavs) return torch.stack(pred_wavs)
# return sr, pred_wavs
def forward(self, audios, audio_lengths=None, sr=None): def forward(self, audios, audio_lengths=None, sr=None):
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr) latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)

View File

View File

@ -0,0 +1,120 @@
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
import comfy.model_management
class Snake(nn.Module):
'''
Implementation of a sine-based periodic activation function
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha: trainable parameter
alpha is initialized to 1 by default, higher values = higher-frequency.
alpha will be trained along with the rest of your model.
'''
super(Snake, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale:
self.alpha = Parameter(torch.empty(in_features))
else:
self.alpha = Parameter(torch.empty(in_features))
self.alpha.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa)
'''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
class SnakeBeta(nn.Module):
'''
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
'''
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale:
self.alpha = Parameter(torch.empty(in_features))
self.beta = Parameter(torch.empty(in_features))
else:
self.alpha = Parameter(torch.empty(in_features))
self.beta = Parameter(torch.empty(in_features))
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
'''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x

View File

@ -0,0 +1,157 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import comfy.model_management
if 'sinc' in dir(torch):
sinc = torch.sinc
else:
# This code is adopted from adefossez's julius.core.sinc under the MIT License
# https://adefossez.github.io/julius/julius/core.html
# LICENSE is in incl_licenses directory.
def sinc(x: torch.Tensor):
"""
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
"""
return torch.where(x == 0,
torch.tensor(1., device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x)
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
even = (kernel_size % 2 == 0)
half_size = kernel_size // 2
#For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.:
beta = 0.1102 * (A - 8.7)
elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
else:
beta = 0.
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if even:
time = (torch.arange(-half_size, half_size) + 0.5)
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(self,
cutoff=0.5,
half_width=0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = 'replicate',
kernel_size: int = 12):
# kernel_size should be even number for stylegan3 setup,
# in this implementation, odd number is also possible.
super().__init__()
if cutoff < -0.:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = (kernel_size % 2 == 0)
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
#input [B, C, T]
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode)
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
stride=self.stride, groups=C)
return out
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size)
self.register_buffer("filter", filter)
# x: [B, C, T]
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * F.conv_transpose1d(
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
x = x[..., self.pad_left:-self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size)
def forward(self, x):
xx = self.lowpass(x)
return xx
class Activation1d(nn.Module):
def __init__(self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
# x: [B,C,T]
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x

View File

@ -0,0 +1,156 @@
from typing import Literal
import torch
import torch.nn as nn
from .distributions import DiagonalGaussianDistribution
from .vae import VAE_16k
from .bigvgan import BigVGANVocoder
import logging
try:
import torchaudio
except:
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
return norm_fn(torch.clamp(x, min=clip_val) * C)
def spectral_normalize_torch(magnitudes, norm_fn):
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
return output
class MelConverter(nn.Module):
def __init__(
self,
*,
sampling_rate: float,
n_fft: int,
num_mels: int,
hop_size: int,
win_size: int,
fmin: float,
fmax: float,
norm_fn,
):
super().__init__()
self.sampling_rate = sampling_rate
self.n_fft = n_fft
self.num_mels = num_mels
self.hop_size = hop_size
self.win_size = win_size
self.fmin = fmin
self.fmax = fmax
self.norm_fn = norm_fn
# mel = librosa_mel_fn(sr=self.sampling_rate,
# n_fft=self.n_fft,
# n_mels=self.num_mels,
# fmin=self.fmin,
# fmax=self.fmax)
# mel_basis = torch.from_numpy(mel).float()
mel_basis = torch.empty((num_mels, 1 + n_fft // 2))
hann_window = torch.hann_window(self.win_size)
self.register_buffer('mel_basis', mel_basis)
self.register_buffer('hann_window', hann_window)
@property
def device(self):
return self.mel_basis.device
def forward(self, waveform: torch.Tensor, center: bool = False) -> torch.Tensor:
waveform = waveform.clamp(min=-1., max=1.).to(self.device)
waveform = torch.nn.functional.pad(
waveform.unsqueeze(1),
[int((self.n_fft - self.hop_size) / 2),
int((self.n_fft - self.hop_size) / 2)],
mode='reflect')
waveform = waveform.squeeze(1)
spec = torch.stft(waveform,
self.n_fft,
hop_length=self.hop_size,
win_length=self.win_size,
window=self.hann_window,
center=center,
pad_mode='reflect',
normalized=False,
onesided=True,
return_complex=True)
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(self.mel_basis, spec)
spec = spectral_normalize_torch(spec, self.norm_fn)
return spec
class AudioAutoencoder(nn.Module):
def __init__(
self,
*,
# ckpt_path: str,
mode=Literal['16k', '44k'],
need_vae_encoder: bool = True,
):
super().__init__()
assert mode == "16k", "Only 16k mode is supported currently."
self.mel_converter = MelConverter(sampling_rate=16_000,
n_fft=1024,
num_mels=80,
hop_size=256,
win_size=1024,
fmin=0,
fmax=8_000,
norm_fn=torch.log10)
self.vae = VAE_16k().eval()
bigvgan_config = {
"resblock": "1",
"num_mels": 80,
"upsample_rates": [4, 4, 2, 2, 2, 2],
"upsample_kernel_sizes": [8, 8, 4, 4, 4, 4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [
[1, 3, 5],
[1, 3, 5],
[1, 3, 5],
],
"activation": "snakebeta",
"snake_logscale": True,
}
self.vocoder = BigVGANVocoder(
bigvgan_config
).eval()
@torch.inference_mode()
def encode_audio(self, x) -> DiagonalGaussianDistribution:
# x: (B * L)
mel = self.mel_converter(x)
dist = self.vae.encode(mel)
return dist
@torch.no_grad()
def decode(self, z):
mel_decoded = self.vae.decode(z)
audio = self.vocoder(mel_decoded)
audio = torchaudio.functional.resample(audio, 16000, 44100)
return audio
@torch.no_grad()
def encode(self, audio):
audio = audio.mean(dim=1)
audio = torchaudio.functional.resample(audio, 44100, 16000)
dist = self.encode_audio(audio)
return dist.mean

View File

@ -0,0 +1,219 @@
# Copyright (c) 2022 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
import torch.nn as nn
from types import SimpleNamespace
from . import activations
from .alias_free_torch import Activation1d
import comfy.ops
ops = comfy.ops.disable_weight_init
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class AMPBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
super(AMPBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]))
])
self.convs2 = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1))
])
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = xt + x
return x
class AMPBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
super(AMPBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))
])
self.num_layers = len(self.convs) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
for c, a in zip(self.convs, self.activations):
xt = a(x)
xt = c(xt)
x = xt + x
return x
class BigVGANVocoder(torch.nn.Module):
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
def __init__(self, h):
super().__init__()
if isinstance(h, dict):
h = SimpleNamespace(**h)
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
# pre conv
self.conv_pre = ops.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
# transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
nn.ModuleList([
ops.ConvTranspose1d(h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2**(i + 1)),
k,
u,
padding=(k - u) // 2)
]))
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2**(i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
# post conv
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
def forward(self, x):
# pre conv
x = self.conv_pre(x)
for i in range(self.num_upsamples):
# upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x)
# AMP blocks
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
# post conv
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x

View File

@ -0,0 +1,92 @@
import torch
import numpy as np
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)

View File

@ -0,0 +1,358 @@
import logging
from typing import Optional
import torch
import torch.nn as nn
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
Upsample1D, nonlinearity)
from .distributions import DiagonalGaussianDistribution
import comfy.ops
ops = comfy.ops.disable_weight_init
log = logging.getLogger()
DATA_MEAN_80D = [
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
-1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
-1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
-1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
-1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
-2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
-2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
]
DATA_STD_80D = [
1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
]
DATA_MEAN_128D = [
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
-2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
-2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
-3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
-3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
-3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
-4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
-5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
-6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
-7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
]
DATA_STD_128D = [
2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
]
class VAE(nn.Module):
def __init__(
self,
*,
data_dim: int,
embed_dim: int,
hidden_dim: int,
):
super().__init__()
if data_dim == 80:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
elif data_dim == 128:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
self.data_mean = self.data_mean.view(1, -1, 1)
self.data_std = self.data_std.view(1, -1, 1)
self.encoder = Encoder1D(
dim=hidden_dim,
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_layers=[3],
down_layers=[0],
in_dim=data_dim,
embed_dim=embed_dim,
)
self.decoder = Decoder1D(
dim=hidden_dim,
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_layers=[3],
down_layers=[0],
in_dim=data_dim,
out_dim=data_dim,
embed_dim=embed_dim,
)
self.embed_dim = embed_dim
# self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
# self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
self.initialize_weights()
def initialize_weights(self):
pass
def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
if normalize:
x = self.normalize(x)
moments = self.encoder(x)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
dec = self.decoder(z)
if unnormalize:
dec = self.unnormalize(dec)
return dec
def normalize(self, x: torch.Tensor) -> torch.Tensor:
return (x - comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)) / comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device)
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)
def forward(
self,
x: torch.Tensor,
sample_posterior: bool = True,
rng: Optional[torch.Generator] = None,
normalize: bool = True,
unnormalize: bool = True,
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
posterior = self.encode(x, normalize=normalize)
if sample_posterior:
z = posterior.sample(rng)
else:
z = posterior.mode()
dec = self.decode(z, unnormalize=unnormalize)
return dec, posterior
def load_weights(self, src_dict) -> None:
self.load_state_dict(src_dict, strict=True)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def get_last_layer(self):
return self.decoder.conv_out.weight
def remove_weight_norm(self):
return self
class Encoder1D(nn.Module):
def __init__(self,
*,
dim: int,
ch_mult: tuple[int] = (1, 2, 4, 8),
num_res_blocks: int,
attn_layers: list[int] = [],
down_layers: list[int] = [],
resamp_with_conv: bool = True,
in_dim: int,
embed_dim: int,
double_z: bool = True,
kernel_size: int = 3,
clip_act: float = 256.0):
super().__init__()
self.dim = dim
self.num_layers = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_dim
self.clip_act = clip_act
self.down_layers = down_layers
self.attn_layers = attn_layers
self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
in_ch_mult = (1, ) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
# downsampling
self.down = nn.ModuleList()
for i_level in range(self.num_layers):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = dim * in_ch_mult[i_level]
block_out = dim * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock1D(in_dim=block_in,
out_dim=block_out,
kernel_size=kernel_size,
use_norm=True))
block_in = block_out
if i_level in attn_layers:
attn.append(AttnBlock1D(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level in down_layers:
down.downsample = Downsample1D(block_in, resamp_with_conv)
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
out_dim=block_in,
kernel_size=kernel_size,
use_norm=True)
self.mid.attn_1 = AttnBlock1D(block_in)
self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
out_dim=block_in,
kernel_size=kernel_size,
use_norm=True)
# end
self.conv_out = ops.Conv1d(block_in,
2 * embed_dim if double_z else embed_dim,
kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.learnable_gain = nn.Parameter(torch.zeros([]))
def forward(self, x):
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_layers):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
h = h.clamp(-self.clip_act, self.clip_act)
if i_level in self.down_layers:
h = self.down[i_level].downsample(h)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = h.clamp(-self.clip_act, self.clip_act)
# end
h = nonlinearity(h)
h = self.conv_out(h) * (self.learnable_gain + 1)
return h
class Decoder1D(nn.Module):
def __init__(self,
*,
dim: int,
out_dim: int,
ch_mult: tuple[int] = (1, 2, 4, 8),
num_res_blocks: int,
attn_layers: list[int] = [],
down_layers: list[int] = [],
kernel_size: int = 3,
resamp_with_conv: bool = True,
in_dim: int,
embed_dim: int,
clip_act: float = 256.0):
super().__init__()
self.ch = dim
self.num_layers = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_dim
self.clip_act = clip_act
self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = dim * ch_mult[self.num_layers - 1]
# z to block_in
self.conv_in = ops.Conv1d(embed_dim, block_in, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
self.mid.attn_1 = AttnBlock1D(block_in)
self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_layers)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = dim * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
block_in = block_out
if i_level in attn_layers:
attn.append(AttnBlock1D(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level in self.down_layers:
up.upsample = Upsample1D(block_in, resamp_with_conv)
self.up.insert(0, up) # prepend to get consistent order
# end
self.conv_out = ops.Conv1d(block_in, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.learnable_gain = nn.Parameter(torch.zeros([]))
def forward(self, z):
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = h.clamp(-self.clip_act, self.clip_act)
# upsampling
for i_level in reversed(range(self.num_layers)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
h = h.clamp(-self.clip_act, self.clip_act)
if i_level in self.down_layers:
h = self.up[i_level].upsample(h)
h = nonlinearity(h)
h = self.conv_out(h) * (self.learnable_gain + 1)
return h
def VAE_16k(**kwargs) -> VAE:
return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
def VAE_44k(**kwargs) -> VAE:
return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
def get_my_vae(name: str, **kwargs) -> VAE:
if name == '16k':
return VAE_16k(**kwargs)
if name == '44k':
return VAE_44k(**kwargs)
raise ValueError(f'Unknown model: {name}')

View File

@ -0,0 +1,121 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import vae_attention
import math
import comfy.ops
ops = comfy.ops.disable_weight_init
def nonlinearity(x):
# swish
return torch.nn.functional.silu(x) / 0.596
def mp_sum(a, b, t=0.5):
return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2)
def normalize(x, dim=None, eps=1e-4):
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
class ResnetBlock1D(nn.Module):
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
super().__init__()
self.in_dim = in_dim
out_dim = in_dim if out_dim is None else out_dim
self.out_dim = out_dim
self.use_conv_shortcut = conv_shortcut
self.use_norm = use_norm
self.conv1 = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.conv2 = ops.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
if self.in_dim != self.out_dim:
if self.use_conv_shortcut:
self.conv_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
else:
self.nin_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=1, padding=0, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# pixel norm
if self.use_norm:
x = normalize(x, dim=1)
h = x
h = nonlinearity(h)
h = self.conv1(h)
h = nonlinearity(h)
h = self.conv2(h)
if self.in_dim != self.out_dim:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return mp_sum(x, h, t=0.3)
class AttnBlock1D(nn.Module):
def __init__(self, in_channels, num_heads=1):
super().__init__()
self.in_channels = in_channels
self.num_heads = num_heads
self.qkv = ops.Conv1d(in_channels, in_channels * 3, kernel_size=1, padding=0, bias=False)
self.proj_out = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
self.optimized_attention = vae_attention()
def forward(self, x):
h = x
y = self.qkv(h)
y = y.reshape(y.shape[0], -1, 3, y.shape[-1])
q, k, v = normalize(y, dim=1).unbind(2)
h = self.optimized_attention(q, k, v)
h = self.proj_out(h)
return mp_sum(x, h, t=0.3)
class Upsample1D(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = ops.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
if self.with_conv:
x = self.conv(x)
return x
class Downsample1D(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv1 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
self.conv2 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
def forward(self, x):
if self.with_conv:
x = self.conv1(x)
x = F.avg_pool1d(x, kernel_size=2, stride=2)
if self.with_conv:
x = self.conv2(x)
return x

View File

@ -657,51 +657,51 @@ class WanVAE(nn.Module):
) )
def encode(self, x): def encode(self, x):
self.clear_cache() conv_idx = [0]
feat_map = [None] * count_conv3d(self.encoder)
x = patchify(x, patch_size=2) x = patchify(x, patch_size=2)
t = x.shape[2] t = x.shape[2]
iter_ = 1 + (t - 1) // 4 iter_ = 1 + (t - 1) // 4
for i in range(iter_): for i in range(iter_):
self._enc_conv_idx = [0] conv_idx = [0]
if i == 0: if i == 0:
out = self.encoder( out = self.encoder(
x[:, :, :1, :, :], x[:, :, :1, :, :],
feat_cache=self._enc_feat_map, feat_cache=feat_map,
feat_idx=self._enc_conv_idx, feat_idx=conv_idx,
) )
else: else:
out_ = self.encoder( out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map, feat_cache=feat_map,
feat_idx=self._enc_conv_idx, feat_idx=conv_idx,
) )
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1) mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu return mu
def decode(self, z): def decode(self, z):
self.clear_cache() conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
iter_ = z.shape[2] iter_ = z.shape[2]
x = self.conv2(z) x = self.conv2(z)
for i in range(iter_): for i in range(iter_):
self._conv_idx = [0] conv_idx = [0]
if i == 0: if i == 0:
out = self.decoder( out = self.decoder(
x[:, :, i:i + 1, :, :], x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map, feat_cache=feat_map,
feat_idx=self._conv_idx, feat_idx=conv_idx,
first_chunk=True, first_chunk=True,
) )
else: else:
out_ = self.decoder( out_ = self.decoder(
x[:, :, i:i + 1, :, :], x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map, feat_cache=feat_map,
feat_idx=self._conv_idx, feat_idx=conv_idx,
) )
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
out = unpatchify(out, patch_size=2) out = unpatchify(out, patch_size=2)
self.clear_cache()
return out return out
def reparameterize(self, mu, log_var): def reparameterize(self, mu, log_var):
@ -715,12 +715,3 @@ class WanVAE(nn.Module):
return mu return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std) return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@ -138,6 +138,7 @@ class BaseModel(torch.nn.Module):
else: else:
operations = model_config.custom_operations operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
self.diffusion_model.eval()
if comfy.model_management.force_channels_last(): if comfy.model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last) self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model") logging.debug("using channels last mode for diffusion model")
@ -669,7 +670,6 @@ class Lotus(BaseModel):
class StableCascade_C(BaseModel): class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageC) super().__init__(model_config, model_type, device=device, unet_model=StageC)
self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = {} out = {}
@ -698,7 +698,6 @@ class StableCascade_C(BaseModel):
class StableCascade_B(BaseModel): class StableCascade_B(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageB) super().__init__(model_config, model_type, device=device, unet_model=StageB)
self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = {} out = {}

View File

@ -365,8 +365,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["patch_size"] = 2 dit_config["patch_size"] = 2
dit_config["in_channels"] = 16 dit_config["in_channels"] = 16
dit_config["dim"] = 2304 dit_config["dim"] = 2304
dit_config["cap_feat_dim"] = 2304 dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1]
dit_config["n_layers"] = 26 dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["n_heads"] = 24 dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8 dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True dit_config["qk_norm"] = True

View File

@ -332,6 +332,7 @@ except:
SUPPORT_FP8_OPS = args.supports_fp8_compute SUPPORT_FP8_OPS = args.supports_fp8_compute
try: try:
if is_amd(): if is_amd():
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
try: try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except: except:
@ -344,9 +345,9 @@ try:
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
# if torch_version_numeric >= (2, 8): if rocm_version >= (7, 0):
# if any((a in arch) for a in ["gfx1201"]): if any((a in arch) for a in ["gfx1201"]):
# ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches
SUPPORT_FP8_OPS = True SUPPORT_FP8_OPS = True
@ -925,11 +926,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
if d == torch.float16 and should_use_fp16(device): if d == torch.float16 and should_use_fp16(device):
return d return d
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32 if d == torch.bfloat16 and should_use_bf16(device):
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
# also a problem on RDNA4 except fp32 is also slow there.
# This is due to large bf16 convolutions being extremely slow.
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
return d return d
return torch.float32 return torch.float32

View File

@ -123,16 +123,30 @@ def move_weight_functions(m, device):
return memory return memory
class LowVramPatch: class LowVramPatch:
def __init__(self, key, patches): def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key self.key = key
self.patches = patches self.patches = patches
self.convert_func = convert_func
self.set_func = set_func
def __call__(self, weight): def __call__(self, weight):
intermediate_dtype = weight.dtype intermediate_dtype = weight.dtype
if self.convert_func is not None:
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32 intermediate_dtype = torch.float32
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key)) out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is None:
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
else:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype) out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is not None:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
else:
return out
def get_key_weight(model, key): def get_key_weight(model, key):
set_func = None set_func = None
@ -754,13 +768,15 @@ class ModelPatcher:
if force_patch_weights: if force_patch_weights:
self.patch_weight_to_device(weight_key) self.patch_weight_to_device(weight_key)
else: else:
m.weight_function = [LowVramPatch(weight_key, self.patches)] _, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
patch_counter += 1 patch_counter += 1
if bias_key in self.patches: if bias_key in self.patches:
if force_patch_weights: if force_patch_weights:
self.patch_weight_to_device(bias_key) self.patch_weight_to_device(bias_key)
else: else:
m.bias_function = [LowVramPatch(bias_key, self.patches)] _, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
patch_counter += 1 patch_counter += 1
cast_weight = True cast_weight = True
@ -957,10 +973,12 @@ class ModelPatcher:
module_mem += move_weight_functions(m, device_to) module_mem += move_weight_functions(m, device_to)
if lowvram_possible: if lowvram_possible:
if weight_key in self.patches: if weight_key in self.patches:
m.weight_function.append(LowVramPatch(weight_key, self.patches)) _, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
patch_counter += 1 patch_counter += 1
if bias_key in self.patches: if bias_key in self.patches:
m.bias_function.append(LowVramPatch(bias_key, self.patches)) _, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
patch_counter += 1 patch_counter += 1
cast_weight = True cast_weight = True

View File

@ -21,17 +21,23 @@ def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_bar[-1] = 4.8973451890853435e-08 alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5 return ((1 - alphas_bar) / alphas_bar) ** 0.5
def reshape_sigma(sigma, noise_dim):
if sigma.nelement() == 1:
return sigma.view(())
else:
return sigma.view(sigma.shape[:1] + (1,) * (noise_dim - 1))
class EPS: class EPS:
def calculate_input(self, sigma, noise): def calculate_input(self, sigma, noise):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) sigma = reshape_sigma(sigma, noise.ndim)
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
def calculate_denoised(self, sigma, model_output, model_input): def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) sigma = reshape_sigma(sigma, model_output.ndim)
return model_input - model_output * sigma return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) sigma = reshape_sigma(sigma, noise.ndim)
if max_denoise: if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0) noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else: else:
@ -45,12 +51,12 @@ class EPS:
class V_PREDICTION(EPS): class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input): def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class EDM(V_PREDICTION): class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input): def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class CONST: class CONST:
@ -58,15 +64,15 @@ class CONST:
return noise return noise
def calculate_denoised(self, sigma, model_output, model_input): def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) sigma = reshape_sigma(sigma, model_output.ndim)
return model_input - model_output * sigma return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) sigma = reshape_sigma(sigma, noise.ndim)
return sigma * noise + (1.0 - sigma) * latent_image return sigma * noise + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma, latent): def inverse_noise_scaling(self, sigma, latent):
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1)) sigma = reshape_sigma(sigma, latent.ndim)
return latent / (1.0 - sigma) return latent / (1.0 - sigma)
class X0(EPS): class X0(EPS):
@ -80,16 +86,16 @@ class IMG_TO_IMG(X0):
class COSMOS_RFLOW: class COSMOS_RFLOW:
def calculate_input(self, sigma, noise): def calculate_input(self, sigma, noise):
sigma = (sigma / (sigma + 1)) sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) sigma = reshape_sigma(sigma, noise.ndim)
return noise * (1.0 - sigma) return noise * (1.0 - sigma)
def calculate_denoised(self, sigma, model_output, model_input): def calculate_denoised(self, sigma, model_output, model_input):
sigma = (sigma / (sigma + 1)) sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * (1.0 - sigma) - model_output * sigma return model_input * (1.0 - sigma) - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) sigma = reshape_sigma(sigma, noise.ndim)
noise = noise * sigma noise = noise * sigma
noise += latent_image noise += latent_image
return noise return noise

View File

@ -24,6 +24,8 @@ import comfy.float
import comfy.rmsnorm import comfy.rmsnorm
import contextlib import contextlib
def run_every_op():
comfy.model_management.throw_exception_if_processing_interrupted()
def scaled_dot_product_attention(q, k, v, *args, **kwargs): def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
@ -109,6 +111,7 @@ class disable_weight_init:
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs) return self.forward_comfy_cast_weights(*args, **kwargs)
else: else:
@ -123,6 +126,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias) return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs) return self.forward_comfy_cast_weights(*args, **kwargs)
else: else:
@ -137,6 +141,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias) return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs) return self.forward_comfy_cast_weights(*args, **kwargs)
else: else:
@ -151,6 +156,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias) return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs) return self.forward_comfy_cast_weights(*args, **kwargs)
else: else:
@ -165,6 +171,7 @@ class disable_weight_init:
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs) return self.forward_comfy_cast_weights(*args, **kwargs)
else: else:
@ -183,6 +190,7 @@ class disable_weight_init:
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs) return self.forward_comfy_cast_weights(*args, **kwargs)
else: else:
@ -202,6 +210,7 @@ class disable_weight_init:
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs) return self.forward_comfy_cast_weights(*args, **kwargs)
else: else:
@ -223,6 +232,7 @@ class disable_weight_init:
output_padding, self.groups, self.dilation) output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs) return self.forward_comfy_cast_weights(*args, **kwargs)
else: else:
@ -244,6 +254,7 @@ class disable_weight_init:
output_padding, self.groups, self.dilation) output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs) return self.forward_comfy_cast_weights(*args, **kwargs)
else: else:
@ -262,6 +273,7 @@ class disable_weight_init:
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs) return self.forward_comfy_cast_weights(*args, **kwargs)
else: else:
@ -416,8 +428,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
else: else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs): def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if return_weight:
return weight
if inplace_update: if inplace_update:
self.weight.data.copy_(weight) self.weight.data.copy_(weight)
else: else:

View File

@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_video.vae import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder
import comfy.pixel_space_convert import comfy.pixel_space_convert
import yaml import yaml
import math import math
@ -275,8 +276,13 @@ class VAE:
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd) sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) if model_management.is_amd():
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) VAE_KL_MEM_RATIO = 2.73
else:
VAE_KL_MEM_RATIO = 1.0
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO
self.downscale_ratio = 8 self.downscale_ratio = 8
self.upscale_ratio = 8 self.upscale_ratio = 8
self.latent_channels = 4 self.latent_channels = 4
@ -291,6 +297,7 @@ class VAE:
self.downscale_index_formula = None self.downscale_index_formula = None
self.upscale_index_formula = None self.upscale_index_formula = None
self.extra_1d_channel = None self.extra_1d_channel = None
self.crop_input = True
if config is None: if config is None:
if "decoder.mid.block_1.mix_factor" in sd: if "decoder.mid.block_1.mix_factor" in sd:
@ -542,6 +549,25 @@ class VAE:
self.latent_channels = 3 self.latent_channels = 3
self.latent_dim = 2 self.latent_dim = 2
self.output_channels = 3 self.output_channels = 3
elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE
sample_rate = 16000
if sample_rate == 16000:
mode = '16k'
else:
mode = '44k'
self.first_stage_model = comfy.ldm.mmaudio.vae.autoencoder.AudioAutoencoder(mode=mode)
self.memory_used_encode = lambda shape, dtype: (30 * shape[2]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (90 * shape[2] * 1411.2) * model_management.dtype_size(dtype)
self.latent_channels = 20
self.output_channels = 2
self.upscale_ratio = 512 * (44100 / sample_rate)
self.downscale_ratio = 512 * (44100 / sample_rate)
self.latent_dim = 1
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float32]
self.crop_input = False
else: else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.") logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None self.first_stage_model = None
@ -575,6 +601,9 @@ class VAE:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels): def vae_encode_crop_pixels(self, pixels):
if not self.crop_input:
return pixels
downscale_ratio = self.spacial_compression_encode() downscale_ratio = self.spacial_compression_encode()
dims = pixels.shape[1:-1] dims = pixels.shape[1:-1]
@ -890,6 +919,7 @@ class TEModel(Enum):
QWEN25_3B = 10 QWEN25_3B = 10
QWEN25_7B = 11 QWEN25_7B = 11
BYT5_SMALL_GLYPH = 12 BYT5_SMALL_GLYPH = 12
GEMMA_3_4B = 13
def detect_te_model(sd): def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -912,6 +942,8 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd: if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd: if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias'] weight = sd['model.layers.0.self_attn.k_proj.bias']
@ -1016,6 +1048,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8: elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)

View File

@ -3,6 +3,7 @@ import torch.nn as nn
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Any from typing import Optional, Any
import math import math
import logging
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management import comfy.model_management
@ -28,6 +29,9 @@ class Llama2Config:
mlp_activation = "silu" mlp_activation = "silu"
qkv_bias = False qkv_bias = False
rope_dims = None rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
@dataclass @dataclass
class Qwen25_3BConfig: class Qwen25_3BConfig:
@ -46,6 +50,9 @@ class Qwen25_3BConfig:
mlp_activation = "silu" mlp_activation = "silu"
qkv_bias = True qkv_bias = True
rope_dims = None rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
@dataclass @dataclass
class Qwen25_7BVLI_Config: class Qwen25_7BVLI_Config:
@ -64,6 +71,9 @@ class Qwen25_7BVLI_Config:
mlp_activation = "silu" mlp_activation = "silu"
qkv_bias = True qkv_bias = True
rope_dims = [16, 24, 24] rope_dims = [16, 24, 24]
q_norm = None
k_norm = None
rope_scale = None
@dataclass @dataclass
class Gemma2_2B_Config: class Gemma2_2B_Config:
@ -82,6 +92,32 @@ class Gemma2_2B_Config:
mlp_activation = "gelu_pytorch_tanh" mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False qkv_bias = False
rope_dims = None rope_dims = None
q_norm = None
k_norm = None
sliding_attention = None
rope_scale = None
@dataclass
class Gemma3_4B_Config:
vocab_size: int = 262208
hidden_size: int = 2560
intermediate_size: int = 10240
num_hidden_layers: int = 34
num_attention_heads: int = 8
num_key_value_heads: int = 4
max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6
rope_theta = [10000.0, 1000000.0]
transformer_type: str = "gemma3"
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0]
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@ -106,9 +142,20 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None): def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
if not isinstance(theta, list):
theta = [theta]
out = []
for index, t in enumerate(theta):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float() theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim)) inv_freq = 1.0 / (t ** (theta_numerator / head_dim))
if rope_scale is not None:
if isinstance(rope_scale, list):
inv_freq /= rope_scale[index]
else:
inv_freq /= rope_scale
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
@ -123,8 +170,12 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=N
else: else:
cos = cos.unsqueeze(1) cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1) sin = sin.unsqueeze(1)
out.append((cos, sin))
return (cos, sin) if len(out) == 1:
return out[0]
return out
def apply_rope(xq, xk, freqs_cis): def apply_rope(xq, xk, freqs_cis):
@ -152,6 +203,14 @@ class Attention(nn.Module):
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype) self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype) self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
self.q_norm = None
self.k_norm = None
if config.q_norm == "gemma3":
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.k_norm == "gemma3":
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -168,6 +227,11 @@ class Attention(nn.Module):
xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
if self.q_norm is not None:
xq = self.q_norm(xq)
if self.k_norm is not None:
xk = self.k_norm(xk)
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis) xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
@ -192,7 +256,7 @@ class MLP(nn.Module):
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x)) return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
super().__init__() super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops) self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
@ -226,7 +290,7 @@ class TransformerBlock(nn.Module):
return x return x
class TransformerBlockGemma2(nn.Module): class TransformerBlockGemma2(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
super().__init__() super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops) self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
@ -235,6 +299,13 @@ class TransformerBlockGemma2(nn.Module):
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
else:
self.sliding_attention = False
self.transformer_type = config.transformer_type
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -242,6 +313,14 @@ class TransformerBlockGemma2(nn.Module):
freqs_cis: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None, optimized_attention=None,
): ):
if self.transformer_type == 'gemma3':
if self.sliding_attention:
if x.shape[1] > self.sliding_attention:
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
freqs_cis = freqs_cis[1]
else:
freqs_cis = freqs_cis[0]
# Self Attention # Self Attention
residual = x residual = x
x = self.input_layernorm(x) x = self.input_layernorm(x)
@ -276,7 +355,7 @@ class Llama2_(nn.Module):
device=device, device=device,
dtype=dtype dtype=dtype
) )
if self.config.transformer_type == "gemma2": if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
transformer = TransformerBlockGemma2 transformer = TransformerBlockGemma2
self.normalize_in = True self.normalize_in = True
else: else:
@ -284,8 +363,8 @@ class Llama2_(nn.Module):
self.normalize_in = False self.normalize_in = False
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
transformer(config, device=device, dtype=dtype, ops=ops) transformer(config, index=i, device=device, dtype=dtype, ops=ops)
for _ in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
@ -305,6 +384,7 @@ class Llama2_(nn.Module):
freqs_cis = precompute_freqs_cis(self.config.head_dim, freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids, position_ids,
self.config.rope_theta, self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims, self.config.rope_dims,
device=x.device) device=x.device)
@ -433,3 +513,12 @@ class Gemma2_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Gemma3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

View File

@ -11,23 +11,41 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def state_dict(self): def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()} return {"spiece_model": self.tokenizer.serialize_model()}
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LuminaTokenizer(sd1_clip.SD1Tokenizer): class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
class NTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_4b", tokenizer=Gemma3_4BTokenizer)
class Gemma2_2BModel(sd1_clip.SDClipModel): class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class LuminaModel(sd1_clip.SD1ClipModel): class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options) super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None): def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
if model_type == "gemma2_2b":
model = Gemma2_2BModel
elif model_type == "gemma3_4b":
model = Gemma3_4BModel
class LuminaTEModel_(LuminaModel): class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
@ -35,5 +53,5 @@ def te(dtype_llama=None, llama_scaled_fp8=None):
model_options["scaled_fp8"] = llama_scaled_fp8 model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None: if dtype_llama is not None:
dtype = dtype_llama dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options) super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
return LuminaTEModel_ return LuminaTEModel_

View File

@ -39,7 +39,11 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
pass pass
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint" ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
from numpy.core.multiarray import scalar def scalar(*args, **kwargs):
from numpy.core.multiarray import scalar as sc
return sc(*args, **kwargs)
scalar.__module__ = "numpy.core.multiarray"
from numpy import dtype from numpy import dtype
from numpy.dtypes import Float64DType from numpy.dtypes import Float64DType
from _codecs import encode from _codecs import encode

View File

@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
from comfy_api.latest._io import _IO as io #noqa: F401 from . import _io as io
from comfy_api.latest._ui import _UI as ui #noqa: F401 from . import _ui as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 # from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple from comfy_execution.progress import get_progress_state, PreviewImageTuple
@ -114,6 +114,8 @@ if TYPE_CHECKING:
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub] ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest) ComfyAPISync = create_sync_class(ComfyAPI_latest)
comfy_io = io # create the new alias for io
__all__ = [ __all__ = [
"ComfyAPI", "ComfyAPI",
"ComfyAPISync", "ComfyAPISync",
@ -121,4 +123,7 @@ __all__ = [
"InputImpl", "InputImpl",
"Types", "Types",
"ComfyExtension", "ComfyExtension",
"io",
"comfy_io",
"ui",
] ]

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional, Union from typing import Optional, Union, IO
import io import io
import av import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
@ -23,7 +23,7 @@ class VideoInput(ABC):
@abstractmethod @abstractmethod
def save_to( def save_to(
self, self,
path: str, path: Union[str, IO[bytes]],
format: VideoContainer = VideoContainer.AUTO, format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO, codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None metadata: Optional[dict] = None

View File

@ -336,11 +336,25 @@ class Combo(ComfyTypeIO):
class Input(WidgetInput): class Input(WidgetInput):
"""Combo input (dropdown).""" """Combo input (dropdown)."""
Type = str Type = str
def __init__(self, id: str, options: list[str]=None, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(
default: str=None, control_after_generate: bool=None, self,
upload: UploadType=None, image_folder: FolderType=None, id: str,
options: list[str] | list[int] | type[Enum] = None,
display_name: str=None,
optional=False,
tooltip: str=None,
lazy: bool=None,
default: str | int | Enum = None,
control_after_generate: bool=None,
upload: UploadType=None,
image_folder: FolderType=None,
remote: RemoteOptions=None, remote: RemoteOptions=None,
socketless: bool=None): socketless: bool=None,
):
if isinstance(options, type) and issubclass(options, Enum):
options = [v.value for v in options]
if isinstance(default, Enum):
default = default.value
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
self.multiselect = False self.multiselect = False
self.options = options self.options = options
@ -1568,78 +1582,78 @@ class _UIOutput(ABC):
... ...
class _IO: __all__ = [
FolderType = FolderType "FolderType",
UploadType = UploadType "UploadType",
RemoteOptions = RemoteOptions "RemoteOptions",
NumberDisplay = NumberDisplay "NumberDisplay",
comfytype = staticmethod(comfytype) "comfytype",
Custom = staticmethod(Custom) "Custom",
Input = Input "Input",
WidgetInput = WidgetInput "WidgetInput",
Output = Output "Output",
ComfyTypeI = ComfyTypeI "ComfyTypeI",
ComfyTypeIO = ComfyTypeIO "ComfyTypeIO",
#---------------------------------
# Supported Types # Supported Types
Boolean = Boolean "Boolean",
Int = Int "Int",
Float = Float "Float",
String = String "String",
Combo = Combo "Combo",
MultiCombo = MultiCombo "MultiCombo",
Image = Image "Image",
WanCameraEmbedding = WanCameraEmbedding "WanCameraEmbedding",
Webcam = Webcam "Webcam",
Mask = Mask "Mask",
Latent = Latent "Latent",
Conditioning = Conditioning "Conditioning",
Sampler = Sampler "Sampler",
Sigmas = Sigmas "Sigmas",
Noise = Noise "Noise",
Guider = Guider "Guider",
Clip = Clip "Clip",
ControlNet = ControlNet "ControlNet",
Vae = Vae "Vae",
Model = Model "Model",
ClipVision = ClipVision "ClipVision",
ClipVisionOutput = ClipVisionOutput "ClipVisionOutput",
AudioEncoder = AudioEncoder "AudioEncoder",
AudioEncoderOutput = AudioEncoderOutput "AudioEncoderOutput",
StyleModel = StyleModel "StyleModel",
Gligen = Gligen "Gligen",
UpscaleModel = UpscaleModel "UpscaleModel",
Audio = Audio "Audio",
Video = Video "Video",
SVG = SVG "SVG",
LoraModel = LoraModel "LoraModel",
LossMap = LossMap "LossMap",
Voxel = Voxel "Voxel",
Mesh = Mesh "Mesh",
Hooks = Hooks "Hooks",
HookKeyframes = HookKeyframes "HookKeyframes",
TimestepsRange = TimestepsRange "TimestepsRange",
LatentOperation = LatentOperation "LatentOperation",
FlowControl = FlowControl "FlowControl",
Accumulation = Accumulation "Accumulation",
Load3DCamera = Load3DCamera "Load3DCamera",
Load3D = Load3D "Load3D",
Load3DAnimation = Load3DAnimation "Load3DAnimation",
Photomaker = Photomaker "Photomaker",
Point = Point "Point",
FaceAnalysis = FaceAnalysis "FaceAnalysis",
BBOX = BBOX "BBOX",
SEGS = SEGS "SEGS",
AnyType = AnyType "AnyType",
MultiType = MultiType "MultiType",
#--------------------------------- # Other classes
HiddenHolder = HiddenHolder "HiddenHolder",
Hidden = Hidden "Hidden",
NodeInfoV1 = NodeInfoV1 "NodeInfoV1",
NodeInfoV3 = NodeInfoV3 "NodeInfoV3",
Schema = Schema "Schema",
ComfyNode = ComfyNode "ComfyNode",
NodeOutput = NodeOutput "NodeOutput",
add_to_dict_v1 = staticmethod(add_to_dict_v1) "add_to_dict_v1",
add_to_dict_v3 = staticmethod(add_to_dict_v3) "add_to_dict_v3",
]

View File

@ -449,15 +449,16 @@ class PreviewText(_UIOutput):
return {"text": (self.value,)} return {"text": (self.value,)}
class _UI: __all__ = [
SavedResult = SavedResult "SavedResult",
SavedImages = SavedImages "SavedImages",
SavedAudios = SavedAudios "SavedAudios",
ImageSaveHelper = ImageSaveHelper "ImageSaveHelper",
AudioSaveHelper = AudioSaveHelper "AudioSaveHelper",
PreviewImage = PreviewImage "PreviewImage",
PreviewMask = PreviewMask "PreviewMask",
PreviewAudio = PreviewAudio "PreviewAudio",
PreviewVideo = PreviewVideo "PreviewVideo",
PreviewUI3D = PreviewUI3D "PreviewUI3D",
PreviewText = PreviewText "PreviewText",
]

View File

@ -18,7 +18,7 @@ from comfy_api_nodes.apis.client import (
UploadResponse, UploadResponse,
) )
from server import PromptServer from server import PromptServer
from comfy.cli_args import args
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -30,7 +30,9 @@ from io import BytesIO
import av import av
async def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile: async def download_url_to_video_output(
video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output. """Downloads a video from a URL and returns a `VIDEO` output.
Args: Args:
@ -39,7 +41,7 @@ async def download_url_to_video_output(video_url: str, timeout: int = None) -> V
Returns: Returns:
A Comfy node `VIDEO` output. A Comfy node `VIDEO` output.
""" """
video_io = await download_url_to_bytesio(video_url, timeout) video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs)
if video_io is None: if video_io is None:
error_msg = f"Failed to download video from {video_url}" error_msg = f"Failed to download video from {video_url}"
logging.error(error_msg) logging.error(error_msg)
@ -152,7 +154,7 @@ def validate_aspect_ratio(
raise TypeError( raise TypeError(
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})." f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
) )
elif calculated_ratio > maximum_ratio: if calculated_ratio > maximum_ratio:
raise TypeError( raise TypeError(
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})." f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
) )
@ -164,7 +166,9 @@ def mimetype_to_extension(mime_type: str) -> str:
return mime_type.split("/")[-1].lower() return mime_type.split("/")[-1].lower()
async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO: async def download_url_to_bytesio(
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO. """Downloads content from a URL using requests and returns it as BytesIO.
Args: Args:
@ -174,9 +178,18 @@ async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
Returns: Returns:
BytesIO object containing the downloaded content. BytesIO object containing the downloaded content.
""" """
headers = {}
if url.startswith("/proxy/"):
url = str(args.comfy_api_base).rstrip("/") + url
auth_token = auth_kwargs.get("auth_token")
comfy_api_key = auth_kwargs.get("comfy_api_key")
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
elif comfy_api_key:
headers["X-API-KEY"] = comfy_api_key
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
async with aiohttp.ClientSession(timeout=timeout_cfg) as session: async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.get(url) as resp: async with session.get(url, headers=headers) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX) resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read()) return BytesIO(await resp.read())
@ -256,7 +269,7 @@ def tensor_to_bytesio(
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns: Returns:
Named BytesIO object containing the image data. Named BytesIO object containing the image data, with pointer set to the start of buffer.
""" """
if not mime_type: if not mime_type:
mime_type = "image/png" mime_type = "image/png"
@ -418,7 +431,7 @@ async def upload_video_to_comfyapi(
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)." f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
) )
except Exception as e: except Exception as e:
logging.error(f"Error getting video duration: {e}") logging.error("Error getting video duration: %s", str(e))
raise ValueError(f"Could not verify video duration from source: {e}") from e raise ValueError(f"Could not verify video duration from source: {e}") from e
upload_mime_type = f"video/{container.value.lower()}" upload_mime_type = f"video/{container.value.lower()}"

View File

@ -98,7 +98,7 @@ import io
import os import os
import socket import socket
from aiohttp.client_exceptions import ClientError, ClientResponseError from aiohttp.client_exceptions import ClientError, ClientResponseError
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple from typing import Type, Optional, Any, TypeVar, Generic, Callable
from enum import Enum from enum import Enum
import json import json
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
@ -175,7 +175,7 @@ class ApiClient:
max_retries: int = 3, max_retries: int = 3,
retry_delay: float = 1.0, retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0, retry_backoff_factor: float = 2.0,
retry_status_codes: Optional[Tuple[int, ...]] = None, retry_status_codes: Optional[tuple[int, ...]] = None,
session: Optional[aiohttp.ClientSession] = None, session: Optional[aiohttp.ClientSession] = None,
): ):
self.base_url = base_url self.base_url = base_url
@ -199,9 +199,9 @@ class ApiClient:
@staticmethod @staticmethod
def _create_json_payload_args( def _create_json_payload_args(
data: Optional[Dict[str, Any]] = None, data: Optional[dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
return { return {
"json": data, "json": data,
"headers": headers, "headers": headers,
@ -209,17 +209,20 @@ class ApiClient:
def _create_form_data_args( def _create_form_data_args(
self, self,
data: Dict[str, Any] | None, data: dict[str, Any] | None,
files: Dict[str, Any] | None, files: dict[str, Any] | None,
headers: Optional[Dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
multipart_parser: Callable | None = None, multipart_parser: Callable | None = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
if headers and "Content-Type" in headers: if headers and "Content-Type" in headers:
del headers["Content-Type"] del headers["Content-Type"]
if multipart_parser and data: if multipart_parser and data:
data = multipart_parser(data) data = multipart_parser(data)
if isinstance(data, aiohttp.FormData):
form = data # If the parser already returned a FormData, pass it through
else:
form = aiohttp.FormData(default_to_multipart=True) form = aiohttp.FormData(default_to_multipart=True)
if data: # regular text fields if data: # regular text fields
for k, v in data.items(): for k, v in data.items():
@ -251,9 +254,9 @@ class ApiClient:
@staticmethod @staticmethod
def _create_urlencoded_form_data_args( def _create_urlencoded_form_data_args(
data: Dict[str, Any], data: dict[str, Any],
headers: Optional[Dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
headers = headers or {} headers = headers or {}
headers["Content-Type"] = "application/x-www-form-urlencoded" headers["Content-Type"] = "application/x-www-form-urlencoded"
return { return {
@ -261,7 +264,7 @@ class ApiClient:
"headers": headers, "headers": headers,
} }
def get_headers(self) -> Dict[str, str]: def get_headers(self) -> dict[str, str]:
"""Get headers for API requests, including authentication if available""" """Get headers for API requests, including authentication if available"""
headers = {"Content-Type": "application/json", "Accept": "application/json"} headers = {"Content-Type": "application/json", "Accept": "application/json"}
@ -272,7 +275,7 @@ class ApiClient:
return headers return headers
async def _check_connectivity(self, target_url: str) -> Dict[str, bool]: async def _check_connectivity(self, target_url: str) -> dict[str, bool]:
""" """
Check connectivity to determine if network issues are local or server-related. Check connectivity to determine if network issues are local or server-related.
@ -313,14 +316,14 @@ class ApiClient:
self, self,
method: str, method: str,
path: str, path: str,
params: Optional[Dict[str, Any]] = None, params: Optional[dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None, data: Optional[dict[str, Any]] = None,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
content_type: str = "application/json", content_type: str = "application/json",
multipart_parser: Callable | None = None, multipart_parser: Callable | None = None,
retry_count: int = 0, # Used internally for tracking retries retry_count: int = 0, # Used internally for tracking retries
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Make an HTTP request to the API with automatic retries for transient errors. Make an HTTP request to the API with automatic retries for transient errors.
@ -356,10 +359,10 @@ class ApiClient:
if params: if params:
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
logging.debug(f"[DEBUG] Request Headers: {request_headers}") logging.debug("[DEBUG] Request Headers: %s", request_headers)
logging.debug(f"[DEBUG] Files: {files}") logging.debug("[DEBUG] Files: %s", files)
logging.debug(f"[DEBUG] Params: {params}") logging.debug("[DEBUG] Params: %s", params)
logging.debug(f"[DEBUG] Data: {data}") logging.debug("[DEBUG] Data: %s", data)
if content_type == "application/x-www-form-urlencoded": if content_type == "application/x-www-form-urlencoded":
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers) payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
@ -482,7 +485,7 @@ class ApiClient:
retry_delay: Initial delay between retries in seconds retry_delay: Initial delay between retries in seconds
retry_backoff_factor: Multiplier for the delay after each retry retry_backoff_factor: Multiplier for the delay after each retry
""" """
headers: Dict[str, str] = {} headers: dict[str, str] = {}
skip_auto_headers: set[str] = set() skip_auto_headers: set[str] = set()
if content_type: if content_type:
headers["Content-Type"] = content_type headers["Content-Type"] = content_type
@ -555,7 +558,7 @@ class ApiClient:
*req_meta, *req_meta,
retry_count: int, retry_count: int,
response_content: dict | str = "", response_content: dict | str = "",
) -> Dict[str, Any]: ) -> dict[str, Any]:
status_code = exc.status status_code = exc.status
if status_code == 401: if status_code == 401:
user_friendly = "Unauthorized: Please login first to use this node." user_friendly = "Unauthorized: Please login first to use this node."
@ -589,9 +592,9 @@ class ApiClient:
error_message=f"HTTP Error {exc.status}", error_message=f"HTTP Error {exc.status}",
) )
logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})") logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code)
if response_content: if response_content:
logging.debug(f"[DEBUG] Response content: {response_content}") logging.debug("[DEBUG] Response content: %s", response_content)
# Retry if eligible # Retry if eligible
if status_code in self.retry_status_codes and retry_count < self.max_retries: if status_code in self.retry_status_codes and retry_count < self.max_retries:
@ -656,7 +659,7 @@ class ApiEndpoint(Generic[T, R]):
method: HttpMethod, method: HttpMethod,
request_model: Type[T], request_model: Type[T],
response_model: Type[R], response_model: Type[R],
query_params: Optional[Dict[str, Any]] = None, query_params: Optional[dict[str, Any]] = None,
): ):
"""Initialize an API endpoint definition. """Initialize an API endpoint definition.
@ -681,11 +684,11 @@ class SynchronousOperation(Generic[T, R]):
self, self,
endpoint: ApiEndpoint[T, R], endpoint: ApiEndpoint[T, R],
request: T, request: T,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
api_base: str | None = None, api_base: str | None = None,
auth_token: Optional[str] = None, auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None, comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None, auth_kwargs: Optional[dict[str, str]] = None,
timeout: float = 7200.0, timeout: float = 7200.0,
verify_ssl: bool = True, verify_ssl: bool = True,
content_type: str = "application/json", content_type: str = "application/json",
@ -726,7 +729,7 @@ class SynchronousOperation(Generic[T, R]):
) )
try: try:
request_dict: Optional[Dict[str, Any]] request_dict: Optional[dict[str, Any]]
if isinstance(self.request, EmptyRequest): if isinstance(self.request, EmptyRequest):
request_dict = None request_dict = None
else: else:
@ -735,11 +738,9 @@ class SynchronousOperation(Generic[T, R]):
if isinstance(v, Enum): if isinstance(v, Enum):
request_dict[k] = v.value request_dict[k] = v.value
logging.debug( logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path)
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}" logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2))
) logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params)
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
response_json = await client.request( response_json = await client.request(
self.endpoint.method.value, self.endpoint.method.value,
@ -754,11 +755,11 @@ class SynchronousOperation(Generic[T, R]):
logging.debug("=" * 50) logging.debug("=" * 50)
logging.debug("[DEBUG] RESPONSE DETAILS:") logging.debug("[DEBUG] RESPONSE DETAILS:")
logging.debug("[DEBUG] Status Code: 200 (Success)") logging.debug("[DEBUG] Status Code: 200 (Success)")
logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}") logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2))
logging.debug("=" * 50) logging.debug("=" * 50)
parsed_response = self.endpoint.response_model.model_validate(response_json) parsed_response = self.endpoint.response_model.model_validate(response_json)
logging.debug(f"[DEBUG] Parsed Response: {parsed_response}") logging.debug("[DEBUG] Parsed Response: %s", parsed_response)
return parsed_response return parsed_response
finally: finally:
if owns_client: if owns_client:
@ -781,14 +782,16 @@ class PollingOperation(Generic[T, R]):
poll_endpoint: ApiEndpoint[EmptyRequest, R], poll_endpoint: ApiEndpoint[EmptyRequest, R],
completed_statuses: list[str], completed_statuses: list[str],
failed_statuses: list[str], failed_statuses: list[str],
status_extractor: Callable[[R], str], *,
progress_extractor: Callable[[R], float] | None = None, status_extractor: Callable[[R], Optional[str]],
result_url_extractor: Callable[[R], str] | None = None, progress_extractor: Callable[[R], Optional[float]] | None = None,
result_url_extractor: Callable[[R], Optional[str]] | None = None,
price_extractor: Callable[[R], Optional[float]] | None = None,
request: Optional[T] = None, request: Optional[T] = None,
api_base: str | None = None, api_base: str | None = None,
auth_token: Optional[str] = None, auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None, comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None, auth_kwargs: Optional[dict[str, str]] = None,
poll_interval: float = 5.0, poll_interval: float = 5.0,
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval) max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
max_retries: int = 3, # Max retries per individual API call max_retries: int = 3, # Max retries per individual API call
@ -814,10 +817,12 @@ class PollingOperation(Generic[T, R]):
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None)) self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None))
self.progress_extractor = progress_extractor self.progress_extractor = progress_extractor
self.result_url_extractor = result_url_extractor self.result_url_extractor = result_url_extractor
self.price_extractor = price_extractor
self.node_id = node_id self.node_id = node_id
self.completed_statuses = completed_statuses self.completed_statuses = completed_statuses
self.failed_statuses = failed_statuses self.failed_statuses = failed_statuses
self.final_response: Optional[R] = None self.final_response: Optional[R] = None
self.extracted_price: Optional[float] = None
async def execute(self, client: Optional[ApiClient] = None) -> R: async def execute(self, client: Optional[ApiClient] = None) -> R:
owns_client = client is None owns_client = client is None
@ -839,6 +844,8 @@ class PollingOperation(Generic[T, R]):
def _display_text_on_node(self, text: str): def _display_text_on_node(self, text: str):
if not self.node_id: if not self.node_id:
return return
if self.extracted_price is not None:
text = f"Price: {self.extracted_price}$\n{text}"
PromptServer.instance.send_progress_text(text, self.node_id) PromptServer.instance.send_progress_text(text, self.node_id)
def _display_time_progress_on_node(self, time_completed: int | float): def _display_time_progress_on_node(self, time_completed: int | float):
@ -874,18 +881,19 @@ class PollingOperation(Generic[T, R]):
status = TaskStatus.PENDING status = TaskStatus.PENDING
for poll_count in range(1, self.max_poll_attempts + 1): for poll_count in range(1, self.max_poll_attempts + 1):
try: try:
logging.debug(f"[DEBUG] Polling attempt #{poll_count}") logging.debug("[DEBUG] Polling attempt #%s", poll_count)
request_dict = ( request_dict = None if self.request is None else self.request.model_dump(exclude_none=True)
None if self.request is None else self.request.model_dump(exclude_none=True)
)
if poll_count == 1: if poll_count == 1:
logging.debug( logging.debug(
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}" "[DEBUG] Poll Request: %s %s",
self.poll_endpoint.method.value,
self.poll_endpoint.path,
) )
logging.debug( logging.debug(
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}" "[DEBUG] Poll Request Data: %s",
json.dumps(request_dict, indent=2) if request_dict else "None",
) )
# Query task status # Query task status
@ -900,7 +908,7 @@ class PollingOperation(Generic[T, R]):
# Check if task is complete # Check if task is complete
status = self._check_task_status(response_obj) status = self._check_task_status(response_obj)
logging.debug(f"[DEBUG] Task Status: {status}") logging.debug("[DEBUG] Task Status: %s", status)
# If progress extractor is provided, extract progress # If progress extractor is provided, extract progress
if self.progress_extractor: if self.progress_extractor:
@ -908,13 +916,18 @@ class PollingOperation(Generic[T, R]):
if new_progress is not None: if new_progress is not None:
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX) progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
if self.price_extractor:
price = self.price_extractor(response_obj)
if price is not None:
self.extracted_price = price
if status == TaskStatus.COMPLETED: if status == TaskStatus.COMPLETED:
message = "Task completed successfully" message = "Task completed successfully"
if self.result_url_extractor: if self.result_url_extractor:
result_url = self.result_url_extractor(response_obj) result_url = self.result_url_extractor(response_obj)
if result_url: if result_url:
message = f"Result URL: {result_url}" message = f"Result URL: {result_url}"
logging.debug(f"[DEBUG] {message}") logging.debug("[DEBUG] %s", message)
self._display_text_on_node(message) self._display_text_on_node(message)
self.final_response = response_obj self.final_response = response_obj
if self.progress_extractor: if self.progress_extractor:
@ -922,7 +935,7 @@ class PollingOperation(Generic[T, R]):
return self.final_response return self.final_response
if status == TaskStatus.FAILED: if status == TaskStatus.FAILED:
message = f"Task failed: {json.dumps(resp)}" message = f"Task failed: {json.dumps(resp)}"
logging.error(f"[DEBUG] {message}") logging.error("[DEBUG] %s", message)
raise Exception(message) raise Exception(message)
logging.debug("[DEBUG] Task still pending, continuing to poll...") logging.debug("[DEBUG] Task still pending, continuing to poll...")
# Task pending wait # Task pending wait
@ -936,7 +949,12 @@ class PollingOperation(Generic[T, R]):
raise Exception( raise Exception(
f"Polling aborted after {consecutive_errors} network errors: {str(e)}" f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
) from e ) from e
logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e)) logging.warning(
"Network error (%s/%s): %s",
consecutive_errors,
max_consecutive_errors,
str(e),
)
await asyncio.sleep(self.poll_interval) await asyncio.sleep(self.poll_interval)
except Exception as e: except Exception as e:
# For other errors, increment count and potentially abort # For other errors, increment count and potentially abort
@ -946,10 +964,13 @@ class PollingOperation(Generic[T, R]):
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}" f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
) from e ) from e
logging.error(f"[DEBUG] Polling error: {str(e)}") logging.error("[DEBUG] Polling error: %s", str(e))
logging.warning( logging.warning(
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " "Error during polling (attempt %s/%s): %s. Will retry in %s seconds.",
f"Will retry in {self.poll_interval} seconds." poll_count,
self.max_poll_attempts,
str(e),
self.poll_interval,
) )
await asyncio.sleep(self.poll_interval) await asyncio.sleep(self.poll_interval)

View File

@ -1,19 +1,22 @@
from __future__ import annotations from typing import Optional
from typing import List, Optional
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
from pydantic import BaseModel from pydantic import BaseModel
class GeminiImageConfig(BaseModel):
aspectRatio: Optional[str] = None
class GeminiImageGenerationConfig(GeminiGenerationConfig): class GeminiImageGenerationConfig(GeminiGenerationConfig):
responseModalities: Optional[List[str]] = None responseModalities: Optional[list[str]] = None
imageConfig: Optional[GeminiImageConfig] = None
class GeminiImageGenerateContentRequest(BaseModel): class GeminiImageGenerateContentRequest(BaseModel):
contents: List[GeminiContent] contents: list[GeminiContent]
generationConfig: Optional[GeminiImageGenerationConfig] = None generationConfig: Optional[GeminiImageGenerationConfig] = None
safetySettings: Optional[List[GeminiSafetySetting]] = None safetySettings: Optional[list[GeminiSafetySetting]] = None
systemInstruction: Optional[GeminiSystemInstructionContent] = None systemInstruction: Optional[GeminiSystemInstructionContent] = None
tools: Optional[List[GeminiTool]] = None tools: Optional[list[GeminiTool]] = None
videoMetadata: Optional[GeminiVideoMetadata] = None videoMetadata: Optional[GeminiVideoMetadata] = None

View File

@ -0,0 +1,100 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
duration: Optional[int] = Field(5)
ingredientsMode: str = Field(...)
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = Field('1080p')
seed: Optional[int] = Field(None)
class PikaGenerateResponse(BaseModel):
video_id: str = Field(...)
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
duration: Optional[int] = Field(None, ge=5, le=10)
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
aspectRatio: Optional[float] = Field(
1.7777777777777777,
description='Aspect ratio (width / height)',
ge=0.4,
le=2.5,
)
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
pikaffect: Optional[str] = None
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
modifyRegionRoi: Optional[str] = Field(None)
class PikaStatusEnum(str, Enum):
queued = "queued"
started = "started"
finished = "finished"
failed = "failed"
class PikaVideoResponse(BaseModel):
id: str = Field(...)
progress: Optional[int] = Field(None)
status: PikaStatusEnum
url: Optional[str] = Field(None)

View File

@ -21,7 +21,7 @@ def get_log_directory():
try: try:
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
except Exception as e: except Exception as e:
logger.error(f"Error creating API log directory {log_dir}: {e}") logger.error("Error creating API log directory %s: %s", log_dir, str(e))
# Fallback to base temp directory if sub-directory creation fails # Fallback to base temp directory if sub-directory creation fails
return base_temp_dir return base_temp_dir
return log_dir return log_dir
@ -122,9 +122,9 @@ def log_request_response(
try: try:
with open(filepath, "w", encoding="utf-8") as f: with open(filepath, "w", encoding="utf-8") as f:
f.write("\n".join(log_content)) f.write("\n".join(log_content))
logger.debug(f"API log saved to: {filepath}") logger.debug("API log saved to: %s", filepath)
except Exception as e: except Exception as e:
logger.error(f"Error writing API log to {filepath}: {e}") logger.error("Error writing API log to %s: %s", filepath, str(e))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -249,8 +249,8 @@ class ByteDanceImageNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in Text2ImageModelName], options=Text2ImageModelName,
default=Text2ImageModelName.seedream_3.value, default=Text2ImageModelName.seedream_3,
tooltip="Model name", tooltip="Model name",
), ),
comfy_io.String.Input( comfy_io.String.Input(
@ -382,8 +382,8 @@ class ByteDanceImageEditNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in Image2ImageModelName], options=Image2ImageModelName,
default=Image2ImageModelName.seededit_3.value, default=Image2ImageModelName.seededit_3,
tooltip="Model name", tooltip="Model name",
), ),
comfy_io.Image.Input( comfy_io.Image.Input(
@ -676,8 +676,8 @@ class ByteDanceTextToVideoNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in Text2VideoModelName], options=Text2VideoModelName,
default=Text2VideoModelName.seedance_1_pro.value, default=Text2VideoModelName.seedance_1_pro,
tooltip="Model name", tooltip="Model name",
), ),
comfy_io.String.Input( comfy_io.String.Input(
@ -793,8 +793,8 @@ class ByteDanceImageToVideoNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in Image2VideoModelName], options=Image2VideoModelName,
default=Image2VideoModelName.seedance_1_pro.value, default=Image2VideoModelName.seedance_1_pro,
tooltip="Model name", tooltip="Model name",
), ),
comfy_io.String.Input( comfy_io.String.Input(

View File

@ -26,7 +26,7 @@ from comfy_api_nodes.apis import (
GeminiPart, GeminiPart,
GeminiMimeType, GeminiMimeType,
) )
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest, GeminiImageConfig
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.apis.client import (
ApiEndpoint, ApiEndpoint,
HttpMethod, HttpMethod,
@ -39,6 +39,7 @@ from comfy_api_nodes.apinode_utils import (
tensor_to_base64_string, tensor_to_base64_string,
bytesio_to_image_tensor, bytesio_to_image_tensor,
) )
from comfy_api.util import VideoContainer, VideoCodec
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
@ -62,6 +63,7 @@ class GeminiImageModel(str, Enum):
""" """
gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview" gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview"
gemini_2_5_flash_image = "gemini-2.5-flash-image"
def get_gemini_endpoint( def get_gemini_endpoint(
@ -310,7 +312,7 @@ class GeminiNode(ComfyNodeABC):
Returns: Returns:
List of GeminiPart objects containing the encoded video. List of GeminiPart objects containing the encoded video.
""" """
from comfy_api.util import VideoContainer, VideoCodec
base_64_string = video_to_base64_string( base_64_string = video_to_base64_string(
video_input, video_input,
container_format=VideoContainer.MP4, container_format=VideoContainer.MP4,
@ -537,7 +539,7 @@ class GeminiImage(ComfyNodeABC):
{ {
"tooltip": "The Gemini model to use for generating responses.", "tooltip": "The Gemini model to use for generating responses.",
"options": [model.value for model in GeminiImageModel], "options": [model.value for model in GeminiImageModel],
"default": GeminiImageModel.gemini_2_5_flash_image_preview.value, "default": GeminiImageModel.gemini_2_5_flash_image.value,
}, },
), ),
"seed": ( "seed": (
@ -578,6 +580,14 @@ class GeminiImage(ComfyNodeABC):
# "tooltip": "How many images to generate", # "tooltip": "How many images to generate",
# }, # },
# ), # ),
"aspect_ratio": (
IO.COMBO,
{
"tooltip": "Defaults to matching the output image size to that of your input image, or otherwise generates 1:1 squares.",
"options": ["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
"default": "auto",
},
),
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
@ -599,15 +609,17 @@ class GeminiImage(ComfyNodeABC):
images: Optional[IO.IMAGE] = None, images: Optional[IO.IMAGE] = None,
files: Optional[list[GeminiPart]] = None, files: Optional[list[GeminiPart]] = None,
n=1, n=1,
aspect_ratio: str = "auto",
unique_id: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs, **kwargs,
): ):
# Validate inputs
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [create_text_part(prompt)] parts: list[GeminiPart] = [create_text_part(prompt)]
# Add other modal parts if not aspect_ratio:
aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December
image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
if images is not None: if images is not None:
image_parts = create_image_parts(images) image_parts = create_image_parts(images)
parts.extend(image_parts) parts.extend(image_parts)
@ -624,7 +636,8 @@ class GeminiImage(ComfyNodeABC):
), ),
], ],
generationConfig=GeminiImageGenerationConfig( generationConfig=GeminiImageGenerationConfig(
responseModalities=["TEXT","IMAGE"] responseModalities=["TEXT","IMAGE"],
imageConfig=None if aspect_ratio == "auto" else image_config,
) )
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,

File diff suppressed because it is too large Load Diff

View File

@ -181,11 +181,11 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in LumaImageModel], options=LumaImageModel,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=[ratio.value for ratio in LumaAspectRatio], options=LumaAspectRatio,
default=LumaAspectRatio.ratio_16_9, default=LumaAspectRatio.ratio_16_9,
), ),
comfy_io.Int.Input( comfy_io.Int.Input(
@ -366,7 +366,7 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in LumaImageModel], options=LumaImageModel,
), ),
comfy_io.Int.Input( comfy_io.Int.Input(
"seed", "seed",
@ -466,21 +466,21 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in LumaVideoModel], options=LumaVideoModel,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=[ratio.value for ratio in LumaAspectRatio], options=LumaAspectRatio,
default=LumaAspectRatio.ratio_16_9, default=LumaAspectRatio.ratio_16_9,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"resolution", "resolution",
options=[resolution.value for resolution in LumaVideoOutputResolution], options=LumaVideoOutputResolution,
default=LumaVideoOutputResolution.res_540p, default=LumaVideoOutputResolution.res_540p,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"duration", "duration",
options=[dur.value for dur in LumaVideoModelOutputDuration], options=LumaVideoModelOutputDuration,
), ),
comfy_io.Boolean.Input( comfy_io.Boolean.Input(
"loop", "loop",
@ -595,7 +595,7 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in LumaVideoModel], options=LumaVideoModel,
), ),
# comfy_io.Combo.Input( # comfy_io.Combo.Input(
# "aspect_ratio", # "aspect_ratio",
@ -604,7 +604,7 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
# ), # ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"resolution", "resolution",
options=[resolution.value for resolution in LumaVideoOutputResolution], options=LumaVideoOutputResolution,
default=LumaVideoOutputResolution.res_540p, default=LumaVideoOutputResolution.res_540p,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(

View File

@ -500,7 +500,7 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
raise Exception( raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}" f"No video was found in the response. Full response: {file_result.model_dump()}"
) )
logging.info(f"Generated video URL: {file_url}") logging.info("Generated video URL: %s", file_url)
if cls.hidden.unique_id: if cls.hidden.unique_id:
if hasattr(file_result.file, "backup_download_url"): if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"

View File

@ -2,11 +2,7 @@ import logging
from typing import Any, Callable, Optional, TypeVar from typing import Any, Callable, Optional, TypeVar
import torch import torch
from typing_extensions import override from typing_extensions import override
from comfy_api_nodes.util.validation_utils import ( from comfy_api_nodes.util.validation_utils import validate_image_dimensions
get_image_dimensions,
validate_image_dimensions,
)
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import (
MoonvalleyTextToVideoRequest, MoonvalleyTextToVideoRequest,
@ -132,47 +128,6 @@ def validate_prompts(
return True return True
def validate_input_media(width, height, with_frame_conditioning, num_frames_in=None):
# inference validation
# T = num_frames
# in all cases, the following must be true: T divisible by 16 and H,W by 8. in addition...
# with image conditioning: H*W must be divisible by 8192
# without image conditioning: T divisible by 32
if num_frames_in and not num_frames_in % 16 == 0:
return False, ("The input video total frame count must be divisible by 16!")
if height % 8 != 0 or width % 8 != 0:
return False, (
f"Height ({height}) and width ({width}) must be " "divisible by 8"
)
if with_frame_conditioning:
if (height * width) % 8192 != 0:
return False, (
f"Height * width ({height * width}) must be "
"divisible by 8192 for frame conditioning"
)
else:
if num_frames_in and not num_frames_in % 32 == 0:
return False, ("The input video total frame count must be divisible by 32!")
def validate_input_image(
image: torch.Tensor, with_frame_conditioning: bool = False
) -> None:
"""
Validates the input image adheres to the expectations of the API:
- The image resolution should not be less than 300*300px
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
"""
height, width = get_image_dimensions(image)
validate_input_media(width, height, with_frame_conditioning)
validate_image_dimensions(
image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH
)
def validate_video_to_video_input(video: VideoInput) -> VideoInput: def validate_video_to_video_input(video: VideoInput) -> VideoInput:
""" """
Validates and processes video input for Moonvalley Video-to-Video generation. Validates and processes video input for Moonvalley Video-to-Video generation.
@ -282,7 +237,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
audio_stream = None audio_stream = None
for stream in input_container.streams: for stream in input_container.streams:
logging.info(f"Found stream: type={stream.type}, class={type(stream)}") logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
if isinstance(stream, av.VideoStream): if isinstance(stream, av.VideoStream):
# Create output video stream with same parameters # Create output video stream with same parameters
video_stream = output_container.add_stream( video_stream = output_container.add_stream(
@ -292,7 +247,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
video_stream.height = stream.height video_stream.height = stream.height
video_stream.pix_fmt = "yuv420p" video_stream.pix_fmt = "yuv420p"
logging.info( logging.info(
f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps" "Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate
) )
elif isinstance(stream, av.AudioStream): elif isinstance(stream, av.AudioStream):
# Create output audio stream with same parameters # Create output audio stream with same parameters
@ -301,9 +256,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
) )
audio_stream.sample_rate = stream.sample_rate audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout audio_stream.layout = stream.layout
logging.info( logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
)
# Calculate target frame count that's divisible by 16 # Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate fps = input_container.streams.video[0].average_rate
@ -333,9 +286,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
for packet in video_stream.encode(): for packet in video_stream.encode():
output_container.mux(packet) output_container.mux(packet)
logging.info( logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
f"Encoded {frame_count} video frames (target: {target_frames})"
)
# Decode and re-encode audio frames # Decode and re-encode audio frames
if audio_stream: if audio_stream:
@ -353,7 +304,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
for packet in audio_stream.encode(): for packet in audio_stream.encode():
output_container.mux(packet) output_container.mux(packet)
logging.info(f"Encoded {audio_frame_count} audio frames") logging.info("Encoded %s audio frames", audio_frame_count)
# Close containers # Close containers
output_container.close() output_container.close()
@ -380,7 +331,7 @@ def parse_width_height_from_res(resolution: str):
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152}, "1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152}, "4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536}, "3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080}, # "21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
} }
return res_map.get(resolution, {"width": 1920, "height": 1080}) return res_map.get(resolution, {"width": 1920, "height": 1080})
@ -448,14 +399,14 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
"1:1 (1152 x 1152)", "1:1 (1152 x 1152)",
"4:3 (1536 x 1152)", "4:3 (1536 x 1152)",
"3:4 (1152 x 1536)", "3:4 (1152 x 1536)",
"21:9 (2560 x 1080)", # "21:9 (2560 x 1080)",
], ],
default="16:9 (1920 x 1080)", default="16:9 (1920 x 1080)",
tooltip="Resolution of the output video", tooltip="Resolution of the output video",
), ),
comfy_io.Float.Input( comfy_io.Float.Input(
"prompt_adherence", "prompt_adherence",
default=10.0, default=4.5,
min=1.0, min=1.0,
max=20.0, max=20.0,
step=1.0, step=1.0,
@ -469,10 +420,11 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
step=1, step=1,
display_mode=comfy_io.NumberDisplay.number, display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed value", tooltip="Random seed value",
control_after_generate=True,
), ),
comfy_io.Int.Input( comfy_io.Int.Input(
"steps", "steps",
default=100, default=33,
min=1, min=1,
max=100, max=100,
step=1, step=1,
@ -499,7 +451,7 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
seed: int, seed: int,
steps: int, steps: int,
) -> comfy_io.NodeOutput: ) -> comfy_io.NodeOutput:
validate_input_image(image, True) validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution) width_height = parse_width_height_from_res(resolution)
@ -513,12 +465,11 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
steps=steps, steps=steps,
seed=seed, seed=seed,
guidance_scale=prompt_adherence, guidance_scale=prompt_adherence,
num_frames=128,
width=width_height["width"], width=width_height["width"],
height=width_height["height"], height=width_height["height"],
use_negative_prompts=True, use_negative_prompts=True,
) )
"""Upload image to comfy backend to have a URL available for further processing"""
# Get MIME type from tensor - assuming PNG format for image tensors # Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png" mime_type = "image/png"
@ -608,6 +559,15 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
tooltip="Only used if control_type is 'Motion Transfer'", tooltip="Only used if control_type is 'Motion Transfer'",
optional=True, optional=True,
), ),
comfy_io.Int.Input(
"steps",
default=33,
min=1,
max=100,
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Number of inference steps",
),
], ],
outputs=[comfy_io.Video.Output()], outputs=[comfy_io.Video.Output()],
hidden=[ hidden=[
@ -627,6 +587,8 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
video: Optional[VideoInput] = None, video: Optional[VideoInput] = None,
control_type: str = "Motion Transfer", control_type: str = "Motion Transfer",
motion_intensity: Optional[int] = 100, motion_intensity: Optional[int] = 100,
steps=33,
prompt_adherence=4.5,
) -> comfy_io.NodeOutput: ) -> comfy_io.NodeOutput:
auth = { auth = {
"auth_token": cls.hidden.auth_token_comfy_org, "auth_token": cls.hidden.auth_token_comfy_org,
@ -636,7 +598,6 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
validated_video = validate_video_to_video_input(video) validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth) video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth)
"""Validate prompts and inference input"""
validate_prompts(prompt, negative_prompt) validate_prompts(prompt, negative_prompt)
# Only include motion_intensity for Motion Transfer # Only include motion_intensity for Motion Transfer
@ -648,6 +609,8 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
seed=seed, seed=seed,
control_params=control_params, control_params=control_params,
steps=steps,
guidance_scale=prompt_adherence,
) )
control = parse_control_parameter(control_type) control = parse_control_parameter(control_type)
@ -721,7 +684,7 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
), ),
comfy_io.Float.Input( comfy_io.Float.Input(
"prompt_adherence", "prompt_adherence",
default=10.0, default=4.0,
min=1.0, min=1.0,
max=20.0, max=20.0,
step=1.0, step=1.0,
@ -734,11 +697,12 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
max=4294967295, max=4294967295,
step=1, step=1,
display_mode=comfy_io.NumberDisplay.number, display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Random seed value", tooltip="Random seed value",
), ),
comfy_io.Int.Input( comfy_io.Int.Input(
"steps", "steps",
default=100, default=33,
min=1, min=1,
max=100, max=100,
step=1, step=1,

File diff suppressed because it is too large Load Diff

View File

@ -85,7 +85,7 @@ class PixverseTemplateNode(comfy_io.ComfyNode):
display_name="PixVerse Template", display_name="PixVerse Template",
category="api node/video/PixVerse", category="api node/video/PixVerse",
inputs=[ inputs=[
comfy_io.Combo.Input("template", options=[list(pixverse_templates.keys())]), comfy_io.Combo.Input("template", options=list(pixverse_templates.keys())),
], ],
outputs=[comfy_io.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")], outputs=[comfy_io.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")],
) )
@ -120,20 +120,20 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=[ratio.value for ratio in PixverseAspectRatio], options=PixverseAspectRatio,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"quality", "quality",
options=[resolution.value for resolution in PixverseQuality], options=PixverseQuality,
default=PixverseQuality.res_540p, default=PixverseQuality.res_540p,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"duration_seconds", "duration_seconds",
options=[dur.value for dur in PixverseDuration], options=PixverseDuration,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"motion_mode", "motion_mode",
options=[mode.value for mode in PixverseMotionMode], options=PixverseMotionMode,
), ),
comfy_io.Int.Input( comfy_io.Int.Input(
"seed", "seed",
@ -146,7 +146,7 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
comfy_io.String.Input( comfy_io.String.Input(
"negative_prompt", "negative_prompt",
default="", default="",
force_input=True, multiline=True,
tooltip="An optional text description of undesired elements on an image.", tooltip="An optional text description of undesired elements on an image.",
optional=True, optional=True,
), ),
@ -262,16 +262,16 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"quality", "quality",
options=[resolution.value for resolution in PixverseQuality], options=PixverseQuality,
default=PixverseQuality.res_540p, default=PixverseQuality.res_540p,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"duration_seconds", "duration_seconds",
options=[dur.value for dur in PixverseDuration], options=PixverseDuration,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"motion_mode", "motion_mode",
options=[mode.value for mode in PixverseMotionMode], options=PixverseMotionMode,
), ),
comfy_io.Int.Input( comfy_io.Int.Input(
"seed", "seed",
@ -284,7 +284,7 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
comfy_io.String.Input( comfy_io.String.Input(
"negative_prompt", "negative_prompt",
default="", default="",
force_input=True, multiline=True,
tooltip="An optional text description of undesired elements on an image.", tooltip="An optional text description of undesired elements on an image.",
optional=True, optional=True,
), ),
@ -403,16 +403,16 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"quality", "quality",
options=[resolution.value for resolution in PixverseQuality], options=PixverseQuality,
default=PixverseQuality.res_540p, default=PixverseQuality.res_540p,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"duration_seconds", "duration_seconds",
options=[dur.value for dur in PixverseDuration], options=PixverseDuration,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"motion_mode", "motion_mode",
options=[mode.value for mode in PixverseMotionMode], options=PixverseMotionMode,
), ),
comfy_io.Int.Input( comfy_io.Int.Input(
"seed", "seed",
@ -425,7 +425,7 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
comfy_io.String.Input( comfy_io.String.Input(
"negative_prompt", "negative_prompt",
default="", default="",
force_input=True, multiline=True,
tooltip="An optional text description of undesired elements on an image.", tooltip="An optional text description of undesired elements on an image.",
optional=True, optional=True,
), ),

View File

@ -35,6 +35,7 @@ from server import PromptServer
import torch import torch
from io import BytesIO from io import BytesIO
from PIL import UnidentifiedImageError from PIL import UnidentifiedImageError
import aiohttp
async def handle_recraft_file_request( async def handle_recraft_file_request(
@ -82,10 +83,16 @@ async def handle_recraft_file_request(
return all_bytesio return all_bytesio
def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, converted_to_check: list[list]=None, is_list=False) -> dict: def recraft_multipart_parser(
data,
parent_key=None,
formatter: callable = None,
converted_to_check: list[list] = None,
is_list: bool = False,
return_mode: str = "formdata" # "dict" | "formdata"
) -> dict | aiohttp.FormData:
""" """
Formats data such that multipart/form-data will work with requests library Formats data such that multipart/form-data will work with aiohttp library when both files and data are present.
when both files and data are present.
The OpenAI client that Recraft uses has a bizarre way of serializing lists: The OpenAI client that Recraft uses has a bizarre way of serializing lists:
@ -103,23 +110,23 @@ def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, co
# Modification of a function that handled a different type of multipart parsing, big ups: # Modification of a function that handled a different type of multipart parsing, big ups:
# https://gist.github.com/kazqvaizer/4cebebe5db654a414132809f9f88067b # https://gist.github.com/kazqvaizer/4cebebe5db654a414132809f9f88067b
def handle_converted_lists(data, parent_key, lists_to_check=tuple[list]): def handle_converted_lists(item, parent_key, lists_to_check=tuple[list]):
# if list already exists exists, just extend list with data # if list already exists exists, just extend list with data
for check_list in lists_to_check: for check_list in lists_to_check:
for conv_tuple in check_list: for conv_tuple in check_list:
if conv_tuple[0] == parent_key and type(conv_tuple[1]) is list: if conv_tuple[0] == parent_key and isinstance(conv_tuple[1], list):
conv_tuple[1].append(formatter(data)) conv_tuple[1].append(formatter(item))
return True return True
return False return False
if converted_to_check is None: if converted_to_check is None:
converted_to_check = [] converted_to_check = []
effective_mode = return_mode if parent_key is None else "dict"
if formatter is None: if formatter is None:
formatter = lambda v: v # Multipart representation of value formatter = lambda v: v # Multipart representation of value
if type(data) is not dict: if not isinstance(data, dict):
# if list already exists exists, just extend list with data # if list already exists exists, just extend list with data
added = handle_converted_lists(data, parent_key, converted_to_check) added = handle_converted_lists(data, parent_key, converted_to_check)
if added: if added:
@ -136,15 +143,24 @@ def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, co
for key, value in data.items(): for key, value in data.items():
current_key = key if parent_key is None else f"{parent_key}[{key}]" current_key = key if parent_key is None else f"{parent_key}[{key}]"
if type(value) is dict: if isinstance(value, dict):
converted.extend(recraft_multipart_parser(value, current_key, formatter, next_check).items()) converted.extend(recraft_multipart_parser(value, current_key, formatter, next_check).items())
elif type(value) is list: elif isinstance(value, list):
for ind, list_value in enumerate(value): for ind, list_value in enumerate(value):
iter_key = f"{current_key}[]" iter_key = f"{current_key}[]"
converted.extend(recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items()) converted.extend(recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items())
else: else:
converted.append((current_key, formatter(value))) converted.append((current_key, formatter(value)))
if effective_mode == "formdata":
fd = aiohttp.FormData()
for k, v in dict(converted).items():
if isinstance(v, list):
for item in v:
fd.add_field(k, str(item))
else:
fd.add_field(k, str(v))
return fd
return dict(converted) return dict(converted)

View File

@ -7,14 +7,15 @@ Rodin API docs: https://developer.hyper3d.ai/
from __future__ import annotations from __future__ import annotations
from inspect import cleandoc from inspect import cleandoc
from comfy.comfy_types.node_typing import IO
import folder_paths as comfy_paths import folder_paths as comfy_paths
import aiohttp import aiohttp
import os import os
import asyncio import asyncio
import io
import logging import logging
import math import math
from typing import Optional
from io import BytesIO
from typing_extensions import override
from PIL import Image from PIL import Image
from comfy_api_nodes.apis.rodin_api import ( from comfy_api_nodes.apis.rodin_api import (
Rodin3DGenerateRequest, Rodin3DGenerateRequest,
@ -31,186 +32,29 @@ from comfy_api_nodes.apis.client import (
SynchronousOperation, SynchronousOperation,
PollingOperation, PollingOperation,
) )
from comfy_api.latest import ComfyExtension, io as comfy_io
COMMON_PARAMETERS = { COMMON_PARAMETERS = [
"Seed": ( comfy_io.Int.Input(
IO.INT, "Seed",
{ default=0,
"default":0, min=0,
"min":0, max=65535,
"max":65535, display_mode=comfy_io.NumberDisplay.number,
"display":"number" optional=True,
}
), ),
"Material_Type": ( comfy_io.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
IO.COMBO, comfy_io.Combo.Input(
{ "Polygon_count",
"options": ["PBR", "Shaded"], options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
"default": "PBR" default="18K-Quad",
} optional=True,
), ),
"Polygon_count": ( ]
IO.COMBO,
{
"options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
"default": "18K-Quad"
}
)
}
def create_task_error(response: Rodin3DGenerateResponse):
"""Check if the response has error"""
return hasattr(response, "error")
class Rodin3DAPI: def get_quality_mode(poly_count):
"""
Generate 3D Assets using Rodin API
"""
RETURN_TYPES = (IO.STRING,)
RETURN_NAMES = ("3D Model Path",)
CATEGORY = "api node/3d/Rodin"
DESCRIPTION = cleandoc(__doc__ or "")
FUNCTION = "api_call"
API_NODE = True
def tensor_to_filelike(self, tensor, max_pixels: int = 2048*2048):
"""
Converts a PyTorch tensor to a file-like object.
Args:
- tensor (torch.Tensor): A tensor representing an image of shape (H, W, C)
where C is the number of channels (3 for RGB), H is height, and W is width.
Returns:
- io.BytesIO: A file-like object containing the image data.
"""
array = tensor.cpu().numpy()
array = (array * 255).astype('uint8')
image = Image.fromarray(array, 'RGB')
original_width, original_height = image.size
original_pixels = original_width * original_height
if original_pixels > max_pixels:
scale = math.sqrt(max_pixels / original_pixels)
new_width = int(original_width * scale)
new_height = int(original_height * scale)
else:
new_width, new_height = original_width, original_height
if new_width != original_width or new_height != original_height:
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
img_byte_arr.seek(0)
return img_byte_arr
def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str:
has_failed = any(job.status == JobStatus.Failed for job in response.jobs)
all_done = all(job.status == JobStatus.Done for job in response.jobs)
status_list = [str(job.status) for job in response.jobs]
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
if has_failed:
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
elif all_done:
return "DONE"
else:
return "Generating"
async def create_generate_task(self, images=None, seed=1, material="PBR", quality_override=18000, tier="Regular", mesh_mode="Quad", TAPose = False, **kwargs):
if images is None:
raise Exception("Rodin 3D generate requires at least 1 image.")
if len(images) > 5:
raise Exception("Rodin 3D generate requires up to 5 image.")
path = "/proxy/rodin/api/v2/rodin"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DGenerateRequest,
response_model=Rodin3DGenerateResponse,
),
request=Rodin3DGenerateRequest(
seed=seed,
tier=tier,
material=material,
quality_override=quality_override,
mesh_mode=mesh_mode,
TAPose=TAPose,
),
files=[
(
"images",
open(image, "rb") if isinstance(image, str) else self.tensor_to_filelike(image)
)
for image in images if image is not None
],
content_type = "multipart/form-data",
auth_kwargs=kwargs,
)
response = await operation.execute()
if create_task_error(response):
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
logging.error(error_message)
raise Exception(error_message)
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
subscription_key = response.jobs.subscription_key
task_uuid = response.uuid
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
return task_uuid, subscription_key
async def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
path = "/proxy/rodin/api/v2/status"
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path = path,
method=HttpMethod.POST,
request_model=Rodin3DCheckStatusRequest,
response_model=Rodin3DCheckStatusResponse,
),
request=Rodin3DCheckStatusRequest(
subscription_key = subscription_key
),
completed_statuses=["DONE"],
failed_statuses=["FAILED"],
status_extractor=self.check_rodin_status,
poll_interval=3.0,
auth_kwargs=kwargs,
)
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
return await poll_operation.execute()
async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
path = "/proxy/rodin/api/v2/download"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DDownloadRequest,
response_model=Rodin3DDownloadResponse,
),
request=Rodin3DDownloadRequest(
task_uuid=uuid
),
auth_kwargs=kwargs
)
return await operation.execute()
def get_quality_mode(self, poly_count):
polycount = poly_count.split("-") polycount = poly_count.split("-")
poly = polycount[1] poly = polycount[1]
count = polycount[0] count = polycount[0]
@ -242,7 +86,145 @@ class Rodin3DAPI:
return mesh_mode, quality_override return mesh_mode, quality_override
async def download_files(self, url_list, task_uuid):
def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
"""
Converts a PyTorch tensor to a file-like object.
Args:
- tensor (torch.Tensor): A tensor representing an image of shape (H, W, C)
where C is the number of channels (3 for RGB), H is height, and W is width.
Returns:
- io.BytesIO: A file-like object containing the image data.
"""
array = tensor.cpu().numpy()
array = (array * 255).astype('uint8')
image = Image.fromarray(array, 'RGB')
original_width, original_height = image.size
original_pixels = original_width * original_height
if original_pixels > max_pixels:
scale = math.sqrt(max_pixels / original_pixels)
new_width = int(original_width * scale)
new_height = int(original_height * scale)
else:
new_width, new_height = original_width, original_height
if new_width != original_width or new_height != original_height:
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
img_byte_arr.seek(0)
return img_byte_arr
async def create_generate_task(
images=None,
seed=1,
material="PBR",
quality_override=18000,
tier="Regular",
mesh_mode="Quad",
TAPose = False,
auth_kwargs: Optional[dict[str, str]] = None,
):
if images is None:
raise Exception("Rodin 3D generate requires at least 1 image.")
if len(images) > 5:
raise Exception("Rodin 3D generate requires up to 5 image.")
path = "/proxy/rodin/api/v2/rodin"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DGenerateRequest,
response_model=Rodin3DGenerateResponse,
),
request=Rodin3DGenerateRequest(
seed=seed,
tier=tier,
material=material,
quality_override=quality_override,
mesh_mode=mesh_mode,
TAPose=TAPose,
),
files=[
(
"images",
open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image)
)
for image in images if image is not None
],
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
)
response = await operation.execute()
if hasattr(response, "error"):
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
logging.error(error_message)
raise Exception(error_message)
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
subscription_key = response.jobs.subscription_key
task_uuid = response.uuid
logging.info("[ Rodin3D API - Submit Jobs ] UUID: %s", task_uuid)
return task_uuid, subscription_key
def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
all_done = all(job.status == JobStatus.Done for job in response.jobs)
status_list = [str(job.status) for job in response.jobs]
logging.info("[ Rodin3D API - CheckStatus ] Generate Status: %s", status_list)
if any(job.status == JobStatus.Failed for job in response.jobs):
logging.error("[ Rodin3D API - CheckStatus ] Generate Failed: %s, Please try again.", status_list)
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
if all_done:
return "DONE"
return "Generating"
async def poll_for_task_status(
subscription_key, auth_kwargs: Optional[dict[str, str]] = None,
) -> Rodin3DCheckStatusResponse:
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/rodin/api/v2/status",
method=HttpMethod.POST,
request_model=Rodin3DCheckStatusRequest,
response_model=Rodin3DCheckStatusResponse,
),
request=Rodin3DCheckStatusRequest(subscription_key=subscription_key),
completed_statuses=["DONE"],
failed_statuses=["FAILED"],
status_extractor=check_rodin_status,
poll_interval=3.0,
auth_kwargs=auth_kwargs,
)
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
return await poll_operation.execute()
async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = None) -> Rodin3DDownloadResponse:
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/rodin/api/v2/download",
method=HttpMethod.POST,
request_model=Rodin3DDownloadRequest,
response_model=Rodin3DDownloadResponse,
),
request=Rodin3DDownloadRequest(task_uuid=uuid),
auth_kwargs=auth_kwargs,
)
return await operation.execute()
async def download_files(url_list, task_uuid):
save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}") save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}")
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
model_file_path = None model_file_path = None
@ -253,7 +235,7 @@ class Rodin3DAPI:
file_path = os.path.join(save_path, file_name) file_path = os.path.join(save_path, file_name)
if file_path.endswith(".glb"): if file_path.endswith(".glb"):
model_file_path = file_path model_file_path = file_path
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}") logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
max_retries = 5 max_retries = 5
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
@ -264,7 +246,7 @@ class Rodin3DAPI:
f.write(chunk) f.write(chunk)
break break
except Exception as e: except Exception as e:
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}") logging.info("[ Rodin3D API - download_files ] Error downloading %s:%s", file_path, str(e))
if attempt < max_retries - 1: if attempt < max_retries - 1:
logging.info("Retrying...") logging.info("Retrying...")
await asyncio.sleep(2) await asyncio.sleep(2)
@ -274,185 +256,212 @@ class Rodin3DAPI:
file_path, file_path,
max_retries, max_retries,
) )
return model_file_path return model_file_path
class Rodin3D_Regular(Rodin3DAPI): class Rodin3D_Regular(comfy_io.ComfyNode):
@classmethod """Generate 3D Assets using Rodin API"""
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
async def api_call( @classmethod
self, def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="Rodin3D_Regular",
display_name="Rodin 3D Generate - Regular Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("Images"),
*COMMON_PARAMETERS,
],
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
Images, Images,
Seed, Seed,
Material_Type, Material_Type,
Polygon_count, Polygon_count,
**kwargs ) -> comfy_io.NodeOutput:
):
tier = "Regular" tier = "Regular"
num_images = Images.shape[0] num_images = Images.shape[0]
m_images = [] m_images = []
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality_override = self.get_quality_mode(Polygon_count) mesh_mode, quality_override = get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, auth = {
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, "auth_token": cls.hidden.auth_token_comfy_org,
**kwargs) "comfy_api_key": cls.hidden.api_key_comfy_org,
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list, task_uuid)
return (model,)
class Rodin3D_Detail(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
} }
task_uuid, subscription_key = await create_generate_task(
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
auth_kwargs=auth,
) )
}, await poll_for_task_status(subscription_key, auth_kwargs=auth)
"optional": { download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
**COMMON_PARAMETERS model = await download_files(download_list, task_uuid)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
async def api_call( return comfy_io.NodeOutput(model)
self,
class Rodin3D_Detail(comfy_io.ComfyNode):
"""Generate 3D Assets using Rodin API"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="Rodin3D_Detail",
display_name="Rodin 3D Generate - Detail Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("Images"),
*COMMON_PARAMETERS,
],
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
Images, Images,
Seed, Seed,
Material_Type, Material_Type,
Polygon_count, Polygon_count,
**kwargs ) -> comfy_io.NodeOutput:
):
tier = "Detail" tier = "Detail"
num_images = Images.shape[0] num_images = Images.shape[0]
m_images = [] m_images = []
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality_override = self.get_quality_mode(Polygon_count) mesh_mode, quality_override = get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, auth = {
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, "auth_token": cls.hidden.auth_token_comfy_org,
**kwargs) "comfy_api_key": cls.hidden.api_key_comfy_org,
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list, task_uuid)
return (model,)
class Rodin3D_Smooth(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
} }
task_uuid, subscription_key = await create_generate_task(
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
auth_kwargs=auth,
) )
}, await poll_for_task_status(subscription_key, auth_kwargs=auth)
"optional": { download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
**COMMON_PARAMETERS model = await download_files(download_list, task_uuid)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
async def api_call( return comfy_io.NodeOutput(model)
self,
class Rodin3D_Smooth(comfy_io.ComfyNode):
"""Generate 3D Assets using Rodin API"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="Rodin3D_Smooth",
display_name="Rodin 3D Generate - Smooth Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("Images"),
*COMMON_PARAMETERS,
],
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
Images, Images,
Seed, Seed,
Material_Type, Material_Type,
Polygon_count, Polygon_count,
**kwargs ) -> comfy_io.NodeOutput:
):
tier = "Smooth" tier = "Smooth"
num_images = Images.shape[0] num_images = Images.shape[0]
m_images = [] m_images = []
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality_override = self.get_quality_mode(Polygon_count) mesh_mode, quality_override = get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, auth = {
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, "auth_token": cls.hidden.auth_token_comfy_org,
**kwargs) "comfy_api_key": cls.hidden.api_key_comfy_org,
await self.poll_for_task_status(subscription_key, **kwargs) }
download_list = await self.get_rodin_download_list(task_uuid, **kwargs) task_uuid, subscription_key = await create_generate_task(
model = await self.download_files(download_list, task_uuid) images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
auth_kwargs=auth,
)
await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await download_files(download_list, task_uuid)
return (model,) return comfy_io.NodeOutput(model)
class Rodin3D_Sketch(Rodin3DAPI): class Rodin3D_Sketch(comfy_io.ComfyNode):
"""Generate 3D Assets using Rodin API"""
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls) -> comfy_io.Schema:
return { return comfy_io.Schema(
"required": { node_id="Rodin3D_Sketch",
"Images": display_name="Rodin 3D Generate - Sketch Generate",
( category="api node/3d/Rodin",
IO.IMAGE, description=cleandoc(cls.__doc__ or ""),
{ inputs=[
"forceInput":True, comfy_io.Image.Input("Images"),
} comfy_io.Int.Input(
"Seed",
default=0,
min=0,
max=65535,
display_mode=comfy_io.NumberDisplay.number,
optional=True,
),
],
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
],
is_api_node=True,
) )
},
"optional": {
"Seed":
(
IO.INT,
{
"default":0,
"min":0,
"max":65535,
"display":"number"
}
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
async def api_call( @classmethod
self, async def execute(
cls,
Images, Images,
Seed, Seed,
**kwargs ) -> comfy_io.NodeOutput:
):
tier = "Sketch" tier = "Sketch"
num_images = Images.shape[0] num_images = Images.shape[0]
m_images = [] m_images = []
@ -461,104 +470,110 @@ class Rodin3D_Sketch(Rodin3DAPI):
material_type = "PBR" material_type = "PBR"
quality_override = 18000 quality_override = 18000
mesh_mode = "Quad" mesh_mode = "Quad"
task_uuid, subscription_key = await self.create_generate_task( auth = {
images=m_images, seed=Seed, material=material_type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, **kwargs "auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task(
images=m_images,
seed=Seed,
material=material_type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
auth_kwargs=auth,
) )
await self.poll_for_task_status(subscription_key, **kwargs) await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs) download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await self.download_files(download_list, task_uuid) model = await download_files(download_list, task_uuid)
return (model,) return comfy_io.NodeOutput(model)
class Rodin3D_Gen2(comfy_io.ComfyNode):
"""Generate 3D Assets using Rodin API"""
class Rodin3D_Gen2(Rodin3DAPI):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls) -> comfy_io.Schema:
return { return comfy_io.Schema(
"required": { node_id="Rodin3D_Gen2",
"Images": display_name="Rodin 3D Generate - Gen-2 Generate",
( category="api node/3d/Rodin",
IO.IMAGE, description=cleandoc(cls.__doc__ or ""),
{ inputs=[
"forceInput":True, comfy_io.Image.Input("Images"),
} comfy_io.Int.Input(
"Seed",
default=0,
min=0,
max=65535,
display_mode=comfy_io.NumberDisplay.number,
optional=True,
),
comfy_io.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
comfy_io.Combo.Input(
"Polygon_count",
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
default="500K-Triangle",
optional=True,
),
comfy_io.Boolean.Input("TAPose", default=False),
],
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
],
is_api_node=True,
) )
},
"optional": {
"Seed": (
IO.INT,
{
"default":0,
"min":0,
"max":65535,
"display":"number"
}
),
"Material_Type": (
IO.COMBO,
{
"options": ["PBR", "Shaded"],
"default": "PBR"
}
),
"Polygon_count": (
IO.COMBO,
{
"options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
"default": "500K-Triangle"
}
),
"TAPose": (
IO.BOOLEAN,
{
"default": False,
}
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
async def api_call( @classmethod
self, async def execute(
cls,
Images, Images,
Seed, Seed,
Material_Type, Material_Type,
Polygon_count, Polygon_count,
TAPose, TAPose,
**kwargs ) -> comfy_io.NodeOutput:
):
tier = "Gen-2" tier = "Gen-2"
num_images = Images.shape[0] num_images = Images.shape[0]
m_images = [] m_images = []
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality_override = self.get_quality_mode(Polygon_count) mesh_mode, quality_override = get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, auth = {
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, TAPose=TAPose, "auth_token": cls.hidden.auth_token_comfy_org,
**kwargs) "comfy_api_key": cls.hidden.api_key_comfy_org,
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list, task_uuid)
return (model,)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"Rodin3D_Regular": Rodin3D_Regular,
"Rodin3D_Detail": Rodin3D_Detail,
"Rodin3D_Smooth": Rodin3D_Smooth,
"Rodin3D_Sketch": Rodin3D_Sketch,
"Rodin3D_Gen2": Rodin3D_Gen2,
} }
task_uuid, subscription_key = await create_generate_task(
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
TAPose=TAPose,
auth_kwargs=auth,
)
await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await download_files(download_list, task_uuid)
# A dictionary that contains the friendly/humanly readable titles for the nodes return comfy_io.NodeOutput(model)
NODE_DISPLAY_NAME_MAPPINGS = {
"Rodin3D_Regular": "Rodin 3D Generate - Regular Generate",
"Rodin3D_Detail": "Rodin 3D Generate - Detail Generate", class Rodin3DExtension(ComfyExtension):
"Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate", @override
"Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate", async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
"Rodin3D_Gen2": "Rodin 3D Generate - Gen-2 Generate", return [
} Rodin3D_Regular,
Rodin3D_Detail,
Rodin3D_Smooth,
Rodin3D_Sketch,
Rodin3D_Gen2,
]
async def comfy_entrypoint() -> Rodin3DExtension:
return Rodin3DExtension()

View File

@ -200,11 +200,11 @@ class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"duration", "duration",
options=[model.value for model in Duration], options=Duration,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"ratio", "ratio",
options=[model.value for model in RunwayGen3aAspectRatio], options=RunwayGen3aAspectRatio,
), ),
comfy_io.Int.Input( comfy_io.Int.Input(
"seed", "seed",
@ -300,11 +300,11 @@ class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"duration", "duration",
options=[model.value for model in Duration], options=Duration,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"ratio", "ratio",
options=[model.value for model in RunwayGen4TurboAspectRatio], options=RunwayGen4TurboAspectRatio,
), ),
comfy_io.Int.Input( comfy_io.Int.Input(
"seed", "seed",
@ -408,11 +408,11 @@ class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"duration", "duration",
options=[model.value for model in Duration], options=Duration,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"ratio", "ratio",
options=[model.value for model in RunwayGen3aAspectRatio], options=RunwayGen3aAspectRatio,
), ),
comfy_io.Int.Input( comfy_io.Int.Input(
"seed", "seed",

View File

@ -0,0 +1,175 @@
from typing import Optional
from typing_extensions import override
import torch
from pydantic import BaseModel, Field
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.util.validation_utils import get_number_of_images
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
tensor_to_bytesio,
)
class Sora2GenerationRequest(BaseModel):
prompt: str = Field(...)
model: str = Field(...)
seconds: str = Field(...)
size: str = Field(...)
class Sora2GenerationResponse(BaseModel):
id: str = Field(...)
error: Optional[dict] = Field(None)
status: Optional[str] = Field(None)
class OpenAIVideoSora2(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="OpenAIVideoSora2",
display_name="OpenAI Sora - Video",
category="api node/video/Sora",
description="OpenAI video and audio generation.",
inputs=[
comfy_io.Combo.Input(
"model",
options=["sora-2", "sora-2-pro"],
default="sora-2",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Guiding text; may be empty if an input image is present.",
),
comfy_io.Combo.Input(
"size",
options=[
"720x1280",
"1280x720",
"1024x1792",
"1792x1024",
],
default="1280x720",
),
comfy_io.Combo.Input(
"duration",
options=[4, 8, 12],
default=8,
),
comfy_io.Image.Input(
"image",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
optional=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
size: str = "1280x720",
duration: int = 8,
seed: int = 0,
image: Optional[torch.Tensor] = None,
):
if model == "sora-2" and size not in ("720x1280", "1280x720"):
raise ValueError("Invalid size for sora-2 model, only 720x1280 and 1280x720 are supported.")
files_input = None
if image is not None:
if get_number_of_images(image) != 1:
raise ValueError("Currently only one input image is supported.")
files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload = Sora2GenerationRequest(
model=model,
prompt=prompt,
seconds=str(duration),
size=size,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/openai/v1/videos",
method=HttpMethod.POST,
request_model=Sora2GenerationRequest,
response_model=Sora2GenerationResponse
),
request=payload,
files=files_input,
auth_kwargs=auth,
content_type="multipart/form-data",
)
initial_response = await initial_operation.execute()
if initial_response.error:
raise Exception(initial_response.error.message)
model_time_multiplier = 1 if model == "sora-2" else 2
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/openai/v1/videos/{initial_response.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=Sora2GenerationResponse
),
completed_statuses=["completed"],
failed_statuses=["failed"],
status_extractor=lambda x: x.status,
auth_kwargs=auth,
poll_interval=8.0,
max_poll_attempts=160,
node_id=cls.hidden.unique_id,
estimated_duration=45 * (duration / 4) * model_time_multiplier,
)
await poll_operation.execute()
return comfy_io.NodeOutput(
await download_url_to_video_output(
f"/proxy/openai/v1/videos/{initial_response.id}/content",
auth_kwargs=auth,
)
)
class OpenAISoraExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
OpenAIVideoSora2,
]
async def comfy_entrypoint() -> OpenAISoraExtension:
return OpenAISoraExtension()

View File

@ -82,8 +82,8 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=[x.value for x in StabilityAspectRatio], options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1.value, default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.", tooltip="Aspect ratio of generated image.",
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
@ -217,12 +217,12 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[x.value for x in Stability_SD3_5_Model], options=Stability_SD3_5_Model,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=[x.value for x in StabilityAspectRatio], options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1.value, default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.", tooltip="Aspect ratio of generated image.",
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(

View File

@ -215,7 +215,7 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
initial_response = await initial_operation.execute() initial_response = await initial_operation.execute()
operation_name = initial_response.name operation_name = initial_response.name
logging.info(f"Veo generation started with operation name: {operation_name}") logging.info("Veo generation started with operation name: %s", operation_name)
# Define status extractor function # Define status extractor function
def status_extractor(response): def status_extractor(response):

View File

@ -173,8 +173,8 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in VideoModelName], options=VideoModelName,
default=VideoModelName.vidu_q1.value, default=VideoModelName.vidu_q1,
tooltip="Model name", tooltip="Model name",
), ),
comfy_io.String.Input( comfy_io.String.Input(
@ -205,22 +205,22 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=[model.value for model in AspectRatio], options=AspectRatio,
default=AspectRatio.r_16_9.value, default=AspectRatio.r_16_9,
tooltip="The aspect ratio of the output video", tooltip="The aspect ratio of the output video",
optional=True, optional=True,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"resolution", "resolution",
options=[model.value for model in Resolution], options=Resolution,
default=Resolution.r_1080p.value, default=Resolution.r_1080p,
tooltip="Supported values may vary by model & duration", tooltip="Supported values may vary by model & duration",
optional=True, optional=True,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"movement_amplitude", "movement_amplitude",
options=[model.value for model in MovementAmplitude], options=MovementAmplitude,
default=MovementAmplitude.auto.value, default=MovementAmplitude.auto,
tooltip="The movement amplitude of objects in the frame", tooltip="The movement amplitude of objects in the frame",
optional=True, optional=True,
), ),
@ -278,8 +278,8 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in VideoModelName], options=VideoModelName,
default=VideoModelName.vidu_q1.value, default=VideoModelName.vidu_q1,
tooltip="Model name", tooltip="Model name",
), ),
comfy_io.Image.Input( comfy_io.Image.Input(
@ -316,14 +316,14 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"resolution", "resolution",
options=[model.value for model in Resolution], options=Resolution,
default=Resolution.r_1080p.value, default=Resolution.r_1080p,
tooltip="Supported values may vary by model & duration", tooltip="Supported values may vary by model & duration",
optional=True, optional=True,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"movement_amplitude", "movement_amplitude",
options=[model.value for model in MovementAmplitude], options=MovementAmplitude,
default=MovementAmplitude.auto.value, default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame", tooltip="The movement amplitude of objects in the frame",
optional=True, optional=True,
@ -388,8 +388,8 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in VideoModelName], options=VideoModelName,
default=VideoModelName.vidu_q1.value, default=VideoModelName.vidu_q1,
tooltip="Model name", tooltip="Model name",
), ),
comfy_io.Image.Input( comfy_io.Image.Input(
@ -424,8 +424,8 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=[model.value for model in AspectRatio], options=AspectRatio,
default=AspectRatio.r_16_9.value, default=AspectRatio.r_16_9,
tooltip="The aspect ratio of the output video", tooltip="The aspect ratio of the output video",
optional=True, optional=True,
), ),

View File

@ -142,9 +142,10 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
for key, value in metadata.items(): for key, value in metadata.items():
output_container.metadata[key] = value output_container.metadata[key] = value
layout = 'mono' if waveform.shape[0] == 1 else 'stereo'
# Set up the output stream with appropriate properties # Set up the output stream with appropriate properties
if format == "opus": if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate) out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
if quality == "64k": if quality == "64k":
out_stream.bit_rate = 64000 out_stream.bit_rate = 64000
elif quality == "96k": elif quality == "96k":
@ -156,7 +157,7 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
elif quality == "320k": elif quality == "320k":
out_stream.bit_rate = 320000 out_stream.bit_rate = 320000
elif format == "mp3": elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate) out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
if quality == "V0": if quality == "V0":
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool #TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1 out_stream.codec_context.qscale = 1
@ -165,9 +166,9 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
elif quality == "320k": elif quality == "320k":
out_stream.bit_rate = 320000 out_stream.bit_rate = 320000
else: #format == "flac": else: #format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate) out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo') frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout)
frame.sample_rate = sample_rate frame.sample_rate = sample_rate
frame.pts = 0 frame.pts = 0
output_container.mux(out_stream.encode(frame)) output_container.mux(out_stream.encode(frame))
@ -360,7 +361,7 @@ class RecordAudio:
def load(self, audio): def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio) audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path) waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, ) return (audio, )

View File

@ -1,6 +1,9 @@
import torch import torch
import comfy.utils import comfy.utils
from enum import Enum from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def resize_mask(mask, shape): def resize_mask(mask, shape):
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1) return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
@ -101,24 +104,28 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
return out_image, out_alpha return out_image, out_alpha
class PorterDuffImageComposite: class PorterDuffImageComposite(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="PorterDuffImageComposite",
"source": ("IMAGE",), display_name="Porter-Duff Image Composite",
"source_alpha": ("MASK",), category="mask/compositing",
"destination": ("IMAGE",), inputs=[
"destination_alpha": ("MASK",), io.Image.Input("source"),
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}), io.Mask.Input("source_alpha"),
}, io.Image.Input("destination"),
} io.Mask.Input("destination_alpha"),
io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name),
],
outputs=[
io.Image.Output(),
io.Mask.Output(),
],
)
RETURN_TYPES = ("IMAGE", "MASK") @classmethod
FUNCTION = "composite" def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput:
CATEGORY = "mask/compositing"
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha)) batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
out_images = [] out_images = []
out_alphas = [] out_alphas = []
@ -150,45 +157,48 @@ class PorterDuffImageComposite:
out_images.append(out_image) out_images.append(out_image)
out_alphas.append(out_alpha.squeeze(2)) out_alphas.append(out_alpha.squeeze(2))
result = (torch.stack(out_images), torch.stack(out_alphas)) return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas))
return result
class SplitImageWithAlpha: class SplitImageWithAlpha(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="SplitImageWithAlpha",
"image": ("IMAGE",), display_name="Split Image with Alpha",
} category="mask/compositing",
} inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
io.Mask.Output(),
],
)
CATEGORY = "mask/compositing" @classmethod
RETURN_TYPES = ("IMAGE", "MASK") def execute(cls, image: torch.Tensor) -> io.NodeOutput:
FUNCTION = "split_image_with_alpha"
def split_image_with_alpha(self, image: torch.Tensor):
out_images = [i[:,:,:3] for i in image] out_images = [i[:,:,:3] for i in image]
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas)) return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas))
return result
class JoinImageWithAlpha: class JoinImageWithAlpha(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="JoinImageWithAlpha",
"image": ("IMAGE",), display_name="Join Image with Alpha",
"alpha": ("MASK",), category="mask/compositing",
} inputs=[
} io.Image.Input("image"),
io.Mask.Input("alpha"),
],
outputs=[io.Image.Output()],
)
CATEGORY = "mask/compositing" @classmethod
RETURN_TYPES = ("IMAGE",) def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
FUNCTION = "join_image_with_alpha"
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
batch_size = min(len(image), len(alpha)) batch_size = min(len(image), len(alpha))
out_images = [] out_images = []
@ -196,19 +206,18 @@ class JoinImageWithAlpha:
for i in range(batch_size): for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
result = (torch.stack(out_images),) return io.NodeOutput(torch.stack(out_images))
return result
NODE_CLASS_MAPPINGS = { class CompositingExtension(ComfyExtension):
"PorterDuffImageComposite": PorterDuffImageComposite, @override
"SplitImageWithAlpha": SplitImageWithAlpha, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"JoinImageWithAlpha": JoinImageWithAlpha, return [
} PorterDuffImageComposite,
SplitImageWithAlpha,
JoinImageWithAlpha,
]
NODE_DISPLAY_NAME_MAPPINGS = { async def comfy_entrypoint() -> CompositingExtension:
"PorterDuffImageComposite": "Porter-Duff Image Composite", return CompositingExtension()
"SplitImageWithAlpha": "Split Image with Alpha",
"JoinImageWithAlpha": "Join Image with Alpha",
}

View File

@ -1,60 +1,80 @@
import node_helpers import node_helpers
import comfy.utils import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class CLIPTextEncodeFlux:
class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"clip": ("CLIP", ), node_id="CLIPTextEncodeFlux",
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), category="advanced/conditioning/flux",
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), inputs=[
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}), io.Clip.Input("clip"),
}} io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
RETURN_TYPES = ("CONDITIONING",) io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
FUNCTION = "encode" io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning/flux" @classmethod
def execute(cls, clip, clip_l, t5xxl, guidance) -> io.NodeOutput:
def encode(self, clip, clip_l, t5xxl, guidance):
tokens = clip.tokenize(clip_l) tokens = clip.tokenize(clip_l)
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), ) return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}))
class FluxGuidance: encode = execute # TODO: remove
class FluxGuidance(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"conditioning": ("CONDITIONING", ), node_id="FluxGuidance",
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}), category="advanced/conditioning/flux",
}} inputs=[
io.Conditioning.Input("conditioning"),
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
],
outputs=[
io.Conditioning.Output(),
],
)
RETURN_TYPES = ("CONDITIONING",) @classmethod
FUNCTION = "append" def execute(cls, conditioning, guidance) -> io.NodeOutput:
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, guidance):
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance}) c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
return (c, ) return io.NodeOutput(c)
append = execute # TODO: remove
class FluxDisableGuidance: class FluxDisableGuidance(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"conditioning": ("CONDITIONING", ), node_id="FluxDisableGuidance",
}} category="advanced/conditioning/flux",
description="This node completely disables the guidance embed on Flux and Flux like models",
inputs=[
io.Conditioning.Input("conditioning"),
],
outputs=[
io.Conditioning.Output(),
],
)
RETURN_TYPES = ("CONDITIONING",) @classmethod
FUNCTION = "append" def execute(cls, conditioning) -> io.NodeOutput:
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"
def append(self, conditioning):
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None}) c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
return (c, ) return io.NodeOutput(c)
append = execute # TODO: remove
PREFERED_KONTEXT_RESOLUTIONS = [ PREFERED_KONTEXT_RESOLUTIONS = [
@ -78,52 +98,73 @@ PREFERED_KONTEXT_RESOLUTIONS = [
] ]
class FluxKontextImageScale: class FluxKontextImageScale(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"image": ("IMAGE", ), return io.Schema(
}, node_id="FluxKontextImageScale",
} category="advanced/conditioning/flux",
description="This node resizes the image to one that is more optimal for flux kontext.",
inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
],
)
RETURN_TYPES = ("IMAGE",) @classmethod
FUNCTION = "scale" def execute(cls, image) -> io.NodeOutput:
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
def scale(self, image):
width = image.shape[2] width = image.shape[2]
height = image.shape[1] height = image.shape[1]
aspect_ratio = width / height aspect_ratio = width / height
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
return (image, ) return io.NodeOutput(image)
scale = execute # TODO: remove
class FluxKontextMultiReferenceLatentMethod: class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"conditioning": ("CONDITIONING", ), node_id="FluxKontextMultiReferenceLatentMethod",
"reference_latents_method": (("offset", "index", "uxo/uno"), ), category="advanced/conditioning/flux",
}} inputs=[
io.Conditioning.Input("conditioning"),
io.Combo.Input(
"reference_latents_method",
options=["offset", "index", "uxo/uno"],
),
],
outputs=[
io.Conditioning.Output(),
],
is_experimental=True,
)
RETURN_TYPES = ("CONDITIONING",) @classmethod
FUNCTION = "append" def execute(cls, conditioning, reference_latents_method) -> io.NodeOutput:
EXPERIMENTAL = True
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, reference_latents_method):
if "uxo" in reference_latents_method or "uso" in reference_latents_method: if "uxo" in reference_latents_method or "uso" in reference_latents_method:
reference_latents_method = "uxo" reference_latents_method = "uxo"
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method}) c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
return (c, ) return io.NodeOutput(c)
NODE_CLASS_MAPPINGS = { append = execute # TODO: remove
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
"FluxDisableGuidance": FluxDisableGuidance, class FluxExtension(ComfyExtension):
"FluxKontextImageScale": FluxKontextImageScale, @override
"FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod, async def get_node_list(self) -> list[type[io.ComfyNode]]:
} return [
CLIPTextEncodeFlux,
FluxGuidance,
FluxDisableGuidance,
FluxKontextImageScale,
FluxKontextMultiReferenceLatentMethod,
]
async def comfy_entrypoint() -> FluxExtension:
return FluxExtension()

View File

@ -2,42 +2,60 @@ import nodes
import node_helpers import node_helpers
import torch import torch
import comfy.model_management import comfy.model_management
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class CLIPTextEncodeHunyuanDiT: class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"clip": ("CLIP", ), node_id="CLIPTextEncodeHunyuanDiT",
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}), category="advanced/conditioning",
"mt5xl": ("STRING", {"multiline": True, "dynamicPrompts": True}), inputs=[
}} io.Clip.Input("clip"),
RETURN_TYPES = ("CONDITIONING",) io.String.Input("bert", multiline=True, dynamic_prompts=True),
FUNCTION = "encode" io.String.Input("mt5xl", multiline=True, dynamic_prompts=True),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning" @classmethod
def execute(cls, clip, bert, mt5xl) -> io.NodeOutput:
def encode(self, clip, bert, mt5xl):
tokens = clip.tokenize(bert) tokens = clip.tokenize(bert)
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"] tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
return (clip.encode_from_tokens_scheduled(tokens), ) return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
class EmptyHunyuanLatentVideo: encode = execute # TODO: remove
class EmptyHunyuanLatentVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), return io.Schema(
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), node_id="EmptyHunyuanLatentVideo",
"length": ("INT", {"default": 25, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), category="latent/video",
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} inputs=[
RETURN_TYPES = ("LATENT",) io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
FUNCTION = "generate" io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=25, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
CATEGORY = "latent/video" @classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, ) return io.NodeOutput({"samples":latent})
generate = execute # TODO: remove
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: " "<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
@ -50,45 +68,61 @@ PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
"<|start_header_id|>assistant<|end_header_id|>\n\n" "<|start_header_id|>assistant<|end_header_id|>\n\n"
) )
class TextEncodeHunyuanVideo_ImageToVideo: class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"clip": ("CLIP", ), node_id="TextEncodeHunyuanVideo_ImageToVideo",
"clip_vision_output": ("CLIP_VISION_OUTPUT", ), category="advanced/conditioning",
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), inputs=[
"image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}), io.Clip.Input("clip"),
}} io.ClipVisionOutput.Input("clip_vision_output"),
RETURN_TYPES = ("CONDITIONING",) io.String.Input("prompt", multiline=True, dynamic_prompts=True),
FUNCTION = "encode" io.Int.Input(
"image_interleave",
default=2,
min=1,
max=512,
tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.",
),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning" @classmethod
def execute(cls, clip, clip_vision_output, prompt, image_interleave) -> io.NodeOutput:
def encode(self, clip, clip_vision_output, prompt, image_interleave):
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave) tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
return (clip.encode_from_tokens_scheduled(tokens), ) return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
class HunyuanImageToVideo: encode = execute # TODO: remove
class HunyuanImageToVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"vae": ("VAE", ), node_id="HunyuanImageToVideo",
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), category="conditioning/video_models",
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), inputs=[
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), io.Conditioning.Input("positive"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Vae.Input("vae"),
"guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], ) io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
}, io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
"optional": {"start_image": ("IMAGE", ), io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4),
}} io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"]),
io.Image.Input("start_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "latent") def execute(cls, positive, vae, width, height, length, batch_size, guidance_type, start_image=None) -> io.NodeOutput:
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
out_latent = {} out_latent = {}
@ -111,51 +145,76 @@ class HunyuanImageToVideo:
positive = node_helpers.conditioning_set_values(positive, cond) positive = node_helpers.conditioning_set_values(positive, cond)
out_latent["samples"] = latent out_latent["samples"] = latent
return (positive, out_latent) return io.NodeOutput(positive, out_latent)
class EmptyHunyuanImageLatent: encode = execute # TODO: remove
class EmptyHunyuanImageLatent(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "width": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), return io.Schema(
"height": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), node_id="EmptyHunyuanImageLatent",
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} category="latent",
RETURN_TYPES = ("LATENT",) inputs=[
FUNCTION = "generate" io.Int.Input("width", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
CATEGORY = "latent" @classmethod
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, ) return io.NodeOutput({"samples":latent})
class HunyuanRefinerLatent: generate = execute # TODO: remove
class HunyuanRefinerLatent(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="HunyuanRefinerLatent",
"latent": ("LATENT", ), inputs=[
"noise_augmentation": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01}), io.Conditioning.Input("positive"),
}} io.Conditioning.Input("negative"),
io.Latent.Input("latent"),
io.Float.Input("noise_augmentation", default=0.10, min=0.0, max=1.0, step=0.01),
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") ],
RETURN_NAMES = ("positive", "negative", "latent") outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
FUNCTION = "execute" @classmethod
def execute(cls, positive, negative, latent, noise_augmentation) -> io.NodeOutput:
def execute(self, positive, negative, latent, noise_augmentation):
latent = latent["samples"] latent = latent["samples"]
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation}) positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation}) negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
out_latent = {} out_latent = {}
out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
return (positive, negative, out_latent) return io.NodeOutput(positive, negative, out_latent)
NODE_CLASS_MAPPINGS = { class HunyuanExtension(ComfyExtension):
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, @override
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, return [
"HunyuanImageToVideo": HunyuanImageToVideo, CLIPTextEncodeHunyuanDiT,
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent, TextEncodeHunyuanVideo_ImageToVideo,
"HunyuanRefinerLatent": HunyuanRefinerLatent, EmptyHunyuanLatentVideo,
} HunyuanImageToVideo,
EmptyHunyuanImageLatent,
HunyuanRefinerLatent,
]
async def comfy_entrypoint() -> HunyuanExtension:
return HunyuanExtension()

View File

@ -2,6 +2,8 @@ import comfy.utils
import comfy_extras.nodes_post_processing import comfy_extras.nodes_post_processing
import torch import torch
import nodes import nodes
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def reshape_latent_to(target_shape, latent, repeat_batch=True): def reshape_latent_to(target_shape, latent, repeat_batch=True):
@ -13,17 +15,23 @@ def reshape_latent_to(target_shape, latent, repeat_batch=True):
return latent return latent
class LatentAdd: class LatentAdd(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} return io.Schema(
node_id="LatentAdd",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples1, samples2) -> io.NodeOutput:
CATEGORY = "latent/advanced"
def op(self, samples1, samples2):
samples_out = samples1.copy() samples_out = samples1.copy()
s1 = samples1["samples"] s1 = samples1["samples"]
@ -31,19 +39,25 @@ class LatentAdd:
s2 = reshape_latent_to(s1.shape, s2) s2 = reshape_latent_to(s1.shape, s2)
samples_out["samples"] = s1 + s2 samples_out["samples"] = s1 + s2
return (samples_out,) return io.NodeOutput(samples_out)
class LatentSubtract: class LatentSubtract(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} return io.Schema(
node_id="LatentSubtract",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples1, samples2) -> io.NodeOutput:
CATEGORY = "latent/advanced"
def op(self, samples1, samples2):
samples_out = samples1.copy() samples_out = samples1.copy()
s1 = samples1["samples"] s1 = samples1["samples"]
@ -51,41 +65,49 @@ class LatentSubtract:
s2 = reshape_latent_to(s1.shape, s2) s2 = reshape_latent_to(s1.shape, s2)
samples_out["samples"] = s1 - s2 samples_out["samples"] = s1 - s2
return (samples_out,) return io.NodeOutput(samples_out)
class LatentMultiply: class LatentMultiply(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples": ("LATENT",), return io.Schema(
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), node_id="LatentMultiply",
}} category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples, multiplier) -> io.NodeOutput:
CATEGORY = "latent/advanced"
def op(self, samples, multiplier):
samples_out = samples.copy() samples_out = samples.copy()
s1 = samples["samples"] s1 = samples["samples"]
samples_out["samples"] = s1 * multiplier samples_out["samples"] = s1 * multiplier
return (samples_out,) return io.NodeOutput(samples_out)
class LatentInterpolate: class LatentInterpolate(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples1": ("LATENT",), return io.Schema(
"samples2": ("LATENT",), node_id="LatentInterpolate",
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), category="latent/advanced",
}} inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples1, samples2, ratio) -> io.NodeOutput:
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, ratio):
samples_out = samples1.copy() samples_out = samples1.copy()
s1 = samples1["samples"] s1 = samples1["samples"]
@ -104,19 +126,26 @@ class LatentInterpolate:
st = torch.nan_to_num(t / mt) st = torch.nan_to_num(t / mt)
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,) return io.NodeOutput(samples_out)
class LatentConcat: class LatentConcat(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}} return io.Schema(
node_id="LatentConcat",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
io.Combo.Input("dim", options=["x", "-x", "y", "-y", "t", "-t"]),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples1, samples2, dim) -> io.NodeOutput:
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, dim):
samples_out = samples1.copy() samples_out = samples1.copy()
s1 = samples1["samples"] s1 = samples1["samples"]
@ -136,22 +165,27 @@ class LatentConcat:
dim = -3 dim = -3
samples_out["samples"] = torch.cat(c, dim=dim) samples_out["samples"] = torch.cat(c, dim=dim)
return (samples_out,) return io.NodeOutput(samples_out)
class LatentCut: class LatentCut(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"samples": ("LATENT",), return io.Schema(
"dim": (["x", "y", "t"], ), node_id="LatentCut",
"index": ("INT", {"default": 0, "min": -nodes.MAX_RESOLUTION, "max": nodes.MAX_RESOLUTION, "step": 1}), category="latent/advanced",
"amount": ("INT", {"default": 1, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 1})}} inputs=[
io.Latent.Input("samples"),
io.Combo.Input("dim", options=["x", "y", "t"]),
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
io.Int.Input("amount", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples, dim, index, amount) -> io.NodeOutput:
CATEGORY = "latent/advanced"
def op(self, samples, dim, index, amount):
samples_out = samples.copy() samples_out = samples.copy()
s1 = samples["samples"] s1 = samples["samples"]
@ -171,19 +205,25 @@ class LatentCut:
amount = min(-index, amount) amount = min(-index, amount)
samples_out["samples"] = torch.narrow(s1, dim, index, amount) samples_out["samples"] = torch.narrow(s1, dim, index, amount)
return (samples_out,) return io.NodeOutput(samples_out)
class LatentBatch: class LatentBatch(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} return io.Schema(
node_id="LatentBatch",
category="latent/batch",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "batch" def execute(cls, samples1, samples2) -> io.NodeOutput:
CATEGORY = "latent/batch"
def batch(self, samples1, samples2):
samples_out = samples1.copy() samples_out = samples1.copy()
s1 = samples1["samples"] s1 = samples1["samples"]
s2 = samples2["samples"] s2 = samples2["samples"]
@ -192,20 +232,25 @@ class LatentBatch:
s = torch.cat((s1, s2), dim=0) s = torch.cat((s1, s2), dim=0)
samples_out["samples"] = s samples_out["samples"] = s
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
return (samples_out,) return io.NodeOutput(samples_out)
class LatentBatchSeedBehavior: class LatentBatchSeedBehavior(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples": ("LATENT",), return io.Schema(
"seed_behavior": (["random", "fixed"],{"default": "fixed"}),}} node_id="LatentBatchSeedBehavior",
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples, seed_behavior) -> io.NodeOutput:
CATEGORY = "latent/advanced"
def op(self, samples, seed_behavior):
samples_out = samples.copy() samples_out = samples.copy()
latent = samples["samples"] latent = samples["samples"]
if seed_behavior == "random": if seed_behavior == "random":
@ -215,41 +260,50 @@ class LatentBatchSeedBehavior:
batch_number = samples_out.get("batch_index", [0])[0] batch_number = samples_out.get("batch_index", [0])[0]
samples_out["batch_index"] = [batch_number] * latent.shape[0] samples_out["batch_index"] = [batch_number] * latent.shape[0]
return (samples_out,) return io.NodeOutput(samples_out)
class LatentApplyOperation: class LatentApplyOperation(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples": ("LATENT",), return io.Schema(
"operation": ("LATENT_OPERATION",), node_id="LatentApplyOperation",
}} category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Latent.Input("samples"),
io.LatentOperation.Input("operation"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples, operation) -> io.NodeOutput:
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, samples, operation):
samples_out = samples.copy() samples_out = samples.copy()
s1 = samples["samples"] s1 = samples["samples"]
samples_out["samples"] = operation(latent=s1) samples_out["samples"] = operation(latent=s1)
return (samples_out,) return io.NodeOutput(samples_out)
class LatentApplyOperationCFG: class LatentApplyOperationCFG(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return io.Schema(
"operation": ("LATENT_OPERATION",), node_id="LatentApplyOperationCFG",
}} category="latent/advanced/operations",
RETURN_TYPES = ("MODEL",) is_experimental=True,
FUNCTION = "patch" inputs=[
io.Model.Input("model"),
io.LatentOperation.Input("operation"),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "latent/advanced/operations" @classmethod
EXPERIMENTAL = True def execute(cls, model, operation) -> io.NodeOutput:
def patch(self, model, operation):
m = model.clone() m = model.clone()
def pre_cfg_function(args): def pre_cfg_function(args):
@ -261,21 +315,25 @@ class LatentApplyOperationCFG:
return conds_out return conds_out
m.set_model_sampler_pre_cfg_function(pre_cfg_function) m.set_model_sampler_pre_cfg_function(pre_cfg_function)
return (m, ) return io.NodeOutput(m)
class LatentOperationTonemapReinhard: class LatentOperationTonemapReinhard(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), return io.Schema(
}} node_id="LatentOperationTonemapReinhard",
category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[
io.LatentOperation.Output(),
],
)
RETURN_TYPES = ("LATENT_OPERATION",) @classmethod
FUNCTION = "op" def execute(cls, multiplier) -> io.NodeOutput:
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, multiplier):
def tonemap_reinhard(latent, **kwargs): def tonemap_reinhard(latent, **kwargs):
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None] latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
normalized_latent = latent / latent_vector_magnitude normalized_latent = latent / latent_vector_magnitude
@ -291,39 +349,27 @@ class LatentOperationTonemapReinhard:
new_magnitude *= top new_magnitude *= top
return normalized_latent * new_magnitude return normalized_latent * new_magnitude
return (tonemap_reinhard,) return io.NodeOutput(tonemap_reinhard)
class LatentOperationSharpen: class LatentOperationSharpen(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"sharpen_radius": ("INT", { node_id="LatentOperationSharpen",
"default": 9, category="latent/advanced/operations",
"min": 1, is_experimental=True,
"max": 31, inputs=[
"step": 1 io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1),
}), io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
"sigma": ("FLOAT", { io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01),
"default": 1.0, ],
"min": 0.1, outputs=[
"max": 10.0, io.LatentOperation.Output(),
"step": 0.1 ],
}), )
"alpha": ("FLOAT", {
"default": 0.1,
"min": 0.0,
"max": 5.0,
"step": 0.01
}),
}}
RETURN_TYPES = ("LATENT_OPERATION",) @classmethod
FUNCTION = "op" def execute(cls, sharpen_radius, sigma, alpha) -> io.NodeOutput:
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, sharpen_radius, sigma, alpha):
def sharpen(latent, **kwargs): def sharpen(latent, **kwargs):
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None] luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
normalized_latent = latent / luminance normalized_latent = latent / luminance
@ -340,19 +386,27 @@ class LatentOperationSharpen:
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius] sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
return luminance * sharpened return luminance * sharpened
return (sharpen,) return io.NodeOutput(sharpen)
NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd, class LatentExtension(ComfyExtension):
"LatentSubtract": LatentSubtract, @override
"LatentMultiply": LatentMultiply, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"LatentInterpolate": LatentInterpolate, return [
"LatentConcat": LatentConcat, LatentAdd,
"LatentCut": LatentCut, LatentSubtract,
"LatentBatch": LatentBatch, LatentMultiply,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior, LatentInterpolate,
"LatentApplyOperation": LatentApplyOperation, LatentConcat,
"LatentApplyOperationCFG": LatentApplyOperationCFG, LatentCut,
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard, LatentBatch,
"LatentOperationSharpen": LatentOperationSharpen, LatentBatchSeedBehavior,
} LatentApplyOperation,
LatentApplyOperationCFG,
LatentOperationTonemapReinhard,
LatentOperationSharpen,
]
async def comfy_entrypoint() -> LatentExtension:
return LatentExtension()

View File

@ -5,6 +5,8 @@ import folder_paths
import os import os
import logging import logging
from enum import Enum from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
CLAMP_QUANTILE = 0.99 CLAMP_QUANTILE = 0.99
@ -71,32 +73,40 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu() output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
return output_sd return output_sd
class LoraSave: class LoraSave(io.ComfyNode):
def __init__(self): @classmethod
self.output_dir = folder_paths.get_output_directory() def define_schema(cls):
return io.Schema(
node_id="LoraSave",
display_name="Extract and Save Lora",
category="_for_testing",
inputs=[
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
io.Int.Input("rank", default=8, min=1, max=4096, step=1),
io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys())),
io.Boolean.Input("bias_diff", default=True),
io.Model.Input(
"model_diff",
tooltip="The ModelSubtract output to be converted to a lora.",
optional=True,
),
io.Clip.Input(
"text_encoder_diff",
tooltip="The CLIPSubtract output to be converted to a lora.",
optional=True,
),
],
is_experimental=True,
is_output_node=True,
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None) -> io.NodeOutput:
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
"lora_type": (tuple(LORA_TYPES.keys()),),
"bias_diff": ("BOOLEAN", {"default": True}),
},
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
if model_diff is None and text_encoder_diff is None: if model_diff is None and text_encoder_diff is None:
return {} return io.NodeOutput()
lora_type = LORA_TYPES.get(lora_type) lora_type = LORA_TYPES.get(lora_type)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
output_sd = {} output_sd = {}
if model_diff is not None: if model_diff is not None:
@ -108,12 +118,16 @@ class LoraSave:
output_checkpoint = os.path.join(full_output_folder, output_checkpoint) output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None) comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
return {} return io.NodeOutput()
NODE_CLASS_MAPPINGS = {
"LoraSave": LoraSave
}
NODE_DISPLAY_NAME_MAPPINGS = { class LoraSaveExtension(ComfyExtension):
"LoraSave": "Extract and Save Lora" @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LoraSave,
]
async def comfy_entrypoint() -> LoraSaveExtension:
return LoraSaveExtension()

View File

@ -34,6 +34,7 @@ class EmptyLTXVLatentVideo(io.ComfyNode):
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent}) return io.NodeOutput({"samples": latent})
generate = execute # TODO: remove
class LTXVImgToVideo(io.ComfyNode): class LTXVImgToVideo(io.ComfyNode):
@classmethod @classmethod
@ -77,6 +78,8 @@ class LTXVImgToVideo(io.ComfyNode):
return io.NodeOutput(positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}) return io.NodeOutput(positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask})
generate = execute # TODO: remove
def conditioning_get_any_value(conditioning, key, default=None): def conditioning_get_any_value(conditioning, key, default=None):
for t in conditioning: for t in conditioning:
@ -264,6 +267,8 @@ class LTXVAddGuide(io.ComfyNode):
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
generate = execute # TODO: remove
class LTXVCropGuides(io.ComfyNode): class LTXVCropGuides(io.ComfyNode):
@classmethod @classmethod
@ -300,6 +305,8 @@ class LTXVCropGuides(io.ComfyNode):
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
crop = execute # TODO: remove
class LTXVConditioning(io.ComfyNode): class LTXVConditioning(io.ComfyNode):
@classmethod @classmethod
@ -498,6 +505,7 @@ class LTXVPreprocess(io.ComfyNode):
output_images.append(preprocess(image[i], img_compression)) output_images.append(preprocess(image[i], img_compression))
return io.NodeOutput(torch.stack(output_images)) return io.NodeOutput(torch.stack(output_images))
preprocess = execute # TODO: remove
class LtxvExtension(ComfyExtension): class LtxvExtension(ComfyExtension):
@override @override

View File

@ -1,24 +1,33 @@
from typing_extensions import override
import comfy.utils import comfy.utils
from comfy_api.latest import ComfyExtension, io
class PatchModelAddDownscale:
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] class PatchModelAddDownscale(io.ComfyNode):
UPSCALE_METHODS = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return io.Schema(
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}), node_id="PatchModelAddDownscale",
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}), display_name="PatchModelAddDownscale (Kohya Deep Shrink)",
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), category="model_patches/unet",
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), inputs=[
"downscale_after_skip": ("BOOLEAN", {"default": True}), io.Model.Input("model"),
"downscale_method": (s.upscale_methods,), io.Int.Input("block_number", default=3, min=1, max=32, step=1),
"upscale_method": (s.upscale_methods,), io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001),
}} io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
RETURN_TYPES = ("MODEL",) io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001),
FUNCTION = "patch" io.Boolean.Input("downscale_after_skip", default=True),
io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS),
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "model_patches/unet" @classmethod
def execute(cls, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method) -> io.NodeOutput:
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
model_sampling = model.get_model_object("model_sampling") model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent) sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent) sigma_end = model_sampling.percent_to_sigma(end_percent)
@ -41,13 +50,21 @@ class PatchModelAddDownscale:
else: else:
m.set_model_input_block_patch(input_block_patch) m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch) m.set_model_output_block_patch(output_block_patch)
return (m, ) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"PatchModelAddDownscale": PatchModelAddDownscale,
}
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling # Sampling
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)", "PatchModelAddDownscale": "",
} }
class ModelDownscaleExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PatchModelAddDownscale,
]
async def comfy_entrypoint() -> ModelDownscaleExtension:
return ModelDownscaleExtension()

View File

@ -25,7 +25,7 @@ class PreviewAny():
value = str(source) value = str(source)
elif source is not None: elif source is not None:
try: try:
value = json.dumps(source) value = json.dumps(source, indent=4)
except Exception: except Exception:
try: try:
value = str(source) value = str(source)

View File

@ -3,64 +3,83 @@ import comfy.sd
import comfy.model_management import comfy.model_management
import nodes import nodes
import torch import torch
import comfy_extras.nodes_slg from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_slg import SkipLayerGuidanceDiT
class TripleCLIPLoader: class TripleCLIPLoader(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), ) return io.Schema(
}} node_id="TripleCLIPLoader",
RETURN_TYPES = ("CLIP",) category="advanced/loaders",
FUNCTION = "load_clip" description="[Recipes]\n\nsd3: clip-l, clip-g, t5",
inputs=[
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/loaders" @classmethod
def execute(cls, clip_name1, clip_name2, clip_name3) -> io.NodeOutput:
DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
def load_clip(self, clip_name1, clip_name2, clip_name3):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,) return io.NodeOutput(clip)
load_clip = execute # TODO: remove
class EmptySD3LatentImage: class EmptySD3LatentImage(io.ComfyNode):
def __init__(self): @classmethod
self.device = comfy.model_management.intermediate_device() def define_schema(cls):
return io.Schema(
node_id="EmptySD3LatentImage",
category="latent/sd3",
inputs=[
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), return io.NodeOutput({"samples":latent})
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/sd3" generate = execute # TODO: remove
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
class CLIPTextEncodeSD3: class CLIPTextEncodeSD3(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"clip": ("CLIP", ), node_id="CLIPTextEncodeSD3",
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), category="advanced/conditioning",
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), inputs=[
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), io.Clip.Input("clip"),
"empty_padding": (["none", "empty_prompt"], ) io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
}} io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
RETURN_TYPES = ("CONDITIONING",) io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
FUNCTION = "encode" io.Combo.Input("empty_padding", options=["none", "empty_prompt"]),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning" @classmethod
def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding) -> io.NodeOutput:
def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
no_padding = empty_padding == "none" no_padding = empty_padding == "none"
tokens = clip.tokenize(clip_g) tokens = clip.tokenize(clip_g)
@ -82,57 +101,112 @@ class CLIPTextEncodeSD3:
tokens["l"] += empty["l"] tokens["l"] += empty["l"]
while len(tokens["l"]) > len(tokens["g"]): while len(tokens["l"]) > len(tokens["g"]):
tokens["g"] += empty["g"] tokens["g"] += empty["g"]
return (clip.encode_from_tokens_scheduled(tokens), ) return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
encode = execute # TODO: remove
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): class ControlNetApplySD3(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls) -> io.Schema:
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="ControlNetApplySD3",
"control_net": ("CONTROL_NET", ), display_name="Apply Controlnet with VAE",
"vae": ("VAE", ), category="conditioning/controlnet",
"image": ("IMAGE", ), inputs=[
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), io.Conditioning.Input("positive"),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), io.Conditioning.Input("negative"),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) io.ControlNet.Input("control_net"),
}} io.Vae.Input("vae"),
CATEGORY = "conditioning/controlnet" io.Image.Input("image"),
DEPRECATED = True io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
is_deprecated=True,
)
@classmethod
def execute(cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None) -> io.NodeOutput:
if strength == 0:
return io.NodeOutput(positive, negative)
control_hint = image.movedim(-1, 1)
cnets = {}
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
prev_cnet = d.get('control', None)
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent),
vae=vae, extra_concat=[])
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
d['control'] = c_net
d['control_apply_to_uncond'] = False
n = [t[0], d]
c.append(n)
out.append(c)
return io.NodeOutput(out[0], out[1])
apply_controlnet = execute # TODO: remove
class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT): class SkipLayerGuidanceSD3(io.ComfyNode):
''' '''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Experimental implementation by Dango233@StabilityAI. Experimental implementation by Dango233@StabilityAI.
''' '''
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"model": ("MODEL", ), return io.Schema(
"layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), node_id="SkipLayerGuidanceSD3",
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), category="advanced/guidance",
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) inputs=[
}} io.Model.Input("model"),
RETURN_TYPES = ("MODEL",) io.String.Input("layers", default="7, 8, 9", multiline=False),
FUNCTION = "skip_guidance_sd3" io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Model.Output(),
],
is_experimental=True,
)
CATEGORY = "advanced/guidance" @classmethod
def execute(cls, model, layers, scale, start_percent, end_percent) -> io.NodeOutput:
return SkipLayerGuidanceDiT().execute(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent): skip_guidance_sd3 = execute # TODO: remove
return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
NODE_CLASS_MAPPINGS = { class SD3Extension(ComfyExtension):
"TripleCLIPLoader": TripleCLIPLoader, @override
"EmptySD3LatentImage": EmptySD3LatentImage, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"CLIPTextEncodeSD3": CLIPTextEncodeSD3, return [
"ControlNetApplySD3": ControlNetApplySD3, TripleCLIPLoader,
"SkipLayerGuidanceSD3": SkipLayerGuidanceSD3, EmptySD3LatentImage,
} CLIPTextEncodeSD3,
ControlNetApplySD3,
SkipLayerGuidanceSD3,
]
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling async def comfy_entrypoint() -> SD3Extension:
"ControlNetApplySD3": "Apply Controlnet with VAE", return SD3Extension()
}

View File

@ -1,33 +1,40 @@
import comfy.model_patcher import comfy.model_patcher
import comfy.samplers import comfy.samplers
import re import re
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class SkipLayerGuidanceDiT: class SkipLayerGuidanceDiT(io.ComfyNode):
''' '''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Original experimental implementation for SD3 by Dango233@StabilityAI. Original experimental implementation for SD3 by Dango233@StabilityAI.
''' '''
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"model": ("MODEL", ), return io.Schema(
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), node_id="SkipLayerGuidanceDiT",
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), category="advanced/guidance",
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), is_experimental=True,
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), inputs=[
"rescaling_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}), io.Model.Input("model"),
}} io.String.Input("double_layers", default="7, 8, 9"),
RETURN_TYPES = ("MODEL",) io.String.Input("single_layers", default="7, 8, 9"),
FUNCTION = "skip_guidance" io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
EXPERIMENTAL = True io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
io.Float.Input("rescaling_scale", default=0.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model." @classmethod
def execute(cls, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0) -> io.NodeOutput:
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0):
# check if layer is comma separated integers # check if layer is comma separated integers
def skip(args, extra_args): def skip(args, extra_args):
return args return args
@ -43,7 +50,7 @@ class SkipLayerGuidanceDiT:
single_layers = [int(i) for i in single_layers] single_layers = [int(i) for i in single_layers]
if len(double_layers) == 0 and len(single_layers) == 0: if len(double_layers) == 0 and len(single_layers) == 0:
return (model, ) return io.NodeOutput(model)
def post_cfg_function(args): def post_cfg_function(args):
model = args["model"] model = args["model"]
@ -76,29 +83,36 @@ class SkipLayerGuidanceDiT:
m = model.clone() m = model.clone()
m.set_model_sampler_post_cfg_function(post_cfg_function) m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m, ) return io.NodeOutput(m)
class SkipLayerGuidanceDiTSimple: skip_guidance = execute # TODO: remove
class SkipLayerGuidanceDiTSimple(io.ComfyNode):
''' '''
Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass. Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
''' '''
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"model": ("MODEL", ), return io.Schema(
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), node_id="SkipLayerGuidanceDiTSimple",
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), category="advanced/guidance",
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), description="Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.",
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), is_experimental=True,
}} inputs=[
RETURN_TYPES = ("MODEL",) io.Model.Input("model"),
FUNCTION = "skip_guidance" io.String.Input("double_layers", default="7, 8, 9"),
EXPERIMENTAL = True io.String.Input("single_layers", default="7, 8, 9"),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Model.Output(),
],
)
DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass." @classmethod
def execute(cls, model, start_percent, end_percent, double_layers="", single_layers="") -> io.NodeOutput:
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""):
def skip(args, extra_args): def skip(args, extra_args):
return args return args
@ -113,7 +127,7 @@ class SkipLayerGuidanceDiTSimple:
single_layers = [int(i) for i in single_layers] single_layers = [int(i) for i in single_layers]
if len(double_layers) == 0 and len(single_layers) == 0: if len(double_layers) == 0 and len(single_layers) == 0:
return (model, ) return io.NodeOutput(model)
def calc_cond_batch_function(args): def calc_cond_batch_function(args):
x = args["input"] x = args["input"]
@ -144,9 +158,19 @@ class SkipLayerGuidanceDiTSimple:
m = model.clone() m = model.clone()
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function) m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
return (m, ) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = { skip_guidance = execute # TODO: remove
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
"SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple,
} class SkipLayerGuidanceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SkipLayerGuidanceDiT,
SkipLayerGuidanceDiTSimple,
]
async def comfy_entrypoint() -> SkipLayerGuidanceExtension:
return SkipLayerGuidanceExtension()

View File

@ -1,6 +1,8 @@
import torch import torch
import nodes import nodes
import comfy.utils import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def camera_embeddings(elevation, azimuth): def camera_embeddings(elevation, azimuth):
elevation = torch.as_tensor([elevation]) elevation = torch.as_tensor([elevation])
@ -20,26 +22,31 @@ def camera_embeddings(elevation, azimuth):
return embeddings return embeddings
class StableZero123_Conditioning: class StableZero123_Conditioning(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "clip_vision": ("CLIP_VISION",), return io.Schema(
"init_image": ("IMAGE",), node_id="StableZero123_Conditioning",
"vae": ("VAE",), category="conditioning/3d_models",
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), inputs=[
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), io.ClipVision.Input("clip_vision"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Image.Input("init_image"),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), io.Vae.Input("vae"),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), io.Int.Input("width", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
}} io.Int.Input("height", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") io.Int.Input("batch_size", default=1, min=1, max=4096),
RETURN_NAMES = ("positive", "negative", "latent") io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False)
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent")
]
)
FUNCTION = "encode" @classmethod
def execute(cls, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth) -> io.NodeOutput:
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
output = clip_vision.encode_image(init_image) output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0) pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
@ -51,30 +58,35 @@ class StableZero123_Conditioning:
positive = [[cond, {"concat_latent_image": t}]] positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8]) latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent}) return io.NodeOutput(positive, negative, {"samples":latent})
class StableZero123_Conditioning_Batched: class StableZero123_Conditioning_Batched(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "clip_vision": ("CLIP_VISION",), return io.Schema(
"init_image": ("IMAGE",), node_id="StableZero123_Conditioning_Batched",
"vae": ("VAE",), category="conditioning/3d_models",
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), inputs=[
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), io.ClipVision.Input("clip_vision"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Image.Input("init_image"),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), io.Vae.Input("vae"),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), io.Int.Input("width", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), io.Int.Input("height", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), io.Int.Input("batch_size", default=1, min=1, max=4096),
}} io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
RETURN_NAMES = ("positive", "negative", "latent") io.Float.Input("elevation_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
io.Float.Input("azimuth_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False)
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent")
]
)
FUNCTION = "encode" @classmethod
def execute(cls, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment) -> io.NodeOutput:
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment):
output = clip_vision.encode_image(init_image) output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0) pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
@ -93,27 +105,32 @@ class StableZero123_Conditioning_Batched:
positive = [[cond, {"concat_latent_image": t}]] positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8]) latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) return io.NodeOutput(positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
class SV3D_Conditioning: class SV3D_Conditioning(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "clip_vision": ("CLIP_VISION",), return io.Schema(
"init_image": ("IMAGE",), node_id="SV3D_Conditioning",
"vae": ("VAE",), category="conditioning/3d_models",
"width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), inputs=[
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), io.ClipVision.Input("clip_vision"),
"video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}), io.Image.Input("init_image"),
"elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}), io.Vae.Input("vae"),
}} io.Int.Input("width", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
RETURN_NAMES = ("positive", "negative", "latent") io.Int.Input("video_frames", default=21, min=1, max=4096),
io.Float.Input("elevation", default=0.0, min=-90.0, max=90.0, step=0.1, round=False)
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent")
]
)
FUNCTION = "encode" @classmethod
def execute(cls, clip_vision, init_image, vae, width, height, video_frames, elevation) -> io.NodeOutput:
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation):
output = clip_vision.encode_image(init_image) output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0) pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
@ -133,11 +150,17 @@ class SV3D_Conditioning:
positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]] positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8]) latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent}) return io.NodeOutput(positive, negative, {"samples":latent})
NODE_CLASS_MAPPINGS = { class Stable3DExtension(ComfyExtension):
"StableZero123_Conditioning": StableZero123_Conditioning, @override
"StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"SV3D_Conditioning": SV3D_Conditioning, return [
} StableZero123_Conditioning,
StableZero123_Conditioning_Batched,
SV3D_Conditioning,
]
async def comfy_entrypoint() -> Stable3DExtension:
return Stable3DExtension()

View File

@ -4,6 +4,8 @@ from comfy import model_management
import torch import torch
import comfy.utils import comfy.utils
import folder_paths import folder_paths
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
try: try:
from spandrel_extra_arches import EXTRA_REGISTRY from spandrel_extra_arches import EXTRA_REGISTRY
@ -13,17 +15,23 @@ try:
except: except:
pass pass
class UpscaleModelLoader: class UpscaleModelLoader(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ), return io.Schema(
}} node_id="UpscaleModelLoader",
RETURN_TYPES = ("UPSCALE_MODEL",) display_name="Load Upscale Model",
FUNCTION = "load_model" category="loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("upscale_models")),
],
outputs=[
io.UpscaleModel.Output(),
],
)
CATEGORY = "loaders" @classmethod
def execute(cls, model_name) -> io.NodeOutput:
def load_model(self, model_name):
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name) model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True) sd = comfy.utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
@ -33,21 +41,29 @@ class UpscaleModelLoader:
if not isinstance(out, ImageModelDescriptor): if not isinstance(out, ImageModelDescriptor):
raise Exception("Upscale model must be a single-image model.") raise Exception("Upscale model must be a single-image model.")
return (out, ) return io.NodeOutput(out)
load_model = execute # TODO: remove
class ImageUpscaleWithModel: class ImageUpscaleWithModel(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "upscale_model": ("UPSCALE_MODEL",), return io.Schema(
"image": ("IMAGE",), node_id="ImageUpscaleWithModel",
}} display_name="Upscale Image (using Model)",
RETURN_TYPES = ("IMAGE",) category="image/upscaling",
FUNCTION = "upscale" inputs=[
io.UpscaleModel.Input("upscale_model"),
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
],
)
CATEGORY = "image/upscaling" @classmethod
def execute(cls, upscale_model, image) -> io.NodeOutput:
def upscale(self, upscale_model, image):
device = model_management.get_torch_device() device = model_management.get_torch_device()
memory_required = model_management.module_size(upscale_model.model) memory_required = model_management.module_size(upscale_model.model)
@ -75,9 +91,19 @@ class ImageUpscaleWithModel:
upscale_model.to("cpu") upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
return (s,) return io.NodeOutput(s)
NODE_CLASS_MAPPINGS = { upscale = execute # TODO: remove
"UpscaleModelLoader": UpscaleModelLoader,
"ImageUpscaleWithModel": ImageUpscaleWithModel
} class UpscaleModelExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
UpscaleModelLoader,
ImageUpscaleWithModel,
]
async def comfy_entrypoint() -> UpscaleModelExtension:
return UpscaleModelExtension()

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.62" __version__ = "0.3.65"

View File

@ -1,25 +1,5 @@
#Rename this to extra_model_paths.yaml and ComfyUI will load it #Rename this to extra_model_paths.yaml and ComfyUI will load it
#config for a1111 ui
#all you have to do is change the base_path to where yours is installed
a111:
base_path: path/to/stable-diffusion-webui/
checkpoints: models/Stable-diffusion
configs: models/Stable-diffusion
vae: models/VAE
loras: |
models/Lora
models/LyCORIS
upscale_models: |
models/ESRGAN
models/RealESRGAN
models/SwinIR
embeddings: embeddings
hypernetworks: models/hypernetworks
controlnet: models/ControlNet
#config for comfyui #config for comfyui
#your base path should be either an existing comfy install or a central folder where you store all of your models, loras, etc. #your base path should be either an existing comfy install or a central folder where you store all of your models, loras, etc.
@ -28,7 +8,9 @@ a111:
# # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads # # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads
# #is_default: true # #is_default: true
# checkpoints: models/checkpoints/ # checkpoints: models/checkpoints/
# clip: models/clip/ # text_encoders: |
# models/text_encoders/
# models/clip/ # legacy location still supported
# clip_vision: models/clip_vision/ # clip_vision: models/clip_vision/
# configs: models/configs/ # configs: models/configs/
# controlnet: models/controlnet/ # controlnet: models/controlnet/
@ -39,6 +21,32 @@ a111:
# loras: models/loras/ # loras: models/loras/
# upscale_models: models/upscale_models/ # upscale_models: models/upscale_models/
# vae: models/vae/ # vae: models/vae/
# audio_encoders: models/audio_encoders/
# model_patches: models/model_patches/
#config for a1111 ui
#all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed
#a111:
# base_path: path/to/stable-diffusion-webui/
# checkpoints: models/Stable-diffusion
# configs: models/Stable-diffusion
# vae: models/VAE
# loras: |
# models/Lora
# models/LyCORIS
# upscale_models: |
# models/ESRGAN
# models/RealESRGAN
# models/SwinIR
# embeddings: embeddings
# hypernetworks: models/hypernetworks
# controlnet: models/ControlNet
# For a full list of supported keys (style_models, vae_approx, hypernetworks, photomaker,
# model_patches, audio_encoders, classifiers, etc.) see folder_paths.py.
#other_ui: #other_ui:
# base_path: path/to/ui # base_path: path/to/ui

View File

@ -2027,7 +2027,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"DiffControlNetLoader": "Load ControlNet Model (diff)", "DiffControlNetLoader": "Load ControlNet Model (diff)",
"StyleModelLoader": "Load Style Model", "StyleModelLoader": "Load Style Model",
"CLIPVisionLoader": "Load CLIP Vision", "CLIPVisionLoader": "Load CLIP Vision",
"UpscaleModelLoader": "Load Upscale Model",
"UNETLoader": "Load Diffusion Model", "UNETLoader": "Load Diffusion Model",
# Conditioning # Conditioning
"CLIPVisionEncode": "CLIP Vision Encode", "CLIPVisionEncode": "CLIP Vision Encode",
@ -2065,7 +2064,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LoadImageOutput": "Load Image (from Outputs)", "LoadImageOutput": "Load Image (from Outputs)",
"ImageScale": "Upscale Image", "ImageScale": "Upscale Image",
"ImageScaleBy": "Upscale Image By", "ImageScaleBy": "Upscale Image By",
"ImageUpscaleWithModel": "Upscale Image (using Model)",
"ImageInvert": "Invert Image", "ImageInvert": "Invert Image",
"ImagePadForOutpaint": "Pad Image for Outpainting", "ImagePadForOutpaint": "Pad Image for Outpainting",
"ImageBatch": "Batch Images", "ImageBatch": "Batch Images",
@ -2358,6 +2356,7 @@ async def init_builtin_api_nodes():
"nodes_stability.py", "nodes_stability.py",
"nodes_pika.py", "nodes_pika.py",
"nodes_runway.py", "nodes_runway.py",
"nodes_sora.py",
"nodes_tripo.py", "nodes_tripo.py",
"nodes_moonvalley.py", "nodes_moonvalley.py",
"nodes_rodin.py", "nodes_rodin.py",

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.62" version = "0.3.65"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"
@ -57,20 +57,13 @@ messages_control.disable = [
"redefined-builtin", "redefined-builtin",
"unnecessary-lambda", "unnecessary-lambda",
"dangerous-default-value", "dangerous-default-value",
"invalid-overridden-method",
# next warnings should be fixed in future # next warnings should be fixed in future
"bad-classmethod-argument", # Class method should have 'cls' as first argument "bad-classmethod-argument", # Class method should have 'cls' as first argument
"wrong-import-order", # Standard imports should be placed before third party imports "wrong-import-order", # Standard imports should be placed before third party imports
"logging-fstring-interpolation", # Use lazy % formatting in logging functions
"ungrouped-imports", "ungrouped-imports",
"unnecessary-pass", "unnecessary-pass",
"unidiomatic-typecheck",
"unnecessary-lambda-assignment", "unnecessary-lambda-assignment",
"no-else-return", "no-else-return",
"no-else-raise",
"invalid-overridden-method",
"unused-variable", "unused-variable",
"pointless-string-statement",
"inconsistent-return-statements",
"import-outside-toplevel",
"redefined-outer-name",
] ]

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.27.7 comfyui-frontend-package==1.27.10
comfyui-workflow-templates==0.1.91 comfyui-workflow-templates==0.1.95
comfyui-embedded-docs==0.2.6 comfyui-embedded-docs==0.3.0
torch torch
torchsde torchsde
torchvision torchvision
@ -25,6 +25,5 @@ av>=14.2.0
#non essential dependencies: #non essential dependencies:
kornia>=0.7.1 kornia>=0.7.1
spandrel spandrel
soundfile
pydantic~=2.0 pydantic~=2.0
pydantic-settings~=2.0 pydantic-settings~=2.0