mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
8cdc246450
@ -185,6 +185,7 @@ def create_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument("--otel-service-name", type=str, default="comfyui", env_var="OTEL_SERVICE_NAME", help="The name of the service or application that is generating telemetry data.")
|
parser.add_argument("--otel-service-name", type=str, default="comfyui", env_var="OTEL_SERVICE_NAME", help="The name of the service or application that is generating telemetry data.")
|
||||||
parser.add_argument("--otel-service-version", type=str, default=__version__, env_var="OTEL_SERVICE_VERSION", help="The version of the service or application that is generating telemetry data.")
|
parser.add_argument("--otel-service-version", type=str, default=__version__, env_var="OTEL_SERVICE_VERSION", help="The version of the service or application that is generating telemetry data.")
|
||||||
parser.add_argument("--otel-exporter-otlp-endpoint", type=str, default=None, env_var="OTEL_EXPORTER_OTLP_ENDPOINT", help="A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.")
|
parser.add_argument("--otel-exporter-otlp-endpoint", type=str, default=None, env_var="OTEL_EXPORTER_OTLP_ENDPOINT", help="A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.")
|
||||||
|
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||||
|
|
||||||
# now give plugins a chance to add configuration
|
# now give plugins a chance to add configuration
|
||||||
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
||||||
|
|||||||
@ -89,6 +89,7 @@ class Configuration(dict):
|
|||||||
otel_service_name (str): The name of the service or application that is generating telemetry data. Default: "comfyui".
|
otel_service_name (str): The name of the service or application that is generating telemetry data. Default: "comfyui".
|
||||||
otel_service_version (str): The version of the service or application that is generating telemetry data. Default: "0.0.1".
|
otel_service_version (str): The version of the service or application that is generating telemetry data. Default: "0.0.1".
|
||||||
otel_exporter_otlp_endpoint (Optional[str]): A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.
|
otel_exporter_otlp_endpoint (Optional[str]): A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.
|
||||||
|
force_channels_last (bool): Force channels last format when inferencing the models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -156,6 +157,7 @@ class Configuration(dict):
|
|||||||
self.external_address: Optional[str] = None
|
self.external_address: Optional[str] = None
|
||||||
self.disable_known_models: bool = False
|
self.disable_known_models: bool = False
|
||||||
self.max_queue_size: int = 65536
|
self.max_queue_size: int = 65536
|
||||||
|
self.force_channels_last: bool = False
|
||||||
|
|
||||||
# from opentracing docs
|
# from opentracing docs
|
||||||
self.otel_service_name: str = "comfyui"
|
self.otel_service_name: str = "comfyui"
|
||||||
|
|||||||
@ -70,7 +70,7 @@ def get_previewer(device, latent_format):
|
|||||||
|
|
||||||
if method == LatentPreviewMethod.TAESD:
|
if method == LatentPreviewMethod.TAESD:
|
||||||
if taesd_decoder_path:
|
if taesd_decoder_path:
|
||||||
taesd = TAESD(None, taesd_decoder_path).to(device)
|
taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
|
||||||
previewer = TAESDPreviewerImpl(taesd)
|
previewer = TAESDPreviewerImpl(taesd)
|
||||||
else:
|
else:
|
||||||
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
||||||
|
|||||||
@ -129,8 +129,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
if s_churn > 0:
|
||||||
sigma_hat = sigmas[i] * (gamma + 1)
|
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||||
|
sigma_hat = sigmas[i] * (gamma + 1)
|
||||||
|
else:
|
||||||
|
gamma = 0
|
||||||
|
sigma_hat = sigmas[i]
|
||||||
|
|
||||||
if gamma > 0:
|
if gamma > 0:
|
||||||
eps = torch.randn_like(x) * s_noise
|
eps = torch.randn_like(x) * s_noise
|
||||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||||
@ -170,7 +175,13 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
if s_churn > 0:
|
||||||
|
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||||
|
sigma_hat = sigmas[i] * (gamma + 1)
|
||||||
|
else:
|
||||||
|
gamma = 0
|
||||||
|
sigma_hat = sigmas[i]
|
||||||
|
|
||||||
sigma_hat = sigmas[i] * (gamma + 1)
|
sigma_hat = sigmas[i] * (gamma + 1)
|
||||||
if gamma > 0:
|
if gamma > 0:
|
||||||
eps = torch.randn_like(x) * s_noise
|
eps = torch.randn_like(x) * s_noise
|
||||||
@ -199,8 +210,13 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
if s_churn > 0:
|
||||||
sigma_hat = sigmas[i] * (gamma + 1)
|
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||||
|
sigma_hat = sigmas[i] * (gamma + 1)
|
||||||
|
else:
|
||||||
|
gamma = 0
|
||||||
|
sigma_hat = sigmas[i]
|
||||||
|
|
||||||
if gamma > 0:
|
if gamma > 0:
|
||||||
eps = torch.randn_like(x) * s_noise
|
eps = torch.randn_like(x) * s_noise
|
||||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||||
|
|||||||
@ -129,9 +129,13 @@ class SD3(LatentFormat):
|
|||||||
[-0.0749, -0.0634, -0.0456],
|
[-0.0749, -0.0634, -0.0456],
|
||||||
[-0.1418, -0.1457, -0.1259]
|
[-0.1418, -0.1457, -0.1259]
|
||||||
]
|
]
|
||||||
|
self.taesd_decoder_name = "taesd3_decoder"
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
return (latent - self.shift_factor) * self.scale_factor
|
return (latent - self.shift_factor) * self.scale_factor
|
||||||
|
|
||||||
def process_out(self, latent):
|
def process_out(self, latent):
|
||||||
return (latent / self.scale_factor) + self.shift_factor
|
return (latent / self.scale_factor) + self.shift_factor
|
||||||
|
|
||||||
|
class StableAudio1(LatentFormat):
|
||||||
|
latent_channels = 64
|
||||||
|
|||||||
282
comfy/ldm/audio/autoencoder.py
Normal file
282
comfy/ldm/audio/autoencoder.py
Normal file
@ -0,0 +1,282 @@
|
|||||||
|
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from typing import Literal, Dict, Any
|
||||||
|
import math
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
def vae_sample(mean, scale):
|
||||||
|
stdev = nn.functional.softplus(scale) + 1e-4
|
||||||
|
var = stdev * stdev
|
||||||
|
logvar = torch.log(var)
|
||||||
|
latents = torch.randn_like(mean) * stdev + mean
|
||||||
|
|
||||||
|
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
||||||
|
|
||||||
|
return latents, kl
|
||||||
|
|
||||||
|
class VAEBottleneck(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.is_discrete = False
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False, **kwargs):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
mean, scale = x.chunk(2, dim=1)
|
||||||
|
|
||||||
|
x, kl = vae_sample(mean, scale)
|
||||||
|
|
||||||
|
info["kl"] = kl
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def snake_beta(x, alpha, beta):
|
||||||
|
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
|
||||||
|
|
||||||
|
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
|
||||||
|
class SnakeBeta(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
||||||
|
super(SnakeBeta, self).__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
|
||||||
|
# initialize alpha
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||||
|
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
else: # linear scale alphas initialized to ones
|
||||||
|
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
||||||
|
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
||||||
|
|
||||||
|
# self.alpha.requires_grad = alpha_trainable
|
||||||
|
# self.beta.requires_grad = alpha_trainable
|
||||||
|
|
||||||
|
self.no_div_by_zero = 0.000000001
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T]
|
||||||
|
beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device)
|
||||||
|
if self.alpha_logscale:
|
||||||
|
alpha = torch.exp(alpha)
|
||||||
|
beta = torch.exp(beta)
|
||||||
|
x = snake_beta(x, alpha, beta)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def WNConv1d(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||||
|
except:
|
||||||
|
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||||
|
|
||||||
|
def WNConvTranspose1d(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||||
|
except:
|
||||||
|
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||||
|
|
||||||
|
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||||
|
if activation == "elu":
|
||||||
|
act = torch.nn.ELU()
|
||||||
|
elif activation == "snake":
|
||||||
|
act = SnakeBeta(channels)
|
||||||
|
elif activation == "none":
|
||||||
|
act = torch.nn.Identity()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown activation {activation}")
|
||||||
|
|
||||||
|
if antialias:
|
||||||
|
act = Activation1d(act)
|
||||||
|
|
||||||
|
return act
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualUnit(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dilation = dilation
|
||||||
|
|
||||||
|
padding = (dilation * (7-1)) // 2
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||||
|
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||||
|
kernel_size=7, dilation=dilation, padding=padding),
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||||
|
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
||||||
|
kernel_size=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = x
|
||||||
|
|
||||||
|
#x = checkpoint(self.layers, x)
|
||||||
|
x = self.layers(x)
|
||||||
|
|
||||||
|
return x + res
|
||||||
|
|
||||||
|
class EncoderBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
ResidualUnit(in_channels=in_channels,
|
||||||
|
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
||||||
|
ResidualUnit(in_channels=in_channels,
|
||||||
|
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
||||||
|
ResidualUnit(in_channels=in_channels,
|
||||||
|
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||||
|
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||||
|
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
class DecoderBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if use_nearest_upsample:
|
||||||
|
upsample_layer = nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=stride, mode="nearest"),
|
||||||
|
WNConv1d(in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=2*stride,
|
||||||
|
stride=1,
|
||||||
|
bias=False,
|
||||||
|
padding='same')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||||
|
upsample_layer,
|
||||||
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||||
|
dilation=1, use_snake=use_snake),
|
||||||
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||||
|
dilation=3, use_snake=use_snake),
|
||||||
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||||
|
dilation=9, use_snake=use_snake),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
class OobleckEncoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels=2,
|
||||||
|
channels=128,
|
||||||
|
latent_dim=32,
|
||||||
|
c_mults = [1, 2, 4, 8],
|
||||||
|
strides = [2, 4, 8, 8],
|
||||||
|
use_snake=False,
|
||||||
|
antialias_activation=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
c_mults = [1] + c_mults
|
||||||
|
|
||||||
|
self.depth = len(c_mults)
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(self.depth-1):
|
||||||
|
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
||||||
|
|
||||||
|
layers += [
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
||||||
|
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
class OobleckDecoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
out_channels=2,
|
||||||
|
channels=128,
|
||||||
|
latent_dim=32,
|
||||||
|
c_mults = [1, 2, 4, 8],
|
||||||
|
strides = [2, 4, 8, 8],
|
||||||
|
use_snake=False,
|
||||||
|
antialias_activation=False,
|
||||||
|
use_nearest_upsample=False,
|
||||||
|
final_tanh=True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
c_mults = [1] + c_mults
|
||||||
|
|
||||||
|
self.depth = len(c_mults)
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(self.depth-1, 0, -1):
|
||||||
|
layers += [DecoderBlock(
|
||||||
|
in_channels=c_mults[i]*channels,
|
||||||
|
out_channels=c_mults[i-1]*channels,
|
||||||
|
stride=strides[i-1],
|
||||||
|
use_snake=use_snake,
|
||||||
|
antialias_activation=antialias_activation,
|
||||||
|
use_nearest_upsample=use_nearest_upsample
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
layers += [
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
||||||
|
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
||||||
|
nn.Tanh() if final_tanh else nn.Identity()
|
||||||
|
]
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioOobleckVAE(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels=2,
|
||||||
|
channels=128,
|
||||||
|
latent_dim=64,
|
||||||
|
c_mults = [1, 2, 4, 8, 16],
|
||||||
|
strides = [2, 4, 4, 8, 8],
|
||||||
|
use_snake=True,
|
||||||
|
antialias_activation=False,
|
||||||
|
use_nearest_upsample=False,
|
||||||
|
final_tanh=False):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation)
|
||||||
|
self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation,
|
||||||
|
use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh)
|
||||||
|
self.bottleneck = VAEBottleneck()
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self.bottleneck.encode(self.encoder(x))
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return self.decoder(self.bottleneck.decode(x))
|
||||||
|
|
||||||
888
comfy/ldm/audio/dit.py
Normal file
888
comfy/ldm/audio/dit.py
Normal file
@ -0,0 +1,888 @@
|
|||||||
|
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
||||||
|
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
import typing as tp
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
class FourierFeatures(nn.Module):
|
||||||
|
def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
|
||||||
|
super().__init__()
|
||||||
|
assert out_features % 2 == 0
|
||||||
|
self.weight = nn.Parameter(torch.empty(
|
||||||
|
[out_features // 2, in_features], dtype=dtype, device=device))
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
f = 2 * math.pi * input @ self.weight.T.to(dtype=input.dtype, device=input.device)
|
||||||
|
return torch.cat([f.cos(), f.sin()], dim=-1)
|
||||||
|
|
||||||
|
# norms
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None):
|
||||||
|
"""
|
||||||
|
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
|
||||||
|
else:
|
||||||
|
self.beta = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
beta = self.beta
|
||||||
|
if self.beta is not None:
|
||||||
|
beta = beta.to(dtype=x.dtype, device=x.device)
|
||||||
|
return F.layer_norm(x, x.shape[-1:], weight=self.gamma.to(dtype=x.dtype, device=x.device), bias=beta)
|
||||||
|
|
||||||
|
class GLU(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_in,
|
||||||
|
dim_out,
|
||||||
|
activation,
|
||||||
|
use_conv = False,
|
||||||
|
conv_kernel_size = 3,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.act = activation
|
||||||
|
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2), dtype=dtype, device=device)
|
||||||
|
self.use_conv = use_conv
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_conv:
|
||||||
|
x = rearrange(x, 'b n d -> b d n')
|
||||||
|
x = self.proj(x)
|
||||||
|
x = rearrange(x, 'b d n -> b n d')
|
||||||
|
else:
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
|
x, gate = x.chunk(2, dim = -1)
|
||||||
|
return x * self.act(gate)
|
||||||
|
|
||||||
|
class AbsolutePositionalEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, max_seq_len):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim ** -0.5
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.emb = nn.Embedding(max_seq_len, dim)
|
||||||
|
|
||||||
|
def forward(self, x, pos = None, seq_start_pos = None):
|
||||||
|
seq_len, device = x.shape[1], x.device
|
||||||
|
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
||||||
|
|
||||||
|
if pos is None:
|
||||||
|
pos = torch.arange(seq_len, device = device)
|
||||||
|
|
||||||
|
if seq_start_pos is not None:
|
||||||
|
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
||||||
|
|
||||||
|
pos_emb = self.emb(pos)
|
||||||
|
pos_emb = pos_emb * self.scale
|
||||||
|
return pos_emb
|
||||||
|
|
||||||
|
class ScaledSinusoidalEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, theta = 10000):
|
||||||
|
super().__init__()
|
||||||
|
assert (dim % 2) == 0, 'dimension must be divisible by 2'
|
||||||
|
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
||||||
|
|
||||||
|
half_dim = dim // 2
|
||||||
|
freq_seq = torch.arange(half_dim).float() / half_dim
|
||||||
|
inv_freq = theta ** -freq_seq
|
||||||
|
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
||||||
|
|
||||||
|
def forward(self, x, pos = None, seq_start_pos = None):
|
||||||
|
seq_len, device = x.shape[1], x.device
|
||||||
|
|
||||||
|
if pos is None:
|
||||||
|
pos = torch.arange(seq_len, device = device)
|
||||||
|
|
||||||
|
if seq_start_pos is not None:
|
||||||
|
pos = pos - seq_start_pos[..., None]
|
||||||
|
|
||||||
|
emb = torch.einsum('i, j -> i j', pos, self.inv_freq)
|
||||||
|
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||||
|
return emb * self.scale
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
use_xpos = False,
|
||||||
|
scale_base = 512,
|
||||||
|
interpolation_factor = 1.,
|
||||||
|
base = 10000,
|
||||||
|
base_rescale_factor = 1.
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||||
|
# has some connection to NTK literature
|
||||||
|
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||||
|
base *= base_rescale_factor ** (dim / (dim - 2))
|
||||||
|
|
||||||
|
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||||
|
self.register_buffer('inv_freq', inv_freq)
|
||||||
|
|
||||||
|
assert interpolation_factor >= 1.
|
||||||
|
self.interpolation_factor = interpolation_factor
|
||||||
|
|
||||||
|
if not use_xpos:
|
||||||
|
self.register_buffer('scale', None)
|
||||||
|
return
|
||||||
|
|
||||||
|
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
||||||
|
|
||||||
|
self.scale_base = scale_base
|
||||||
|
self.register_buffer('scale', scale)
|
||||||
|
|
||||||
|
def forward_from_seq_len(self, seq_len, device, dtype):
|
||||||
|
# device = self.inv_freq.device
|
||||||
|
|
||||||
|
t = torch.arange(seq_len, device=device, dtype=dtype)
|
||||||
|
return self.forward(t)
|
||||||
|
|
||||||
|
def forward(self, t):
|
||||||
|
# device = self.inv_freq.device
|
||||||
|
device = t.device
|
||||||
|
dtype = t.dtype
|
||||||
|
|
||||||
|
# t = t.to(torch.float32)
|
||||||
|
|
||||||
|
t = t / self.interpolation_factor
|
||||||
|
|
||||||
|
freqs = torch.einsum('i , j -> i j', t, self.inv_freq.to(dtype=dtype, device=device))
|
||||||
|
freqs = torch.cat((freqs, freqs), dim = -1)
|
||||||
|
|
||||||
|
if self.scale is None:
|
||||||
|
return freqs, 1.
|
||||||
|
|
||||||
|
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
||||||
|
scale = self.scale.to(dtype=dtype, device=device) ** rearrange(power, 'n -> n 1')
|
||||||
|
scale = torch.cat((scale, scale), dim = -1)
|
||||||
|
|
||||||
|
return freqs, scale
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x = rearrange(x, '... (j d) -> ... j d', j = 2)
|
||||||
|
x1, x2 = x.unbind(dim = -2)
|
||||||
|
return torch.cat((-x2, x1), dim = -1)
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
||||||
|
out_dtype = t.dtype
|
||||||
|
|
||||||
|
# cast to float32 if necessary for numerical stability
|
||||||
|
dtype = t.dtype #reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
|
||||||
|
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
||||||
|
freqs, t = freqs.to(dtype), t.to(dtype)
|
||||||
|
freqs = freqs[-seq_len:, :]
|
||||||
|
|
||||||
|
if t.ndim == 4 and freqs.ndim == 3:
|
||||||
|
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
||||||
|
|
||||||
|
# partial rotary embeddings, Wang et al. GPT-J
|
||||||
|
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
||||||
|
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
||||||
|
|
||||||
|
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
|
||||||
|
|
||||||
|
return torch.cat((t, t_unrotated), dim = -1)
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_out = None,
|
||||||
|
mult = 4,
|
||||||
|
no_bias = False,
|
||||||
|
glu = True,
|
||||||
|
use_conv = False,
|
||||||
|
conv_kernel_size = 3,
|
||||||
|
zero_init_output = True,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
|
||||||
|
# Default to SwiGLU
|
||||||
|
|
||||||
|
activation = nn.SiLU()
|
||||||
|
|
||||||
|
dim_out = dim if dim_out is None else dim_out
|
||||||
|
|
||||||
|
if glu:
|
||||||
|
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
linear_in = nn.Sequential(
|
||||||
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
|
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
|
||||||
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
|
activation
|
||||||
|
)
|
||||||
|
|
||||||
|
linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# # init last linear layer to 0
|
||||||
|
# if zero_init_output:
|
||||||
|
# nn.init.zeros_(linear_out.weight)
|
||||||
|
# if not no_bias:
|
||||||
|
# nn.init.zeros_(linear_out.bias)
|
||||||
|
|
||||||
|
|
||||||
|
self.ff = nn.Sequential(
|
||||||
|
linear_in,
|
||||||
|
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||||
|
linear_out,
|
||||||
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.ff(x)
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_heads = 64,
|
||||||
|
dim_context = None,
|
||||||
|
causal = False,
|
||||||
|
zero_init_output=True,
|
||||||
|
qk_norm = False,
|
||||||
|
natten_kernel_size = None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.dim_heads = dim_heads
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
|
dim_kv = dim_context if dim_context is not None else dim
|
||||||
|
|
||||||
|
self.num_heads = dim // dim_heads
|
||||||
|
self.kv_heads = dim_kv // dim_heads
|
||||||
|
|
||||||
|
if dim_context is not None:
|
||||||
|
self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# if zero_init_output:
|
||||||
|
# nn.init.zeros_(self.to_out.weight)
|
||||||
|
|
||||||
|
self.qk_norm = qk_norm
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
context = None,
|
||||||
|
mask = None,
|
||||||
|
context_mask = None,
|
||||||
|
rotary_pos_emb = None,
|
||||||
|
causal = None
|
||||||
|
):
|
||||||
|
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
||||||
|
|
||||||
|
kv_input = context if has_context else x
|
||||||
|
|
||||||
|
if hasattr(self, 'to_q'):
|
||||||
|
# Use separate linear projections for q and k/v
|
||||||
|
q = self.to_q(x)
|
||||||
|
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
||||||
|
|
||||||
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||||
|
|
||||||
|
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
|
||||||
|
else:
|
||||||
|
# Use fused linear projection
|
||||||
|
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||||
|
|
||||||
|
# Normalize q and k for cosine sim attention
|
||||||
|
if self.qk_norm:
|
||||||
|
q = F.normalize(q, dim=-1)
|
||||||
|
k = F.normalize(k, dim=-1)
|
||||||
|
|
||||||
|
if rotary_pos_emb is not None and not has_context:
|
||||||
|
freqs, _ = rotary_pos_emb
|
||||||
|
|
||||||
|
q_dtype = q.dtype
|
||||||
|
k_dtype = k.dtype
|
||||||
|
|
||||||
|
q = q.to(torch.float32)
|
||||||
|
k = k.to(torch.float32)
|
||||||
|
freqs = freqs.to(torch.float32)
|
||||||
|
|
||||||
|
q = apply_rotary_pos_emb(q, freqs)
|
||||||
|
k = apply_rotary_pos_emb(k, freqs)
|
||||||
|
|
||||||
|
q = q.to(q_dtype)
|
||||||
|
k = k.to(k_dtype)
|
||||||
|
|
||||||
|
input_mask = context_mask
|
||||||
|
|
||||||
|
if input_mask is None and not has_context:
|
||||||
|
input_mask = mask
|
||||||
|
|
||||||
|
# determine masking
|
||||||
|
masks = []
|
||||||
|
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
|
||||||
|
|
||||||
|
if input_mask is not None:
|
||||||
|
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
|
||||||
|
masks.append(~input_mask)
|
||||||
|
|
||||||
|
# Other masks will be added here later
|
||||||
|
|
||||||
|
if len(masks) > 0:
|
||||||
|
final_attn_mask = ~or_reduce(masks)
|
||||||
|
|
||||||
|
n, device = q.shape[-2], q.device
|
||||||
|
|
||||||
|
causal = self.causal if causal is None else causal
|
||||||
|
|
||||||
|
if n == 1 and causal:
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
if h != kv_h:
|
||||||
|
# Repeat interleave kv_heads to match q_heads
|
||||||
|
heads_per_kv_head = h // kv_h
|
||||||
|
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||||
|
|
||||||
|
out = optimized_attention(q, k, v, h, skip_reshape=True)
|
||||||
|
out = self.to_out(out)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = rearrange(mask, 'b n -> b n 1')
|
||||||
|
out = out.masked_fill(~mask, 0.)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class ConformerModule(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
norm_kwargs = {},
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
self.in_norm = LayerNorm(dim, **norm_kwargs)
|
||||||
|
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||||
|
self.glu = GLU(dim, dim, nn.SiLU())
|
||||||
|
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
|
||||||
|
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
|
||||||
|
self.swish = nn.SiLU()
|
||||||
|
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.in_norm(x)
|
||||||
|
x = rearrange(x, 'b n d -> b d n')
|
||||||
|
x = self.pointwise_conv(x)
|
||||||
|
x = rearrange(x, 'b d n -> b n d')
|
||||||
|
x = self.glu(x)
|
||||||
|
x = rearrange(x, 'b n d -> b d n')
|
||||||
|
x = self.depthwise_conv(x)
|
||||||
|
x = rearrange(x, 'b d n -> b n d')
|
||||||
|
x = self.mid_norm(x)
|
||||||
|
x = self.swish(x)
|
||||||
|
x = rearrange(x, 'b n d -> b d n')
|
||||||
|
x = self.pointwise_conv_2(x)
|
||||||
|
x = rearrange(x, 'b d n -> b n d')
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_heads = 64,
|
||||||
|
cross_attend = False,
|
||||||
|
dim_context = None,
|
||||||
|
global_cond_dim = None,
|
||||||
|
causal = False,
|
||||||
|
zero_init_branch_outputs = True,
|
||||||
|
conformer = False,
|
||||||
|
layer_ix = -1,
|
||||||
|
remove_norms = False,
|
||||||
|
attn_kwargs = {},
|
||||||
|
ff_kwargs = {},
|
||||||
|
norm_kwargs = {},
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.dim_heads = dim_heads
|
||||||
|
self.cross_attend = cross_attend
|
||||||
|
self.dim_context = dim_context
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
|
self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
||||||
|
|
||||||
|
self.self_attn = Attention(
|
||||||
|
dim,
|
||||||
|
dim_heads = dim_heads,
|
||||||
|
causal = causal,
|
||||||
|
zero_init_output=zero_init_branch_outputs,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
**attn_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if cross_attend:
|
||||||
|
self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
||||||
|
self.cross_attn = Attention(
|
||||||
|
dim,
|
||||||
|
dim_heads = dim_heads,
|
||||||
|
dim_context=dim_context,
|
||||||
|
causal = causal,
|
||||||
|
zero_init_output=zero_init_branch_outputs,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
**attn_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
||||||
|
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs)
|
||||||
|
|
||||||
|
self.layer_ix = layer_ix
|
||||||
|
|
||||||
|
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
|
||||||
|
|
||||||
|
self.global_cond_dim = global_cond_dim
|
||||||
|
|
||||||
|
if global_cond_dim is not None:
|
||||||
|
self.to_scale_shift_gate = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(global_cond_dim, dim * 6, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
nn.init.zeros_(self.to_scale_shift_gate[1].weight)
|
||||||
|
#nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
context = None,
|
||||||
|
global_cond=None,
|
||||||
|
mask = None,
|
||||||
|
context_mask = None,
|
||||||
|
rotary_pos_emb = None
|
||||||
|
):
|
||||||
|
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
||||||
|
|
||||||
|
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
|
||||||
|
|
||||||
|
# self-attention with adaLN
|
||||||
|
residual = x
|
||||||
|
x = self.pre_norm(x)
|
||||||
|
x = x * (1 + scale_self) + shift_self
|
||||||
|
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||||
|
x = x * torch.sigmoid(1 - gate_self)
|
||||||
|
x = x + residual
|
||||||
|
|
||||||
|
if context is not None:
|
||||||
|
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||||
|
|
||||||
|
if self.conformer is not None:
|
||||||
|
x = x + self.conformer(x)
|
||||||
|
|
||||||
|
# feedforward with adaLN
|
||||||
|
residual = x
|
||||||
|
x = self.ff_norm(x)
|
||||||
|
x = x * (1 + scale_ff) + shift_ff
|
||||||
|
x = self.ff(x)
|
||||||
|
x = x * torch.sigmoid(1 - gate_ff)
|
||||||
|
x = x + residual
|
||||||
|
|
||||||
|
else:
|
||||||
|
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||||
|
|
||||||
|
if context is not None:
|
||||||
|
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||||
|
|
||||||
|
if self.conformer is not None:
|
||||||
|
x = x + self.conformer(x)
|
||||||
|
|
||||||
|
x = x + self.ff(self.ff_norm(x))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class ContinuousTransformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
depth,
|
||||||
|
*,
|
||||||
|
dim_in = None,
|
||||||
|
dim_out = None,
|
||||||
|
dim_heads = 64,
|
||||||
|
cross_attend=False,
|
||||||
|
cond_token_dim=None,
|
||||||
|
global_cond_dim=None,
|
||||||
|
causal=False,
|
||||||
|
rotary_pos_emb=True,
|
||||||
|
zero_init_branch_outputs=True,
|
||||||
|
conformer=False,
|
||||||
|
use_sinusoidal_emb=False,
|
||||||
|
use_abs_pos_emb=False,
|
||||||
|
abs_pos_emb_max_length=10000,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.depth = depth
|
||||||
|
self.causal = causal
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
|
||||||
|
self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity()
|
||||||
|
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
|
||||||
|
|
||||||
|
if rotary_pos_emb:
|
||||||
|
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
|
||||||
|
else:
|
||||||
|
self.rotary_pos_emb = None
|
||||||
|
|
||||||
|
self.use_sinusoidal_emb = use_sinusoidal_emb
|
||||||
|
if use_sinusoidal_emb:
|
||||||
|
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
||||||
|
|
||||||
|
self.use_abs_pos_emb = use_abs_pos_emb
|
||||||
|
if use_abs_pos_emb:
|
||||||
|
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
|
||||||
|
|
||||||
|
for i in range(depth):
|
||||||
|
self.layers.append(
|
||||||
|
TransformerBlock(
|
||||||
|
dim,
|
||||||
|
dim_heads = dim_heads,
|
||||||
|
cross_attend = cross_attend,
|
||||||
|
dim_context = cond_token_dim,
|
||||||
|
global_cond_dim = global_cond_dim,
|
||||||
|
causal = causal,
|
||||||
|
zero_init_branch_outputs = zero_init_branch_outputs,
|
||||||
|
conformer=conformer,
|
||||||
|
layer_ix=i,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
mask = None,
|
||||||
|
prepend_embeds = None,
|
||||||
|
prepend_mask = None,
|
||||||
|
global_cond = None,
|
||||||
|
return_info = False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
batch, seq, device = *x.shape[:2], x.device
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"hidden_states": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
x = self.project_in(x)
|
||||||
|
|
||||||
|
if prepend_embeds is not None:
|
||||||
|
prepend_length, prepend_dim = prepend_embeds.shape[1:]
|
||||||
|
|
||||||
|
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
|
||||||
|
|
||||||
|
x = torch.cat((prepend_embeds, x), dim = -2)
|
||||||
|
|
||||||
|
if prepend_mask is not None or mask is not None:
|
||||||
|
mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
|
||||||
|
prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
|
||||||
|
|
||||||
|
mask = torch.cat((prepend_mask, mask), dim = -1)
|
||||||
|
|
||||||
|
# Attention layers
|
||||||
|
|
||||||
|
if self.rotary_pos_emb is not None:
|
||||||
|
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
|
||||||
|
else:
|
||||||
|
rotary_pos_emb = None
|
||||||
|
|
||||||
|
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||||
|
x = x + self.pos_emb(x)
|
||||||
|
|
||||||
|
# Iterate over the transformer layers
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||||
|
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
info["hidden_states"].append(x)
|
||||||
|
|
||||||
|
x = self.project_out(x)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class AudioDiffusionTransformer(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
io_channels=64,
|
||||||
|
patch_size=1,
|
||||||
|
embed_dim=1536,
|
||||||
|
cond_token_dim=768,
|
||||||
|
project_cond_tokens=False,
|
||||||
|
global_cond_dim=1536,
|
||||||
|
project_global_cond=True,
|
||||||
|
input_concat_dim=0,
|
||||||
|
prepend_cond_dim=0,
|
||||||
|
depth=24,
|
||||||
|
num_heads=24,
|
||||||
|
transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
|
||||||
|
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
|
||||||
|
audio_model="",
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dtype = dtype
|
||||||
|
self.cond_token_dim = cond_token_dim
|
||||||
|
|
||||||
|
# Timestep embeddings
|
||||||
|
timestep_features_dim = 256
|
||||||
|
|
||||||
|
self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.to_timestep_embed = nn.Sequential(
|
||||||
|
operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
if cond_token_dim > 0:
|
||||||
|
# Conditioning tokens
|
||||||
|
|
||||||
|
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
|
||||||
|
self.to_cond_embed = nn.Sequential(
|
||||||
|
operations.Linear(cond_token_dim, cond_embed_dim, bias=False, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(cond_embed_dim, cond_embed_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cond_embed_dim = 0
|
||||||
|
|
||||||
|
if global_cond_dim > 0:
|
||||||
|
# Global conditioning
|
||||||
|
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
|
||||||
|
self.to_global_embed = nn.Sequential(
|
||||||
|
operations.Linear(global_cond_dim, global_embed_dim, bias=False, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(global_embed_dim, global_embed_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
if prepend_cond_dim > 0:
|
||||||
|
# Prepend conditioning
|
||||||
|
self.to_prepend_embed = nn.Sequential(
|
||||||
|
operations.Linear(prepend_cond_dim, embed_dim, bias=False, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_concat_dim = input_concat_dim
|
||||||
|
|
||||||
|
dim_in = io_channels + self.input_concat_dim
|
||||||
|
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
# Transformer
|
||||||
|
|
||||||
|
self.transformer_type = transformer_type
|
||||||
|
|
||||||
|
self.global_cond_type = global_cond_type
|
||||||
|
|
||||||
|
if self.transformer_type == "continuous_transformer":
|
||||||
|
|
||||||
|
global_dim = None
|
||||||
|
|
||||||
|
if self.global_cond_type == "adaLN":
|
||||||
|
# The global conditioning is projected to the embed_dim already at this point
|
||||||
|
global_dim = embed_dim
|
||||||
|
|
||||||
|
self.transformer = ContinuousTransformer(
|
||||||
|
dim=embed_dim,
|
||||||
|
depth=depth,
|
||||||
|
dim_heads=embed_dim // num_heads,
|
||||||
|
dim_in=dim_in * patch_size,
|
||||||
|
dim_out=io_channels * patch_size,
|
||||||
|
cross_attend = cond_token_dim > 0,
|
||||||
|
cond_token_dim = cond_embed_dim,
|
||||||
|
global_cond_dim=global_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
|
||||||
|
|
||||||
|
self.preprocess_conv = operations.Conv1d(dim_in, dim_in, 1, bias=False, dtype=dtype, device=device)
|
||||||
|
self.postprocess_conv = operations.Conv1d(io_channels, io_channels, 1, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
mask=None,
|
||||||
|
cross_attn_cond=None,
|
||||||
|
cross_attn_cond_mask=None,
|
||||||
|
input_concat_cond=None,
|
||||||
|
global_embed=None,
|
||||||
|
prepend_cond=None,
|
||||||
|
prepend_cond_mask=None,
|
||||||
|
return_info=False,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
if cross_attn_cond is not None:
|
||||||
|
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
|
||||||
|
|
||||||
|
if global_embed is not None:
|
||||||
|
# Project the global conditioning to the embedding dimension
|
||||||
|
global_embed = self.to_global_embed(global_embed)
|
||||||
|
|
||||||
|
prepend_inputs = None
|
||||||
|
prepend_mask = None
|
||||||
|
prepend_length = 0
|
||||||
|
if prepend_cond is not None:
|
||||||
|
# Project the prepend conditioning to the embedding dimension
|
||||||
|
prepend_cond = self.to_prepend_embed(prepend_cond)
|
||||||
|
|
||||||
|
prepend_inputs = prepend_cond
|
||||||
|
if prepend_cond_mask is not None:
|
||||||
|
prepend_mask = prepend_cond_mask
|
||||||
|
|
||||||
|
if input_concat_cond is not None:
|
||||||
|
|
||||||
|
# Interpolate input_concat_cond to the same length as x
|
||||||
|
if input_concat_cond.shape[2] != x.shape[2]:
|
||||||
|
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
||||||
|
|
||||||
|
x = torch.cat([x, input_concat_cond], dim=1)
|
||||||
|
|
||||||
|
# Get the batch of timestep embeddings
|
||||||
|
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None]).to(x.dtype)) # (b, embed_dim)
|
||||||
|
|
||||||
|
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
|
||||||
|
if global_embed is not None:
|
||||||
|
global_embed = global_embed + timestep_embed
|
||||||
|
else:
|
||||||
|
global_embed = timestep_embed
|
||||||
|
|
||||||
|
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
|
||||||
|
if self.global_cond_type == "prepend":
|
||||||
|
if prepend_inputs is None:
|
||||||
|
# Prepend inputs are just the global embed, and the mask is all ones
|
||||||
|
prepend_inputs = global_embed.unsqueeze(1)
|
||||||
|
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
# Prepend inputs are the prepend conditioning + the global embed
|
||||||
|
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
|
||||||
|
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
|
||||||
|
|
||||||
|
prepend_length = prepend_inputs.shape[1]
|
||||||
|
|
||||||
|
x = self.preprocess_conv(x) + x
|
||||||
|
|
||||||
|
x = rearrange(x, "b c t -> b t c")
|
||||||
|
|
||||||
|
extra_args = {}
|
||||||
|
|
||||||
|
if self.global_cond_type == "adaLN":
|
||||||
|
extra_args["global_cond"] = global_embed
|
||||||
|
|
||||||
|
if self.patch_size > 1:
|
||||||
|
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
|
||||||
|
|
||||||
|
if self.transformer_type == "x-transformers":
|
||||||
|
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
|
||||||
|
elif self.transformer_type == "continuous_transformer":
|
||||||
|
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
output, info = output
|
||||||
|
elif self.transformer_type == "mm_transformer":
|
||||||
|
output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
|
||||||
|
|
||||||
|
output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
|
||||||
|
|
||||||
|
if self.patch_size > 1:
|
||||||
|
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
|
||||||
|
|
||||||
|
output = self.postprocess_conv(output) + output
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return output, info
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
context=None,
|
||||||
|
context_mask=None,
|
||||||
|
input_concat_cond=None,
|
||||||
|
global_embed=None,
|
||||||
|
negative_global_embed=None,
|
||||||
|
prepend_cond=None,
|
||||||
|
prepend_cond_mask=None,
|
||||||
|
mask=None,
|
||||||
|
return_info=False,
|
||||||
|
control=None,
|
||||||
|
transformer_options={},
|
||||||
|
**kwargs):
|
||||||
|
return self._forward(
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
cross_attn_cond=context,
|
||||||
|
cross_attn_cond_mask=context_mask,
|
||||||
|
input_concat_cond=input_concat_cond,
|
||||||
|
global_embed=global_embed,
|
||||||
|
prepend_cond=prepend_cond,
|
||||||
|
prepend_cond_mask=prepend_cond_mask,
|
||||||
|
mask=mask,
|
||||||
|
return_info=return_info,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
108
comfy/ldm/audio/embedders.py
Normal file
108
comfy/ldm/audio/embedders.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor, einsum
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
||||||
|
from einops import rearrange
|
||||||
|
import math
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
class LearnedPositionalEmbedding(nn.Module):
|
||||||
|
"""Used for continuous time"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
assert (dim % 2) == 0
|
||||||
|
half_dim = dim // 2
|
||||||
|
self.weights = nn.Parameter(torch.empty(half_dim))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
x = rearrange(x, "b -> b 1")
|
||||||
|
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
|
||||||
|
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
||||||
|
fouriered = torch.cat((x, fouriered), dim=-1)
|
||||||
|
return fouriered
|
||||||
|
|
||||||
|
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
|
||||||
|
return nn.Sequential(
|
||||||
|
LearnedPositionalEmbedding(dim),
|
||||||
|
comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NumberEmbedder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
features: int,
|
||||||
|
dim: int = 256,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.features = features
|
||||||
|
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
|
||||||
|
|
||||||
|
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
|
||||||
|
if not torch.is_tensor(x):
|
||||||
|
device = next(self.embedding.parameters()).device
|
||||||
|
x = torch.tensor(x, device=device)
|
||||||
|
assert isinstance(x, Tensor)
|
||||||
|
shape = x.shape
|
||||||
|
x = rearrange(x, "... -> (...)")
|
||||||
|
embedding = self.embedding(x)
|
||||||
|
x = embedding.view(*shape, self.features)
|
||||||
|
return x # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class Conditioner(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
project_out: bool = False
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
class NumberConditioner(Conditioner):
|
||||||
|
'''
|
||||||
|
Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
|
||||||
|
'''
|
||||||
|
def __init__(self,
|
||||||
|
output_dim: int,
|
||||||
|
min_val: float=0,
|
||||||
|
max_val: float=1
|
||||||
|
):
|
||||||
|
super().__init__(output_dim, output_dim)
|
||||||
|
|
||||||
|
self.min_val = min_val
|
||||||
|
self.max_val = max_val
|
||||||
|
|
||||||
|
self.embedder = NumberEmbedder(features=output_dim)
|
||||||
|
|
||||||
|
def forward(self, floats, device=None):
|
||||||
|
# Cast the inputs to floats
|
||||||
|
floats = [float(x) for x in floats]
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
device = next(self.embedder.parameters()).device
|
||||||
|
|
||||||
|
floats = torch.tensor(floats).to(device)
|
||||||
|
|
||||||
|
floats = floats.clamp(self.min_val, self.max_val)
|
||||||
|
|
||||||
|
normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
|
||||||
|
|
||||||
|
# Cast floats to same type as embedder
|
||||||
|
embedder_dtype = next(self.embedder.parameters()).dtype
|
||||||
|
normalized_floats = normalized_floats.to(embedder_dtype)
|
||||||
|
|
||||||
|
float_embeds = self.embedder(normalized_floats).unsqueeze(1)
|
||||||
|
|
||||||
|
return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
|
||||||
@ -85,22 +85,32 @@ class FeedForward(nn.Module):
|
|||||||
def Normalize(in_channels, dtype=None, device=None):
|
def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
|
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
b, _, dim_head = q.shape
|
if skip_reshape:
|
||||||
dim_head //= heads
|
b, _, _, dim_head = q.shape
|
||||||
|
else:
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
|
||||||
scale = dim_head ** -0.5
|
scale = dim_head ** -0.5
|
||||||
|
|
||||||
h = heads
|
h = heads
|
||||||
q, k, v = map(
|
if skip_reshape:
|
||||||
lambda t: t.unsqueeze(3)
|
q, k, v = map(
|
||||||
.reshape(b, -1, heads, dim_head)
|
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||||
.permute(0, 2, 1, 3)
|
(q, k, v),
|
||||||
.reshape(b * heads, -1, dim_head)
|
)
|
||||||
.contiguous(),
|
else:
|
||||||
(q, k, v),
|
q, k, v = map(
|
||||||
)
|
lambda t: t.unsqueeze(3)
|
||||||
|
.reshape(b, -1, heads, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b * heads, -1, dim_head)
|
||||||
|
.contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
# force cast to fp32 to avoid overflowing
|
# force cast to fp32 to avoid overflowing
|
||||||
if attn_precision == torch.float32:
|
if attn_precision == torch.float32:
|
||||||
@ -137,17 +147,26 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None):
|
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
b, _, dim_head = query.shape
|
if skip_reshape:
|
||||||
dim_head //= heads
|
b, _, _, dim_head = query.shape
|
||||||
|
else:
|
||||||
|
b, _, dim_head = query.shape
|
||||||
|
dim_head //= heads
|
||||||
|
|
||||||
scale = dim_head ** -0.5
|
scale = dim_head ** -0.5
|
||||||
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
|
||||||
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
|
||||||
|
|
||||||
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
if skip_reshape:
|
||||||
|
query = query.reshape(b * heads, -1, dim_head)
|
||||||
|
value = value.reshape(b * heads, -1, dim_head)
|
||||||
|
key = key.reshape(b * heads, -1, dim_head).movedim(1, 2)
|
||||||
|
else:
|
||||||
|
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||||
|
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||||
|
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
||||||
|
|
||||||
|
|
||||||
dtype = query.dtype
|
dtype = query.dtype
|
||||||
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
|
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
|
||||||
@ -199,22 +218,32 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None)
|
|||||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None):
|
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
|
|
||||||
b, _, dim_head = q.shape
|
if skip_reshape:
|
||||||
dim_head //= heads
|
b, _, _, dim_head = q.shape
|
||||||
|
else:
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
|
||||||
scale = dim_head ** -0.5
|
scale = dim_head ** -0.5
|
||||||
|
|
||||||
h = heads
|
h = heads
|
||||||
q, k, v = map(
|
if skip_reshape:
|
||||||
lambda t: t.unsqueeze(3)
|
q, k, v = map(
|
||||||
.reshape(b, -1, heads, dim_head)
|
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||||
.permute(0, 2, 1, 3)
|
(q, k, v),
|
||||||
.reshape(b * heads, -1, dim_head)
|
)
|
||||||
.contiguous(),
|
else:
|
||||||
(q, k, v),
|
q, k, v = map(
|
||||||
)
|
lambda t: t.unsqueeze(3)
|
||||||
|
.reshape(b, -1, heads, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b * heads, -1, dim_head)
|
||||||
|
.contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
@ -308,9 +337,12 @@ if model_management.xformers_enabled():
|
|||||||
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
|
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
|
||||||
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
|
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
|
||||||
|
|
||||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
|
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
b, _, dim_head = q.shape
|
if skip_reshape:
|
||||||
dim_head //= heads
|
b, _, _, dim_head = q.shape
|
||||||
|
else:
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
|
||||||
disabled_xformers = False
|
disabled_xformers = False
|
||||||
|
|
||||||
@ -325,10 +357,16 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
|
|||||||
if disabled_xformers:
|
if disabled_xformers:
|
||||||
return attention_pytorch(q, k, v, heads, mask)
|
return attention_pytorch(q, k, v, heads, mask)
|
||||||
|
|
||||||
q, k, v = map(
|
if skip_reshape:
|
||||||
lambda t: t.reshape(b, -1, heads, dim_head),
|
q, k, v = map(
|
||||||
(q, k, v),
|
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||||
)
|
(q, k, v),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.reshape(b, -1, heads, dim_head),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
pad = 8 - q.shape[1] % 8
|
pad = 8 - q.shape[1] % 8
|
||||||
@ -338,18 +376,30 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
|
|||||||
|
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||||
|
|
||||||
out = (
|
if skip_reshape:
|
||||||
out.reshape(b, -1, heads * dim_head)
|
out = (
|
||||||
)
|
out.unsqueeze(0)
|
||||||
|
.reshape(b, heads, -1, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = (
|
||||||
|
out.reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
b, _, dim_head = q.shape
|
if skip_reshape:
|
||||||
dim_head //= heads
|
b, _, _, dim_head = q.shape
|
||||||
q, k, v = map(
|
else:
|
||||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
b, _, dim_head = q.shape
|
||||||
(q, k, v),
|
dim_head //= heads
|
||||||
)
|
q, k, v = map(
|
||||||
|
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
out = (
|
out = (
|
||||||
|
|||||||
@ -244,9 +244,9 @@ class TimestepEmbedder(nn.Module):
|
|||||||
half = dim // 2
|
half = dim // 2
|
||||||
freqs = torch.exp(
|
freqs = torch.exp(
|
||||||
-math.log(max_period)
|
-math.log(max_period)
|
||||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
||||||
/ half
|
/ half
|
||||||
).to(device=t.device)
|
)
|
||||||
args = t[:, None].float() * freqs[None]
|
args = t[:, None].float() * freqs[None]
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
if dim % 2:
|
if dim % 2:
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from . import utils
|
from . import utils
|
||||||
|
from . import model_base
|
||||||
|
|
||||||
LORA_CLIP_MAP = {
|
LORA_CLIP_MAP = {
|
||||||
"mlp.fc1": "mlp_fc1",
|
"mlp.fc1": "mlp_fc1",
|
||||||
@ -29,6 +30,8 @@ def load_lora(lora, to_load):
|
|||||||
|
|
||||||
regular_lora = "{}.lora_up.weight".format(x)
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
|
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||||
|
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||||
A_name = B_name = None
|
A_name = B_name = None
|
||||||
|
|
||||||
@ -40,6 +43,12 @@ def load_lora(lora, to_load):
|
|||||||
elif diffusers_lora in lora.keys():
|
elif diffusers_lora in lora.keys():
|
||||||
A_name = diffusers_lora
|
A_name = diffusers_lora
|
||||||
B_name = "{}_lora.down.weight".format(x)
|
B_name = "{}_lora.down.weight".format(x)
|
||||||
|
elif diffusers2_lora in lora.keys():
|
||||||
|
A_name = diffusers2_lora
|
||||||
|
B_name = "{}.lora_A.weight".format(x)
|
||||||
|
elif diffusers3_lora in lora.keys():
|
||||||
|
A_name = diffusers3_lora
|
||||||
|
B_name = "{}.lora.down.weight".format(x)
|
||||||
elif transformers_lora in lora.keys():
|
elif transformers_lora in lora.keys():
|
||||||
A_name = transformers_lora
|
A_name = transformers_lora
|
||||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||||
@ -163,6 +172,7 @@ def load_lora(lora, to_load):
|
|||||||
for x in lora.keys():
|
for x in lora.keys():
|
||||||
if x not in loaded_keys:
|
if x not in loaded_keys:
|
||||||
logging.warning("lora key not loaded: {}".format(x))
|
logging.warning("lora key not loaded: {}".format(x))
|
||||||
|
|
||||||
return patch_dict
|
return patch_dict
|
||||||
|
|
||||||
def model_lora_keys_clip(model, key_map={}):
|
def model_lora_keys_clip(model, key_map={}):
|
||||||
@ -216,7 +226,8 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
def model_lora_keys_unet(model, key_map={}):
|
def model_lora_keys_unet(model, key_map={}):
|
||||||
sdk = model.state_dict().keys()
|
sd = model.state_dict()
|
||||||
|
sdk = sd.keys()
|
||||||
|
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
@ -237,4 +248,17 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
if diffusers_lora_key.endswith(".to_out.0"):
|
if diffusers_lora_key.endswith(".to_out.0"):
|
||||||
diffusers_lora_key = diffusers_lora_key[:-2]
|
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||||
key_map[diffusers_lora_key] = unet_key
|
key_map[diffusers_lora_key] = unet_key
|
||||||
|
|
||||||
|
if isinstance(model, model_base.SD3): #Diffusers lora SD3
|
||||||
|
for i in range(model.model_config.unet_config.get("depth", 0)):
|
||||||
|
k = "transformer.transformer_blocks.{}.attn.".format(i)
|
||||||
|
qkv = "diffusion_model.joint_blocks.{}.x_block.attn.qkv.weight".format(i)
|
||||||
|
proj = "diffusion_model.joint_blocks.{}.x_block.attn.proj.weight".format(i)
|
||||||
|
if qkv in sd:
|
||||||
|
offset = sd[qkv].shape[0] // 3
|
||||||
|
key_map["{}to_q".format(k)] = (qkv, (0, 0, offset))
|
||||||
|
key_map["{}to_k".format(k)] = (qkv, (0, offset, offset))
|
||||||
|
key_map["{}to_v".format(k)] = (qkv, (0, offset * 2, offset))
|
||||||
|
key_map["{}to_out.0".format(k)] = proj
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|||||||
@ -1,5 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from .ldm.audio.dit import AudioDiffusionTransformer
|
||||||
|
from .ldm.audio.embedders import NumberConditioner
|
||||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
@ -12,6 +15,7 @@ from .ldm.cascade.stage_b import StageB
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
|
import math
|
||||||
|
|
||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
EPS = 1
|
EPS = 1
|
||||||
@ -20,9 +24,10 @@ class ModelType(Enum):
|
|||||||
STABLE_CASCADE = 4
|
STABLE_CASCADE = 4
|
||||||
EDM = 5
|
EDM = 5
|
||||||
FLOW = 6
|
FLOW = 6
|
||||||
|
V_PREDICTION_CONTINUOUS = 7
|
||||||
|
|
||||||
|
|
||||||
from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, CONST, ModelSamplingDiscreteFlow
|
from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, CONST, ModelSamplingDiscreteFlow, ModelSamplingContinuousV
|
||||||
|
|
||||||
|
|
||||||
def model_sampling(model_config, model_type):
|
def model_sampling(model_config, model_type):
|
||||||
@ -45,6 +50,9 @@ def model_sampling(model_config, model_type):
|
|||||||
elif model_type == ModelType.FLOW:
|
elif model_type == ModelType.FLOW:
|
||||||
c = CONST
|
c = CONST
|
||||||
s = ModelSamplingDiscreteFlow
|
s = ModelSamplingDiscreteFlow
|
||||||
|
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
|
||||||
|
c = V_PREDICTION
|
||||||
|
s = ModelSamplingContinuousV
|
||||||
|
|
||||||
class ModelSampling(s, c):
|
class ModelSampling(s, c):
|
||||||
pass
|
pass
|
||||||
@ -67,6 +75,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
operations = ops.disable_weight_init
|
operations = ops.disable_weight_init
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
|
if model_management.force_channels_last():
|
||||||
|
# todo: ???
|
||||||
|
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||||
|
logging.debug("using channels last mode for diffusion model")
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.model_sampling = model_sampling(model_config, model_type)
|
self.model_sampling = model_sampling(model_config, model_type)
|
||||||
|
|
||||||
@ -234,11 +246,11 @@ class BaseModel(torch.nn.Module):
|
|||||||
if self.manual_cast_dtype is not None:
|
if self.manual_cast_dtype is not None:
|
||||||
dtype = self.manual_cast_dtype
|
dtype = self.manual_cast_dtype
|
||||||
#TODO: this needs to be tweaked
|
#TODO: this needs to be tweaked
|
||||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
area = input_shape[0] * math.prod(input_shape[2:])
|
||||||
return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024)
|
return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024)
|
||||||
else:
|
else:
|
||||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
area = input_shape[0] * math.prod(input_shape[2:])
|
||||||
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
||||||
|
|
||||||
|
|
||||||
@ -591,3 +603,33 @@ class SD3(BaseModel):
|
|||||||
else:
|
else:
|
||||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||||
return (area * 0.3) * (1024 * 1024)
|
return (area * 0.3) * (1024 * 1024)
|
||||||
|
|
||||||
|
|
||||||
|
class StableAudio1(BaseModel):
|
||||||
|
def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=AudioDiffusionTransformer)
|
||||||
|
self.seconds_start_embedder = NumberConditioner(768, min_val=0, max_val=512)
|
||||||
|
self.seconds_total_embedder = NumberConditioner(768, min_val=0, max_val=512)
|
||||||
|
self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights)
|
||||||
|
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
device = kwargs["device"]
|
||||||
|
|
||||||
|
seconds_start = kwargs.get("seconds_start", 0)
|
||||||
|
seconds_total = kwargs.get("seconds_total", int(noise.shape[-1] / 21.53))
|
||||||
|
|
||||||
|
seconds_start_embed = self.seconds_start_embedder([seconds_start])[0].to(device)
|
||||||
|
seconds_total_embed = self.seconds_total_embedder([seconds_total])[0].to(device)
|
||||||
|
|
||||||
|
global_embed = torch.cat([seconds_start_embed, seconds_total_embed], dim=-1).reshape((1, -1))
|
||||||
|
out['global_embed'] = conds.CONDRegular(global_embed)
|
||||||
|
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
cross_attn = torch.cat([cross_attn.to(device), seconds_start_embed.repeat((cross_attn.shape[0], 1, 1)), seconds_total_embed.repeat((cross_attn.shape[0], 1, 1))], dim=1)
|
||||||
|
out['c_crossattn'] = conds.CONDRegular(cross_attn)
|
||||||
|
return out
|
||||||
|
|||||||
@ -96,6 +96,11 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
|
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
|
||||||
|
unet_config = {}
|
||||||
|
unet_config["audio_model"] = "dit1.0"
|
||||||
|
return unet_config
|
||||||
|
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"use_checkpoint": False,
|
"use_checkpoint": False,
|
||||||
"image_size": 32,
|
"image_size": 32,
|
||||||
@ -236,6 +241,13 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
else:
|
else:
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
def unet_prefix_from_state_dict(state_dict):
|
||||||
|
if "model.model.postprocess_conv.weight" in state_dict: #audio models
|
||||||
|
unet_key_prefix = "model.model."
|
||||||
|
else:
|
||||||
|
unet_key_prefix = "model.diffusion_model."
|
||||||
|
return unet_key_prefix
|
||||||
|
|
||||||
def convert_config(unet_config):
|
def convert_config(unet_config):
|
||||||
new_config = unet_config.copy()
|
new_config = unet_config.copy()
|
||||||
num_res_blocks = new_config.get("num_res_blocks", None)
|
num_res_blocks = new_config.get("num_res_blocks", None)
|
||||||
|
|||||||
@ -198,7 +198,7 @@ if args.use_pytorch_cross_attention:
|
|||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
|
|
||||||
VAE_DTYPE = torch.float32
|
VAE_DTYPES = [torch.float32]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia() or is_amd():
|
if is_nvidia() or is_amd():
|
||||||
@ -207,7 +207,7 @@ try:
|
|||||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
||||||
VAE_DTYPE = torch.bfloat16
|
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
@ -215,17 +215,10 @@ except:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
VAE_DTYPE = torch.bfloat16
|
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
||||||
|
|
||||||
if args.cpu_vae:
|
if args.cpu_vae:
|
||||||
VAE_DTYPE = torch.float32
|
VAE_DTYPES = [torch.float32]
|
||||||
|
|
||||||
if args.fp16_vae:
|
|
||||||
VAE_DTYPE = torch.float16
|
|
||||||
elif args.bf16_vae:
|
|
||||||
VAE_DTYPE = torch.bfloat16
|
|
||||||
elif args.fp32_vae:
|
|
||||||
VAE_DTYPE = torch.float32
|
|
||||||
|
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
@ -294,7 +287,6 @@ try:
|
|||||||
except:
|
except:
|
||||||
logging.warning("Could not pick default device.")
|
logging.warning("Could not pick default device.")
|
||||||
|
|
||||||
logging.info("VAE dtype: {}".format(VAE_DTYPE))
|
|
||||||
|
|
||||||
current_loaded_models: List["LoadedModel"] = []
|
current_loaded_models: List["LoadedModel"] = []
|
||||||
|
|
||||||
@ -677,9 +669,22 @@ def vae_offload_device():
|
|||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
def vae_dtype():
|
def vae_dtype(device=None, allowed_dtypes=[]):
|
||||||
global VAE_DTYPE
|
global VAE_DTYPES
|
||||||
return VAE_DTYPE
|
if args.fp16_vae:
|
||||||
|
return torch.float16
|
||||||
|
elif args.bf16_vae:
|
||||||
|
return torch.bfloat16
|
||||||
|
elif args.fp32_vae:
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
|
for d in allowed_dtypes:
|
||||||
|
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
|
||||||
|
return d
|
||||||
|
if d in VAE_DTYPES:
|
||||||
|
return d
|
||||||
|
|
||||||
|
return VAE_DTYPES[0]
|
||||||
|
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
@ -719,6 +724,8 @@ def supports_cast(device, dtype): #TODO
|
|||||||
def device_supports_non_blocking(device):
|
def device_supports_non_blocking(device):
|
||||||
if is_device_mps(device):
|
if is_device_mps(device):
|
||||||
return False # pytorch bug? mps doesn't support non blocking
|
return False # pytorch bug? mps doesn't support non blocking
|
||||||
|
if is_intel_xpu():
|
||||||
|
return False
|
||||||
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
||||||
return False
|
return False
|
||||||
if directml_device:
|
if directml_device:
|
||||||
@ -731,6 +738,12 @@ def device_should_use_non_blocking(device):
|
|||||||
return False
|
return False
|
||||||
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
|
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
|
||||||
|
|
||||||
|
def force_channels_last():
|
||||||
|
if args.force_channels_last:
|
||||||
|
return True
|
||||||
|
|
||||||
|
#TODO
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def cast_to_device(tensor, device, dtype, copy=False):
|
def cast_to_device(tensor, device, dtype, copy=False):
|
||||||
@ -987,7 +1000,7 @@ def unload_all_models():
|
|||||||
|
|
||||||
|
|
||||||
def resolve_lowvram_weight(weight, model, key): # TODO: remove
|
def resolve_lowvram_weight(weight, model, key): # TODO: remove
|
||||||
print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
|
warnings.warn("The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.", category=DeprecationWarning)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -221,11 +221,18 @@ class ModelPatcher(ModelManageable):
|
|||||||
p = set()
|
p = set()
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
for k in patches:
|
for k in patches:
|
||||||
if k in model_sd:
|
offset = None
|
||||||
|
if isinstance(k, str):
|
||||||
|
key = k
|
||||||
|
else:
|
||||||
|
offset = k[1]
|
||||||
|
key = k[0]
|
||||||
|
|
||||||
|
if key in model_sd:
|
||||||
p.add(k)
|
p.add(k)
|
||||||
current_patches = self.patches.get(k, [])
|
current_patches = self.patches.get(key, [])
|
||||||
current_patches.append((strength_patch, patches[k], strength_model))
|
current_patches.append((strength_patch, patches[k], strength_model, offset))
|
||||||
self.patches[k] = current_patches
|
self.patches[key] = current_patches
|
||||||
|
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
return list(p)
|
return list(p)
|
||||||
@ -342,7 +349,7 @@ class ModelPatcher(ModelManageable):
|
|||||||
self.patch_weight_to_device(bias_key, device_to)
|
self.patch_weight_to_device(bias_key, device_to)
|
||||||
m.to(device_to)
|
m.to(device_to)
|
||||||
mem_counter += model_management.module_size(m)
|
mem_counter += model_management.module_size(m)
|
||||||
logging.debug("lowvram: loaded module regularly {}".format(m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
|
|
||||||
self.model_lowvram = True
|
self.model_lowvram = True
|
||||||
self.lowvram_patch_counter = patch_counter
|
self.lowvram_patch_counter = patch_counter
|
||||||
@ -353,6 +360,12 @@ class ModelPatcher(ModelManageable):
|
|||||||
strength = p[0]
|
strength = p[0]
|
||||||
v = p[1]
|
v = p[1]
|
||||||
strength_model = p[2]
|
strength_model = p[2]
|
||||||
|
offset = p[3]
|
||||||
|
|
||||||
|
old_weight = None
|
||||||
|
if offset is not None:
|
||||||
|
old_weight = weight
|
||||||
|
weight = weight.narrow(offset[0], offset[1], offset[2])
|
||||||
|
|
||||||
if strength_model != 1.0:
|
if strength_model != 1.0:
|
||||||
weight *= strength_model
|
weight *= strength_model
|
||||||
@ -504,6 +517,9 @@ class ModelPatcher(ModelManageable):
|
|||||||
else:
|
else:
|
||||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||||
|
|
||||||
|
if old_weight is not None:
|
||||||
|
weight = old_weight
|
||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||||
|
|||||||
@ -171,6 +171,14 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
|
|||||||
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
|
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSamplingContinuousV(ModelSamplingContinuousEDM):
|
||||||
|
def timestep(self, sigma):
|
||||||
|
return sigma.atan() / math.pi * 2
|
||||||
|
|
||||||
|
def sigma(self, timestep):
|
||||||
|
return (timestep * math.pi / 2).tan()
|
||||||
|
|
||||||
|
|
||||||
def time_snr_shift(alpha, t):
|
def time_snr_shift(alpha, t):
|
||||||
if alpha == 1.0:
|
if alpha == 1.0:
|
||||||
return t
|
return t
|
||||||
|
|||||||
@ -631,6 +631,8 @@ class VAELoader:
|
|||||||
sdxl_taesd_dec = False
|
sdxl_taesd_dec = False
|
||||||
sd1_taesd_enc = False
|
sd1_taesd_enc = False
|
||||||
sd1_taesd_dec = False
|
sd1_taesd_dec = False
|
||||||
|
sd3_taesd_enc = False
|
||||||
|
sd3_taesd_dec = False
|
||||||
|
|
||||||
for v in approx_vaes:
|
for v in approx_vaes:
|
||||||
if v.startswith("taesd_decoder."):
|
if v.startswith("taesd_decoder."):
|
||||||
@ -641,10 +643,16 @@ class VAELoader:
|
|||||||
sdxl_taesd_dec = True
|
sdxl_taesd_dec = True
|
||||||
elif v.startswith("taesdxl_encoder."):
|
elif v.startswith("taesdxl_encoder."):
|
||||||
sdxl_taesd_enc = True
|
sdxl_taesd_enc = True
|
||||||
|
elif v.startswith("taesd3_decoder."):
|
||||||
|
sd3_taesd_dec = True
|
||||||
|
elif v.startswith("taesd3_encoder."):
|
||||||
|
sd3_taesd_enc = True
|
||||||
if sd1_taesd_dec and sd1_taesd_enc:
|
if sd1_taesd_dec and sd1_taesd_enc:
|
||||||
vaes.append("taesd")
|
vaes.append("taesd")
|
||||||
if sdxl_taesd_dec and sdxl_taesd_enc:
|
if sdxl_taesd_dec and sdxl_taesd_enc:
|
||||||
vaes.append("taesdxl")
|
vaes.append("taesdxl")
|
||||||
|
if sd3_taesd_dec and sd3_taesd_enc:
|
||||||
|
vaes.append("taesd3")
|
||||||
return vaes
|
return vaes
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -665,8 +673,13 @@ class VAELoader:
|
|||||||
|
|
||||||
if name == "taesd":
|
if name == "taesd":
|
||||||
sd_["vae_scale"] = torch.tensor(0.18215)
|
sd_["vae_scale"] = torch.tensor(0.18215)
|
||||||
|
sd_["vae_shift"] = torch.tensor(0.0)
|
||||||
elif name == "taesdxl":
|
elif name == "taesdxl":
|
||||||
sd_["vae_scale"] = torch.tensor(0.13025)
|
sd_["vae_scale"] = torch.tensor(0.13025)
|
||||||
|
sd_["vae_shift"] = torch.tensor(0.0)
|
||||||
|
elif name == "taesd3":
|
||||||
|
sd_["vae_scale"] = torch.tensor(1.5305)
|
||||||
|
sd_["vae_shift"] = torch.tensor(0.0609)
|
||||||
return sd_
|
return sd_
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -679,7 +692,7 @@ class VAELoader:
|
|||||||
|
|
||||||
#TODO: scale factor?
|
#TODO: scale factor?
|
||||||
def load_vae(self, vae_name):
|
def load_vae(self, vae_name):
|
||||||
if vae_name in ["taesd", "taesdxl"]:
|
if vae_name in ["taesd", "taesdxl", "taesd3"]:
|
||||||
sd_ = self.load_taesd(vae_name)
|
sd_ = self.load_taesd(vae_name)
|
||||||
else:
|
else:
|
||||||
vae_path = get_or_download("vae", vae_name, KNOWN_VAES)
|
vae_path = get_or_download("vae", vae_name, KNOWN_VAES)
|
||||||
@ -815,7 +828,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (get_filename_list_with_downloadable("clip", KNOWN_CLIP_MODELS),),
|
return {"required": { "clip_name": (get_filename_list_with_downloadable("clip", KNOWN_CLIP_MODELS),),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio"], ),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
FUNCTION = "load_clip"
|
FUNCTION = "load_clip"
|
||||||
@ -828,6 +841,8 @@ class CLIPLoader:
|
|||||||
clip_type = sd.CLIPType.STABLE_CASCADE
|
clip_type = sd.CLIPType.STABLE_CASCADE
|
||||||
elif type == "sd3":
|
elif type == "sd3":
|
||||||
clip_type = sd.CLIPType.SD3
|
clip_type = sd.CLIPType.SD3
|
||||||
|
elif type == "stable_audio":
|
||||||
|
clip_type = sd.CLIPType.STABLE_AUDIO
|
||||||
else:
|
else:
|
||||||
logging.warning(f"Unknown clip type argument passed: {type} for model {clip_name}")
|
logging.warning(f"Unknown clip type argument passed: {type} for model {clip_name}")
|
||||||
|
|
||||||
|
|||||||
41
comfy/ops.py
41
comfy/ops.py
@ -51,6 +51,20 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
|
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
@ -133,6 +147,27 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input, output_size=None):
|
||||||
|
num_spatial_dims = 1
|
||||||
|
output_padding = self._output_padding(
|
||||||
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||||
|
num_spatial_dims, self.dilation)
|
||||||
|
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.conv_transpose1d(
|
||||||
|
input, weight, bias, self.stride, self.padding,
|
||||||
|
output_padding, self.groups, self.dilation)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def conv_nd(s, dims, *args, **kwargs):
|
def conv_nd(s, dims, *args, **kwargs):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
@ -147,6 +182,9 @@ class manual_cast(disable_weight_init):
|
|||||||
class Linear(disable_weight_init.Linear):
|
class Linear(disable_weight_init.Linear):
|
||||||
comfy_cast_weights = True
|
comfy_cast_weights = True
|
||||||
|
|
||||||
|
class Conv1d(disable_weight_init.Conv1d):
|
||||||
|
comfy_cast_weights = True
|
||||||
|
|
||||||
class Conv2d(disable_weight_init.Conv2d):
|
class Conv2d(disable_weight_init.Conv2d):
|
||||||
comfy_cast_weights = True
|
comfy_cast_weights = True
|
||||||
|
|
||||||
@ -161,3 +199,6 @@ class manual_cast(disable_weight_init):
|
|||||||
|
|
||||||
class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
|
class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
|
||||||
comfy_cast_weights = True
|
comfy_cast_weights = True
|
||||||
|
|
||||||
|
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
||||||
|
comfy_cast_weights = True
|
||||||
|
|||||||
22
comfy/sa_t5.py
Normal file
22
comfy/sa_t5.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
import comfy.t5
|
||||||
|
import os
|
||||||
|
|
||||||
|
class T5BaseModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||||
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||||
|
|
||||||
|
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
|
||||||
|
|
||||||
|
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, clip_name="t5base", tokenizer=T5BaseTokenizer)
|
||||||
|
|
||||||
|
class SAT5Model(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||||
|
super().__init__(device=device, dtype=dtype, clip_name="t5base", clip_model=T5BaseModel, **kwargs)
|
||||||
50
comfy/sd.py
50
comfy/sd.py
@ -23,9 +23,11 @@ from . import utils
|
|||||||
from .ldm.cascade.stage_a import StageA
|
from .ldm.cascade.stage_a import StageA
|
||||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||||
|
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||||
from .t2i_adapter import adapter
|
from .t2i_adapter import adapter
|
||||||
from .taesd import taesd
|
from .taesd import taesd
|
||||||
from . import sd3_clip
|
from . import sd3_clip
|
||||||
|
from . import sa_t5
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(model, sd):
|
def load_model_weights(model, sd):
|
||||||
@ -180,8 +182,10 @@ class VAE:
|
|||||||
self.downscale_ratio = 8
|
self.downscale_ratio = 8
|
||||||
self.upscale_ratio = 8
|
self.upscale_ratio = 8
|
||||||
self.latent_channels = 4
|
self.latent_channels = 4
|
||||||
|
self.output_channels = 3
|
||||||
self.process_input = lambda image: image * 2.0 - 1.0
|
self.process_input = lambda image: image * 2.0 - 1.0
|
||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
if "decoder.mid.block_1.mix_factor" in sd:
|
if "decoder.mid.block_1.mix_factor" in sd:
|
||||||
@ -193,7 +197,8 @@ class VAE:
|
|||||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
||||||
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
||||||
elif "taesd_decoder.1.weight" in sd:
|
elif "taesd_decoder.1.weight" in sd:
|
||||||
self.first_stage_model = taesd.TAESD()
|
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
|
||||||
|
self.first_stage_model = taesd.TAESD(latent_channels=self.latent_channels)
|
||||||
elif "vquantizer.codebook.weight" in sd: # VQGan: stage a of stable cascade
|
elif "vquantizer.codebook.weight" in sd: # VQGan: stage a of stable cascade
|
||||||
self.first_stage_model = StageA()
|
self.first_stage_model = StageA()
|
||||||
self.downscale_ratio = 4
|
self.downscale_ratio = 4
|
||||||
@ -238,6 +243,17 @@ class VAE:
|
|||||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
||||||
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
|
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
|
||||||
|
elif "decoder.layers.0.weight_v" in sd:
|
||||||
|
self.first_stage_model = AudioOobleckVAE()
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
|
||||||
|
self.latent_channels = 64
|
||||||
|
self.output_channels = 2
|
||||||
|
self.upscale_ratio = 2048
|
||||||
|
self.downscale_ratio = 2048
|
||||||
|
self.process_output = lambda audio: audio
|
||||||
|
self.process_input = lambda audio: audio
|
||||||
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
self.first_stage_model = None
|
self.first_stage_model = None
|
||||||
@ -258,20 +274,21 @@ class VAE:
|
|||||||
self.device = device
|
self.device = device
|
||||||
offload_device = model_management.vae_offload_device()
|
offload_device = model_management.vae_offload_device()
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = model_management.vae_dtype()
|
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
||||||
self.vae_dtype = dtype
|
self.vae_dtype = dtype
|
||||||
self.first_stage_model.to(self.vae_dtype)
|
self.first_stage_model.to(self.vae_dtype)
|
||||||
self.output_device = model_management.intermediate_device()
|
self.output_device = model_management.intermediate_device()
|
||||||
|
|
||||||
self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
|
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||||
|
|
||||||
def vae_encode_crop_pixels(self, pixels):
|
def vae_encode_crop_pixels(self, pixels):
|
||||||
x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio
|
dims = pixels.shape[1:-1]
|
||||||
y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio
|
for d in range(len(dims)):
|
||||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
x = (dims[d] // self.downscale_ratio) * self.downscale_ratio
|
||||||
x_offset = (pixels.shape[1] % self.downscale_ratio) // 2
|
x_offset = (dims[d] % self.downscale_ratio) // 2
|
||||||
y_offset = (pixels.shape[2] % self.downscale_ratio) // 2
|
if x != dims[d]:
|
||||||
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
|
pixels = pixels.narrow(d + 1, x_offset, x)
|
||||||
return pixels
|
return pixels
|
||||||
|
|
||||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap=16):
|
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap=16):
|
||||||
@ -309,7 +326,7 @@ class VAE:
|
|||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
|
|
||||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device)
|
pixel_samples = torch.empty((samples_in.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples_in.shape[2:])), device=self.output_device)
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x + batch_number].to(self.vae_dtype).to(self.device)
|
samples = samples_in[x:x + batch_number].to(self.vae_dtype).to(self.device)
|
||||||
pixel_samples[x:x + batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
pixel_samples[x:x + batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||||
@ -334,7 +351,7 @@ class VAE:
|
|||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
|
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||||
samples[x:x + batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
samples[x:x + batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||||
@ -379,6 +396,7 @@ class CLIPType(Enum):
|
|||||||
STABLE_DIFFUSION = 1
|
STABLE_DIFFUSION = 1
|
||||||
STABLE_CASCADE = 2
|
STABLE_CASCADE = 2
|
||||||
SD3 = 3
|
SD3 = 3
|
||||||
|
STABLE_AUDIO = 4
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@ -418,6 +436,9 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype
|
dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype
|
||||||
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
||||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||||
|
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
|
||||||
|
clip_target.clip = sa_t5.SAT5Model
|
||||||
|
clip_target.tokenizer = sa_t5.SAT5Tokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
@ -489,10 +510,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
_model_patcher = None
|
_model_patcher = None
|
||||||
clip_target = None
|
clip_target = None
|
||||||
|
|
||||||
parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
|
parameters = utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.")
|
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
@ -507,8 +529,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if output_model:
|
if output_model:
|
||||||
inital_load_device = model_management.unet_initial_load_device(parameters, unet_dtype)
|
inital_load_device = model_management.unet_initial_load_device(parameters, unet_dtype)
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||||
model.load_model_weights(sd, "model.diffusion_model.")
|
model.load_model_weights(sd, diffusion_model_prefix)
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
vae_sd = utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
vae_sd = utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from . import sd1_clip
|
|||||||
from . import sd2_clip
|
from . import sd2_clip
|
||||||
from . import sdxl_clip
|
from . import sdxl_clip
|
||||||
from . import sd3_clip
|
from . import sd3_clip
|
||||||
|
from . import sa_t5
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@ -524,7 +525,35 @@ class SD3(supported_models_base.BASE):
|
|||||||
|
|
||||||
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
|
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
|
||||||
|
|
||||||
|
class StableAudio(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"audio_model": "dit1.0",
|
||||||
|
}
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]
|
sampling_settings = {"sigma_max": 500.0, "sigma_min": 0.03}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.StableAudio1
|
||||||
|
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
vae_key_prefix = ["pretransform.model."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
seconds_start_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_start.": ""}, filter_keys=True)
|
||||||
|
seconds_total_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_total.": ""}, filter_keys=True)
|
||||||
|
return model_base.StableAudio1(self, seconds_start_embedder_weights=seconds_start_sd, seconds_total_embedder_weights=seconds_total_sd, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
if k.endswith(".cross_attend_norm.beta") or k.endswith(".ff_norm.beta") or k.endswith(".pre_norm.beta"): #These weights are all zero
|
||||||
|
state_dict.pop(k)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)
|
||||||
|
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -25,18 +25,19 @@ class Block(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.fuse(self.conv(x) + self.skip(x))
|
return self.fuse(self.conv(x) + self.skip(x))
|
||||||
|
|
||||||
def Encoder():
|
def Encoder(latent_channels=4):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
conv(3, 64), Block(64, 64),
|
conv(3, 64), Block(64, 64),
|
||||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
conv(64, 4),
|
conv(64, latent_channels),
|
||||||
)
|
)
|
||||||
|
|
||||||
def Decoder():
|
|
||||||
|
def Decoder(latent_channels=4):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
Clamp(), conv(4, 64), nn.ReLU(),
|
Clamp(), conv(latent_channels, 64), nn.ReLU(),
|
||||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
@ -47,12 +48,13 @@ class TAESD(nn.Module):
|
|||||||
latent_magnitude = 3
|
latent_magnitude = 3
|
||||||
latent_shift = 0.5
|
latent_shift = 0.5
|
||||||
|
|
||||||
def __init__(self, encoder_path=None, decoder_path=None):
|
def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
|
||||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.taesd_encoder = Encoder()
|
self.taesd_encoder = Encoder(latent_channels=latent_channels)
|
||||||
self.taesd_decoder = Decoder()
|
self.taesd_decoder = Decoder(latent_channels=latent_channels)
|
||||||
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
||||||
|
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
|
||||||
if encoder_path is not None:
|
if encoder_path is not None:
|
||||||
self.taesd_encoder.load_state_dict(utils.load_torch_file(encoder_path, safe_load=True))
|
self.taesd_encoder.load_state_dict(utils.load_torch_file(encoder_path, safe_load=True))
|
||||||
if decoder_path is not None:
|
if decoder_path is not None:
|
||||||
@ -69,9 +71,9 @@ class TAESD(nn.Module):
|
|||||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||||
|
|
||||||
def decode(self, x):
|
def decode(self, x):
|
||||||
x_sample = self.taesd_decoder(x * self.vae_scale)
|
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
|
||||||
x_sample = x_sample.sub(0.5).mul(2)
|
x_sample = x_sample.sub(0.5).mul(2)
|
||||||
return x_sample
|
return x_sample
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale
|
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
|
||||||
|
|||||||
@ -2239,7 +2239,7 @@ export class ComfyApp {
|
|||||||
const node = LiteGraph.createNode(data.class_type);
|
const node = LiteGraph.createNode(data.class_type);
|
||||||
node.id = isNaN(+id) ? id : +id;
|
node.id = isNaN(+id) ? id : +id;
|
||||||
node.title = data._meta?.title ?? node.title
|
node.title = data._meta?.title ?? node.title
|
||||||
graph.add(node);
|
app.graph.add(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const id of ids) {
|
for (const id of ids) {
|
||||||
|
|||||||
128
comfy_extras/nodes/nodes_audio.py
Normal file
128
comfy_extras/nodes/nodes_audio.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import torchaudio
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from comfy.cmd import folder_paths
|
||||||
|
import os
|
||||||
|
|
||||||
|
class EmptyLatentAudio:
|
||||||
|
def __init__(self):
|
||||||
|
self.device = comfy.model_management.intermediate_device()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/audio"
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
batch_size = 1
|
||||||
|
latent = torch.zeros([batch_size, 64, 1024], device=self.device)
|
||||||
|
return ({"samples":latent, "type": "audio"}, )
|
||||||
|
|
||||||
|
class VAEEncodeAudio:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/audio"
|
||||||
|
|
||||||
|
def encode(self, vae, audio):
|
||||||
|
t = vae.encode(audio["waveform"].movedim(1, -1))
|
||||||
|
return ({"samples":t}, )
|
||||||
|
|
||||||
|
class VAEDecodeAudio:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||||
|
RETURN_TYPES = ("AUDIO",)
|
||||||
|
FUNCTION = "decode"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/audio"
|
||||||
|
|
||||||
|
def decode(self, vae, samples):
|
||||||
|
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||||
|
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||||
|
|
||||||
|
class SaveAudio:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
self.type = "output"
|
||||||
|
self.prefix_append = ""
|
||||||
|
self.compress_level = 4
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "audio": ("AUDIO", ),
|
||||||
|
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save_audio"
|
||||||
|
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/audio"
|
||||||
|
|
||||||
|
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
|
filename_prefix += self.prefix_append
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
|
results = list()
|
||||||
|
for (batch_number, waveform) in enumerate(audio["waveform"]):
|
||||||
|
#TODO: metadata
|
||||||
|
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||||
|
file = f"{filename_with_batch_num}_{counter:05}_.flac"
|
||||||
|
torchaudio.save(os.path.join(full_output_folder, file), waveform, audio["sample_rate"], format="FLAC")
|
||||||
|
results.append({
|
||||||
|
"filename": file,
|
||||||
|
"subfolder": subfolder,
|
||||||
|
"type": self.type
|
||||||
|
})
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
return { "ui": { "audio": results } }
|
||||||
|
|
||||||
|
class LoadAudio:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
input_dir = folder_paths.get_input_directory()
|
||||||
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||||
|
return {"required": {"audio": [sorted(files), ]}, }
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/audio"
|
||||||
|
|
||||||
|
RETURN_TYPES = ("AUDIO", )
|
||||||
|
FUNCTION = "load"
|
||||||
|
|
||||||
|
def load(self, audio):
|
||||||
|
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||||
|
waveform, sample_rate = torchaudio.load(audio_path)
|
||||||
|
multiplier = 1.0
|
||||||
|
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||||
|
return (audio, )
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def IS_CHANGED(s, audio):
|
||||||
|
image_path = folder_paths.get_annotated_filepath(audio)
|
||||||
|
m = hashlib.sha256()
|
||||||
|
with open(image_path, 'rb') as f:
|
||||||
|
m.update(f.read())
|
||||||
|
return m.digest().hex()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(s, audio):
|
||||||
|
if not folder_paths.exists_annotated_filepath(audio):
|
||||||
|
return "Invalid audio file: {}".format(audio)
|
||||||
|
return True
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"EmptyLatentAudio": EmptyLatentAudio,
|
||||||
|
"VAEEncodeAudio": VAEEncodeAudio,
|
||||||
|
"VAEDecodeAudio": VAEDecodeAudio,
|
||||||
|
"SaveAudio": SaveAudio,
|
||||||
|
"LoadAudio": LoadAudio,
|
||||||
|
}
|
||||||
@ -195,6 +195,36 @@ class ModelSamplingContinuousEDM:
|
|||||||
m.add_object_patch("latent_format", latent_format)
|
m.add_object_patch("latent_format", latent_format)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
|
class ModelSamplingContinuousV:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"sampling": (["v_prediction"],),
|
||||||
|
"sigma_max": ("FLOAT", {"default": 500.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||||
|
"sigma_min": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/model"
|
||||||
|
|
||||||
|
def patch(self, model, sampling, sigma_max, sigma_min):
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
latent_format = None
|
||||||
|
sigma_data = 1.0
|
||||||
|
if sampling == "v_prediction":
|
||||||
|
sampling_type = comfy.model_sampling.V_PREDICTION
|
||||||
|
|
||||||
|
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousV, sampling_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||||
|
model_sampling.set_parameters(sigma_min, sigma_max, sigma_data)
|
||||||
|
m.add_object_patch("model_sampling", model_sampling)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
class RescaleCFG:
|
class RescaleCFG:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -237,6 +267,7 @@ class RescaleCFG:
|
|||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
||||||
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
|
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
|
||||||
|
"ModelSamplingContinuousV": ModelSamplingContinuousV,
|
||||||
"ModelSamplingStableCascade": ModelSamplingStableCascade,
|
"ModelSamplingStableCascade": ModelSamplingStableCascade,
|
||||||
"ModelSamplingSD3": ModelSamplingSD3,
|
"ModelSamplingSD3": ModelSamplingSD3,
|
||||||
"RescaleCFG": RescaleCFG,
|
"RescaleCFG": RescaleCFG,
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
torch
|
torch
|
||||||
torchvision
|
torchvision
|
||||||
|
torchaudio
|
||||||
torchdiffeq>=0.2.3
|
torchdiffeq>=0.2.3
|
||||||
torchsde>=0.2.6
|
torchsde>=0.2.6
|
||||||
einops>=0.6.0
|
einops>=0.6.0
|
||||||
|
|||||||
2
setup.py
2
setup.py
@ -120,7 +120,7 @@ def _is_linux_arm64():
|
|||||||
def dependencies(for_pypi=False, force_nightly: bool = False) -> List[str]:
|
def dependencies(for_pypi=False, force_nightly: bool = False) -> List[str]:
|
||||||
_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements.txt")).readlines()
|
_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements.txt")).readlines()
|
||||||
if for_pypi:
|
if for_pypi:
|
||||||
return [dep for dep in _dependencies if dep not in {"torch", "torchvision"} and "@" not in dep]
|
return [dep for dep in _dependencies if dep not in {"torch", "torchvision", "torchaudio"} and "@" not in dep]
|
||||||
# If we're installing with no build isolation, we can check if torch is already installed in the environment, and if
|
# If we're installing with no build isolation, we can check if torch is already installed in the environment, and if
|
||||||
# so, go ahead and use the version that is already installed.
|
# so, go ahead and use the version that is already installed.
|
||||||
existing_torch: Optional[str]
|
existing_torch: Optional[str]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user