Merge branch 'comfyanonymous:master' into dpr

This commit is contained in:
kali-linex 2023-08-16 01:25:47 +02:00 committed by GitHub
commit caf4ff59cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
48 changed files with 1786 additions and 793 deletions

View File

@ -2,6 +2,13 @@ name: "Windows Release cu118 dependencies 2"
on:
workflow_dispatch:
inputs:
xformers:
description: 'xformers version'
required: true
type: string
default: "xformers"
# push:
# branches:
# - master
@ -17,7 +24,7 @@ jobs:
- shell: bash
run: |
python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir

1
CODEOWNERS Normal file
View File

@ -0,0 +1 @@
* @comfyanonymous

View File

@ -47,6 +47,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes |
| Ctrl + M | Mute/unmute selected nodes |
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| Delete/Backspace | Delete selected nodes |
| Ctrl + Delete/Backspace | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor |
@ -93,8 +94,8 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
This is the command to install the nightly with ROCm 5.5 that supports the 7000 series and might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.5 -r requirements.txt```
This is the command to install the nightly with ROCm 5.6 that supports the 7000 series and might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.6```
### NVIDIA
@ -126,10 +127,10 @@ After this you should have everything installed and can proceed to running Comfy
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
1. Install pytorch. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide.
1. Install pytorch nightly. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide (make sure to install the latest pytorch nightly).
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies).
1. Launch ComfyUI by running `python main.py`.
1. Launch ComfyUI by running `python main.py --force-fp16`. Note that --force-fp16 will only work if you installed the latest pytorch nightly.
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).

View File

@ -13,7 +13,7 @@ from ..ldm.modules.diffusionmodules.util import (
)
from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
from ..ldm.util import exists
@ -57,6 +57,7 @@ class ControlNet(nn.Module):
transformer_depth_middle=None,
):
super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
@ -200,13 +201,7 @@ class ControlNet(nn.Module):
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
@ -259,13 +254,7 @@ class ControlNet(nn.Module):
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint

View File

@ -38,8 +38,14 @@ parser.add_argument("--port", type=int, default=8188, help="Set the listen port.
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
fp_group = parser.add_mutually_exclusive_group()
@ -80,7 +86,12 @@ parser.add_argument("--dont-print-server", action="store_true", help="Don't prin
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
args = parser.parse_args()
if args.windows_standalone_build:
args.auto_launch = True
if args.disable_auto_launch:
args.auto_launch = False

View File

@ -24,8 +24,8 @@ class ClipVisionModel():
return self.model.load_state_dict(sd, strict=False)
def encode_image(self, image):
img = torch.clip((255. * image[0]), 0, 255).round().int()
inputs = self.processor(images=[img], return_tensors="pt")
img = torch.clip((255. * image), 0, 255).round().int()
inputs = self.processor(images=img, return_tensors="pt")
outputs = self.model(**inputs)
return outputs

View File

@ -148,6 +148,10 @@ vae_conversion_map_attn = [
("q.", "query."),
("k.", "key."),
("v.", "value."),
("q.", "to_q."),
("k.", "to_k."),
("v.", "to_v."),
("proj_out.", "to_out.0."),
("proj_out.", "proj_attn."),
]

View File

@ -180,7 +180,6 @@ class NoiseScheduleVP:
def model_wrapper(
model,
sampling_function,
noise_schedule,
model_type="noise",
model_kwargs={},
@ -295,7 +294,7 @@ def model_wrapper(
if t_continuous.reshape((-1,)).shape[0] == 1:
t_continuous = t_continuous.expand((x.shape[0]))
t_input = get_model_input_time(t_continuous)
output = sampling_function(model, x, t_input, **model_kwargs)
output = model(x, t_input, **model_kwargs)
if model_type == "noise":
return output
elif model_type == "x_start":
@ -843,10 +842,12 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
else:
timesteps = sigmas.clone()
for s in range(timesteps.shape[0]):
timesteps[s] = (model.sigma_to_t(timesteps[s]) / 1000) + (1 / len(model.sigmas))
alphas_cumprod = model.inner_model.alphas_cumprod
ns = NoiseScheduleVP('discrete', alphas_cumprod=model.inner_model.alphas_cumprod)
for s in range(timesteps.shape[0]):
timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod))
ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
if image is not None:
img = image * ns.marginal_alpha(timesteps[0])
@ -859,18 +860,15 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
img = noise
if to_zero:
timesteps[-1] = (1 / len(model.sigmas))
timesteps[-1] = (1 / len(alphas_cumprod))
device = noise.device
if model.parameterization == "v":
model_type = "v"
else:
model_type = "noise"
model_type = "noise"
model_fn = model_wrapper(
model.inner_model.inner_model.apply_model,
sampling_function,
model.predict_eps_discrete_timestep,
ns,
model_type=model_type,
guidance_type="uncond",

View File

@ -63,12 +63,17 @@ class DiscreteSchedule(nn.Module):
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
return sampling.append_zero(self.t_to_sigma(t))
def sigma_to_t(self, sigma, quantize=None):
quantize = self.quantize if quantize is None else quantize
def sigma_to_discrete_timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape)
def sigma_to_t(self, sigma, quantize=None):
quantize = self.quantize if quantize is None else quantize
if quantize:
return dists.abs().argmin(dim=0).view(sigma.shape)
return self.sigma_to_discrete_timestep(sigma)
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
@ -85,6 +90,12 @@ class DiscreteSchedule(nn.Module):
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
def predict_eps_discrete_timestep(self, input, t, **kwargs):
if t.dtype != torch.int64 and t.dtype != torch.int32:
t = t.round()
sigma = self.t_to_sigma(t)
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted

View File

@ -3,7 +3,6 @@ import math
from scipy import integrate
import torch
from torch import nn
from torchdiffeq import odeint
import torchsde
from tqdm.auto import trange, tqdm
@ -131,9 +130,9 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
@ -172,9 +171,9 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None,
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
@ -201,9 +200,9 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
@ -287,30 +286,6 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
return x
@torch.no_grad()
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
v = torch.randint_like(x, 2) * 2 - 1
fevals = 0
def ode_fn(sigma, x):
nonlocal fevals
with torch.enable_grad():
x = x[0].detach().requires_grad_()
denoised = model(x, sigma * s_in, **extra_args)
d = to_d(x, sigma, denoised)
fevals += 1
grad = torch.autograd.grad((d * v).sum(), x)[0]
d_ll = (v * grad).flatten(1).sum(1)
return d.detach(), d_ll
x_min = x, x.new_zeros([x.shape[0]])
t = x.new_tensor([sigma_min, sigma_max])
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
latent, delta_ll = sol[0][-1], sol[1][-1]
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
return ll_prior + delta_ll, {'fevals': fevals}
class PIDStepSizeController:
"""A PID controller for ODE adaptive step size control."""
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
@ -656,23 +631,78 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
elif solver_type == 'midpoint':
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
if eta:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
return x
@torch.no_grad()
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE."""
seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
denoised_1, denoised_2 = None, None
h_1, h_2 = None, None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
h_eta = h * (eta + 1)
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
if h_2 is not None:
r0 = h_1 / h
r1 = h_2 / h
d1_0 = (denoised - denoised_1) / r0
d1_1 = (denoised_1 - denoised_2) / r1
d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
d2 = (d1_0 - d1_1) / (r0 + r1)
phi_2 = h_eta.neg().expm1() / h_eta + 1
phi_3 = phi_2 / h_eta - 0.5
x = x + phi_2 * d1 - phi_3 * d2
elif h_1 is not None:
r = h_1 / h
d = (denoised - denoised_1) / r
phi_2 = h_eta.neg().expm1() / h_eta + 1
x = x + phi_2 * d
if eta:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
denoised_1, denoised_2 = denoised, denoised_1
h_1, h_2 = h, h_1
return x
@torch.no_grad()
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad()
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)

View File

@ -14,6 +14,7 @@ class DDIMSampler(object):
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
self.parameterization = kwargs.get("parameterization", "eps")
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
@ -261,7 +262,7 @@ class DDIMSampler(object):
b, *_, device = *x.shape, x.device
if denoise_function is not None:
model_output = denoise_function(self.model.apply_model, x, t, **extra_args)
model_output = denoise_function(x, t, **extra_args)
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
@ -289,13 +290,13 @@ class DDIMSampler(object):
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.model.parameterization == "v":
if self.parameterization == "v":
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", 'not implemented'
assert self.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
@ -309,7 +310,7 @@ class DDIMSampler(object):
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.model.parameterization != "v":
if self.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output

View File

@ -36,7 +36,7 @@ def uniq(arr):
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
return d
def max_neg_value(t):
@ -52,9 +52,9 @@ def init_(tensor):
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None):
def __init__(self, dim_in, dim_out, dtype=None, device=None):
super().__init__()
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype)
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
@ -62,19 +62,19 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
comfy.ops.Linear(dim, inner_dim, dtype=dtype),
comfy.ops.Linear(dim, inner_dim, dtype=dtype, device=device),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim, dtype=dtype)
) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
comfy.ops.Linear(inner_dim, dim_out, dtype=dtype)
comfy.ops.Linear(inner_dim, dim_out, dtype=dtype, device=device)
)
def forward(self, x):
@ -90,8 +90,8 @@ def zero_module(module):
return module
def Normalize(in_channels, dtype=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype)
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
class SpatialSelfAttention(nn.Module):
@ -148,7 +148,7 @@ class SpatialSelfAttention(nn.Module):
class CrossAttentionBirchSan(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@ -156,12 +156,12 @@ class CrossAttentionBirchSan(nn.Module):
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout)
)
@ -245,7 +245,7 @@ class CrossAttentionBirchSan(nn.Module):
class CrossAttentionDoggettx(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@ -253,12 +253,12 @@ class CrossAttentionDoggettx(nn.Module):
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout)
)
@ -343,7 +343,7 @@ class CrossAttentionDoggettx(nn.Module):
return self.to_out(r2)
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@ -351,12 +351,12 @@ class CrossAttention(nn.Module):
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout)
)
@ -399,7 +399,7 @@ class CrossAttention(nn.Module):
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None):
super().__init__()
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads.")
@ -409,11 +409,11 @@ class MemoryEfficientCrossAttention(nn.Module):
self.heads = heads
self.dim_head = dim_head
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout))
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None):
@ -450,7 +450,7 @@ class MemoryEfficientCrossAttention(nn.Module):
return self.to_out(out)
class CrossAttentionPytorch(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@ -458,11 +458,11 @@ class CrossAttentionPytorch(nn.Module):
self.heads = heads
self.dim_head = dim_head
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout))
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None):
@ -508,17 +508,17 @@ else:
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False, dtype=None):
disable_self_attn=False, dtype=None, device=None):
super().__init__()
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype)
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim, dtype=dtype)
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.checkpoint = checkpoint
self.n_heads = n_heads
self.d_head = d_head
@ -648,34 +648,34 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None):
use_checkpoint=True, dtype=None, device=None):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels, dtype=dtype)
self.norm = Normalize(in_channels, dtype=dtype, device=device)
if not use_linear:
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0, dtype=dtype)
padding=0, dtype=dtype, device=device)
else:
self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype)
self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype)
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device)
for d in range(depth)]
)
if not use_linear:
self.proj_out = nn.Conv2d(inner_dim,in_channels,
kernel_size=1,
stride=1,
padding=0, dtype=dtype)
padding=0, dtype=dtype, device=device)
else:
self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype)
self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.use_linear = use_linear
def forward(self, x, context=None, transformer_options={}):

View File

@ -8,6 +8,7 @@ from typing import Optional, Any
from ..attention import MemoryEfficientCrossAttention
from comfy import model_management
import comfy.ops
if model_management.xformers_enabled_vae():
import xformers
@ -48,7 +49,7 @@ class Upsample(nn.Module):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
self.conv = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
@ -67,7 +68,7 @@ class Downsample(nn.Module):
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
self.conv = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
@ -95,30 +96,30 @@ class ResnetBlock(nn.Module):
self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels,
self.conv1 = comfy.ops.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels,
self.temb_proj = comfy.ops.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = torch.nn.Conv2d(out_channels,
self.conv2 = comfy.ops.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
self.conv_shortcut = comfy.ops.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
self.nin_shortcut = comfy.ops.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
@ -188,22 +189,22 @@ class AttnBlock(nn.Module):
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
self.q = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
self.k = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
self.v = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
self.proj_out = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
@ -243,22 +244,22 @@ class MemoryEfficientAttnBlock(nn.Module):
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
self.q = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
self.k = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
self.v = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
self.proj_out = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
@ -302,22 +303,22 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
self.q = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
self.k = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
self.v = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
self.proj_out = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
@ -399,14 +400,14 @@ class Model(nn.Module):
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList([
torch.nn.Linear(self.ch,
comfy.ops.Linear(self.ch,
self.temb_ch),
torch.nn.Linear(self.temb_ch,
comfy.ops.Linear(self.temb_ch,
self.temb_ch),
])
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.conv_in = comfy.ops.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
@ -475,7 +476,7 @@ class Model(nn.Module):
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
self.conv_out = comfy.ops.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
@ -548,7 +549,7 @@ class Encoder(nn.Module):
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.conv_in = comfy.ops.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
@ -593,7 +594,7 @@ class Encoder(nn.Module):
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
self.conv_out = comfy.ops.Conv2d(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
@ -653,7 +654,7 @@ class Decoder(nn.Module):
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels,
self.conv_in = comfy.ops.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
@ -695,7 +696,7 @@ class Decoder(nn.Module):
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
self.conv_out = comfy.ops.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,

View File

@ -19,45 +19,6 @@ from ..attention import SpatialTransformer
from comfy.ldm.util import exists
# dummy replace
def convert_module_to_f16(x):
pass
def convert_module_to_f32(x):
pass
## go
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def __init__(
self,
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
self.attention = QKVAttention(self.num_heads)
def forward(self, x):
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1) # NC(HW)
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
return x[:, :, 0]
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
@ -111,14 +72,14 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype)
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device)
def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels
@ -138,19 +99,6 @@ class Upsample(nn.Module):
x = self.conv(x)
return x
class TransposedUpsample(nn.Module):
'Learned 2x upsampling without padding'
def __init__(self, channels, out_channels=None, ks=5):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
def forward(self,x):
return self.up(x)
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
@ -160,7 +108,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@ -169,7 +117,7 @@ class Downsample(nn.Module):
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device
)
else:
assert self.channels == self.out_channels
@ -208,7 +156,8 @@ class ResBlock(TimestepBlock):
use_checkpoint=False,
up=False,
down=False,
dtype=None
dtype=None,
device=None,
):
super().__init__()
self.channels = channels
@ -220,19 +169,19 @@ class ResBlock(TimestepBlock):
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype),
nn.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims, dtype=dtype)
self.x_upd = Upsample(channels, False, dims, dtype=dtype)
self.h_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
self.x_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
elif down:
self.h_upd = Downsample(channels, False, dims, dtype=dtype)
self.x_upd = Downsample(channels, False, dims, dtype=dtype)
self.h_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
self.x_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
else:
self.h_upd = self.x_upd = nn.Identity()
@ -240,15 +189,15 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
),
)
self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels, dtype=dtype),
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype)
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
),
)
@ -256,10 +205,10 @@ class ResBlock(TimestepBlock):
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1, dtype=dtype
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype)
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
def forward(self, x, emb):
"""
@ -295,142 +244,6 @@ class ResBlock(TimestepBlock):
h = self.out_layers(h)
return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
if use_new_attention_order:
# split qkv before split heads
self.attention = QKVAttention(self.num_heads)
else:
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
#return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
def count_flops_attn(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class Timestep(nn.Module):
def __init__(self, dim):
super().__init__()
@ -503,8 +316,10 @@ class UNetModel(nn.Module):
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
device=None,
):
super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
@ -564,9 +379,9 @@ class UNetModel(nn.Module):
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim, dtype=self.dtype),
linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
)
if self.num_classes is not None:
@ -579,9 +394,9 @@ class UNetModel(nn.Module):
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim, dtype=self.dtype),
linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
)
)
else:
@ -590,7 +405,7 @@ class UNetModel(nn.Module):
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype)
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
)
]
)
@ -609,7 +424,8 @@ class UNetModel(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype
dtype=self.dtype,
device=device,
)
]
ch = mult * model_channels
@ -628,17 +444,10 @@ class UNetModel(nn.Module):
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
layers.append(SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
@ -657,11 +466,12 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
dtype=self.dtype
dtype=self.dtype,
device=device,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device
)
)
)
@ -686,18 +496,13 @@ class UNetModel(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype
dtype=self.dtype,
device=device,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
),
ResBlock(
ch,
@ -706,7 +511,8 @@ class UNetModel(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype
dtype=self.dtype,
device=device,
),
)
self._feature_size += ch
@ -724,7 +530,8 @@ class UNetModel(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype
dtype=self.dtype,
device=device,
)
]
ch = model_channels * mult
@ -744,16 +551,10 @@ class UNetModel(nn.Module):
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
)
)
if level and i == self.num_res_blocks[level]:
@ -768,43 +569,28 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
dtype=self.dtype
dtype=self.dtype,
device=device,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype)
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype),
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype),
conv_nd(dims, model_channels, n_embed, 1),
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
"""
Apply the model to an input batch.

View File

@ -4,27 +4,27 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import numpy as np
from enum import Enum
from . import utils
class ModelType(Enum):
EPS = 1
V_PREDICTION = 2
class BaseModel(torch.nn.Module):
def __init__(self, model_config, v_prediction=False):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__()
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config)
self.v_prediction = v_prediction
if self.v_prediction:
self.parameterization = "v"
else:
self.parameterization = "eps"
self.diffusion_model = UNetModel(**unet_config, device=device)
self.model_type = model_type
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
print("v_prediction", v_prediction)
print("model_type", model_type.name)
print("adm", self.adm_channels)
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
@ -99,53 +99,58 @@ class BaseModel(torch.nn.Module):
if self.get_dtype() == torch.float16:
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([])
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
adm_inputs = []
weights = []
noise_aug = []
for unclip_cond in unclip_conditioning:
for adm_cond in unclip_cond["clip_vision_output"].image_embeds:
weight = unclip_cond["strength"]
noise_augment = unclip_cond["noise_augmentation"]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out)
if len(noise_aug) > 1:
adm_out = torch.stack(adm_inputs).sum(0)
noise_augment = noise_augment_merge
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1)
return adm_out
class SD21UNCLIP(BaseModel):
def __init__(self, model_config, noise_aug_config, v_prediction=True):
super().__init__(model_config, v_prediction)
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
def encode_adm(self, **kwargs):
unclip_conditioning = kwargs.get("unclip_conditioning", None)
device = kwargs["device"]
if unclip_conditioning is not None:
adm_inputs = []
weights = []
noise_aug = []
for unclip_cond in unclip_conditioning:
adm_cond = unclip_cond["clip_vision_output"].image_embeds
weight = unclip_cond["strength"]
noise_augment = unclip_cond["noise_augmentation"]
noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = self.noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out)
if len(noise_aug) > 1:
adm_out = torch.stack(adm_inputs).sum(0)
#TODO: add a way to control this
noise_augment = 0.05
noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = self.noise_augmentor(adm_out[:, :self.noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1)
if unclip_conditioning is None:
return torch.zeros((1, self.adm_channels))
else:
adm_out = torch.zeros((1, self.adm_channels))
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05))
return adm_out
class SDInpaint(BaseModel):
def __init__(self, model_config, v_prediction=False):
super().__init__(model_config, v_prediction)
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
self.concat_keys = ("mask", "masked_image")
class SDXLRefiner(BaseModel):
def __init__(self, model_config, v_prediction=False):
super().__init__(model_config, v_prediction)
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):
@ -160,7 +165,6 @@ class SDXLRefiner(BaseModel):
else:
aesthetic_score = kwargs.get("aesthetic_score", 6)
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
out = []
out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([width])))
@ -171,8 +175,8 @@ class SDXLRefiner(BaseModel):
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SDXL(BaseModel):
def __init__(self, model_config, v_prediction=False):
super().__init__(model_config, v_prediction)
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):
@ -184,7 +188,6 @@ class SDXL(BaseModel):
target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height)
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
out = []
out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([width])))

View File

@ -113,8 +113,63 @@ def model_config_from_unet_config(unet_config):
if model_config.matches(unet_config):
return model_config(unet_config)
print("no match", unet_config)
return None
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
return model_config_from_unet_config(unet_config)
def model_config_from_diffusers_unet(state_dict, use_fp16):
match = {}
match["context_dim"] = state_dict["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
match["adm_in_channels"] = None
if "class_embedding.linear_1.weight" in state_dict:
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
elif "add_embedding.linear_1.weight" in state_dict:
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
for unet_config in supported_models:
matches = True
for k in match:
if match[k] != unet_config[k]:
matches = False
break
if matches:
return model_config_from_unet_config(unet_config)
return None

View File

@ -49,6 +49,7 @@ except:
try:
if torch.backends.mps.is_available():
cpu_state = CPUState.MPS
import torch.mps
except:
pass
@ -204,7 +205,11 @@ print(f"Set vram state to: {vram_state.name}")
def get_torch_device_name(device):
if hasattr(device, 'type'):
if device.type == "cuda":
return "{} {}".format(device, torch.cuda.get_device_name(device))
try:
allocator_backend = torch.cuda.get_allocator_backend()
except:
allocator_backend = ""
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
else:
return "{}".format(device.type)
else:
@ -233,10 +238,9 @@ def unload_model():
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
model_accelerated = False
current_loaded_model.unpatch_model()
current_loaded_model.model.to(current_loaded_model.offload_device)
current_loaded_model.model_patches_to(current_loaded_model.offload_device)
current_loaded_model.unpatch_model()
current_loaded_model = None
if vram_state != VRAMState.HIGH_VRAM:
soft_empty_cache()
@ -258,15 +262,11 @@ def load_model_gpu(model):
if model is current_loaded_model:
return
unload_model()
try:
real_model = model.patch_model()
except Exception as e:
model.unpatch_model()
raise e
torch_dev = model.load_device
model.model_patches_to(torch_dev)
model.model_patches_to(model.model_dtype())
current_loaded_model = model
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
@ -280,21 +280,33 @@ def load_model_gpu(model):
if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
current_loaded_model = model
real_model = model.model
patch_model_to = None
if vram_set_state == VRAMState.DISABLED:
pass
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
model_accelerated = False
real_model.to(torch_dev)
else:
if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_set_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
patch_model_to = torch_dev
try:
real_model = model.patch_model(device_to=patch_model_to)
except Exception as e:
model.unpatch_model()
unload_model()
raise e
if patch_model_to is not None:
real_model.to(torch_dev)
if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True
elif vram_set_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True
return current_loaded_model
def load_controlnet_gpu(control_models):
@ -352,6 +364,7 @@ def text_encoder_device():
if args.gpu_only:
return get_torch_device()
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
#NOTE: on a Ryzen 5 7600X with 4080 it's faster to shift to GPU
if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough.
return get_torch_device()
else:
@ -522,7 +535,7 @@ def should_use_fp16(device=None, model_params=0):
return False
#FP16 is just broken on these cards
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600"]
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX"]
for x in nvidia_16_series:
if x in props.name:
return False

View File

@ -6,6 +6,7 @@ from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
import math
from comfy import model_base
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
@ -16,6 +17,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
if 'timestep_start' in cond[1]:
timestep_start = cond[1]['timestep_start']
if timestep_in[0] > timestep_start:
return None
if 'timestep_end' in cond[1]:
timestep_end = cond[1]['timestep_end']
if timestep_in[0] < timestep_end:
return None
if 'area' in cond[1]:
area = cond[1]['area']
if 'strength' in cond[1]:
@ -180,12 +189,13 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
continue
to_run += [(p, COND)]
for x in uncond:
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
if p is None:
continue
if uncond is not None:
for x in uncond:
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
if p is None:
continue
to_run += [(p, UNCOND)]
to_run += [(p, UNCOND)]
while len(to_run) > 0:
first = to_run[0]
@ -247,7 +257,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
c['transformer_options'] = transformer_options
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
model_management.throw_exception_if_processing_interrupted()
@ -270,6 +283,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
max_total_area = model_management.maximum_batch_area()
if math.isclose(cond_scale, 1.0):
uncond = None
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
if "sampler_cfg_function" in model_options:
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
@ -331,6 +347,17 @@ def ddim_scheduler(model, steps):
sigs += [0.0]
return torch.FloatTensor(sigs)
def sgm_scheduler(model, steps):
sigs = []
timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int)
for x in range(len(timesteps)):
ts = timesteps[x]
if ts > 999:
ts = 999
sigs.append(model.t_to_sigma(torch.tensor(ts)))
sigs += [0.0]
return torch.FloatTensor(sigs)
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
@ -424,6 +451,35 @@ def create_cond_with_same_area_if_none(conds, c):
n = c[1].copy()
conds += [[smallest[0], n]]
def calculate_start_end_timesteps(model, conds):
for t in range(len(conds)):
x = conds[t]
timestep_start = None
timestep_end = None
if 'start_percent' in x[1]:
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0)))
if 'end_percent' in x[1]:
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0)))
if (timestep_start is not None) or (timestep_end is not None):
n = x[1].copy()
if (timestep_start is not None):
n['timestep_start'] = timestep_start
if (timestep_end is not None):
n['timestep_end'] = timestep_end
conds[t] = [x[0], n]
def pre_run_control(model, conds):
for t in range(len(conds)):
x = conds[t]
timestep_start = None
timestep_end = None
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
if 'control' in x[1]:
x[1]['control'].pre_run(model.inner_model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
cond_other = []
@ -480,19 +536,19 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
class KSampler:
SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"]
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"]
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model
self.model_denoise = CFGNoisePredictor(self.model)
if self.model.parameterization == "v":
if self.model.model_type == model_base.ModelType.V_PREDICTION:
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
else:
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
self.model_wrap.parameterization = self.model.parameterization
self.model_k = KSamplerX0Inpaint(self.model_wrap)
self.device = device
if scheduler not in self.SCHEDULERS:
@ -525,6 +581,8 @@ class KSampler:
sigmas = simple_scheduler(self.model_wrap, steps)
elif self.scheduler == "ddim_uniform":
sigmas = ddim_scheduler(self.model_wrap, steps)
elif self.scheduler == "sgm_uniform":
sigmas = sgm_scheduler(self.model_wrap, steps)
else:
print("error invalid scheduler", self.scheduler)
@ -567,13 +625,18 @@ class KSampler:
resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
calculate_start_end_timesteps(self.model_wrap, negative)
calculate_start_end_timesteps(self.model_wrap, positive)
#make sure each cond area has an opposite one with the same area
for c in positive:
create_cond_with_same_area_if_none(negative, c)
for c in negative:
create_cond_with_same_area_if_none(positive, c)
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
pre_run_control(self.model_wrap, negative + positive)
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if self.model.is_adm():
@ -614,7 +677,7 @@ class KSampler:
elif self.sampler == "ddim":
timesteps = []
for s in range(sigmas.shape[0]):
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
timesteps.insert(0, self.model_wrap.sigma_to_discrete_timestep(sigmas[s]))
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask
@ -638,7 +701,7 @@ class KSampler:
x_T=z_enc,
x0=latent_image,
img_callback=ddim_callback,
denoise_function=sampling_function,
denoise_function=self.model_wrap.predict_eps_discrete_timestep,
extra_args=extra_args,
mask=noise_mask,
to_zero=sigmas[-1]==0,

View File

@ -70,13 +70,27 @@ def load_lora(lora, to_load):
alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name)
A_name = "{}.lora_up.weight".format(x)
B_name = "{}.lora_down.weight".format(x)
mid_name = "{}.lora_mid.weight".format(x)
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None
if A_name in lora.keys():
if regular_lora in lora.keys():
A_name = regular_lora
B_name = "{}.lora_down.weight".format(x)
mid_name = "{}.lora_mid.weight".format(x)
elif diffusers_lora in lora.keys():
A_name = diffusers_lora
B_name = "{}_lora.down.weight".format(x)
mid_name = None
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x)
mid_name = None
if A_name is not None:
mid = None
if mid_name in lora.keys():
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
@ -170,20 +184,31 @@ def model_lora_keys_clip(model, key_map={}):
if k in sdk:
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
key_map[lora_key] = k
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k
clip_l_present = True
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
key_map[lora_key] = k
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
if clip_l_present:
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
key_map[lora_key] = k
else:
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
key_map[lora_key] = k
key_map[lora_key] = k
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
key_map[lora_key] = k
return key_map
@ -198,10 +223,26 @@ def model_lora_keys_unet(model, key_map={}):
diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys:
if k.endswith(".weight"):
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
key_lora = k[:-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
key_map["lora_unet_{}".format(key_lora)] = unet_key
diffusers_lora_prefix = ["", "unet."]
for p in diffusers_lora_prefix:
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
if diffusers_lora_key.endswith(".to_out.0"):
diffusers_lora_key = diffusers_lora_key[:-2]
key_map[diffusers_lora_key] = unet_key
return key_map
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0):
self.size = size
@ -330,7 +371,7 @@ class ModelPatcher:
sd.pop(k)
return sd
def patch_model(self):
def patch_model(self, device_to=None):
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
@ -340,10 +381,14 @@ class ModelPatcher:
weight = model_sd[key]
if key not in self.backup:
self.backup[key] = weight.clone()
self.backup[key] = weight.to(self.offload_device)
temp_weight = weight.to(torch.float32, copy=True)
weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if device_to is not None:
temp_weight = weight.float().to(device_to, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
set_attr(self.model, key, out_weight)
del temp_weight
return self.model
@ -367,15 +412,19 @@ class ModelPatcher:
else:
weight += alpha * w1.type(weight.dtype).to(weight.device)
elif len(v) == 4: #lora/locon
mat1 = v[0]
mat2 = v[1]
mat1 = v[0].float().to(weight.device)
mat2 = v[1].float().to(weight.device)
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
mat3 = v[3].float().to(weight.device)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif len(v) == 8: #lokr
w1 = v[0]
w2 = v[1]
@ -389,20 +438,27 @@ class ModelPatcher:
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(w1_a.float(), w1_b.float())
else:
w1 = w1.float().to(weight.device)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(w2_a.float(), w2_b.float())
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float())
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
else:
w2 = w2.float().to(weight.device)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device)
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
else: #loha
w1a = v[0]
w1b = v[1]
@ -413,21 +469,24 @@ class ModelPatcher:
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float(), w1b.float(), w1a.float())
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2b.float(), w2a.float())
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
else:
m1 = torch.mm(w1a.float(), w1b.float())
m2 = torch.mm(w2a.float(), w2b.float())
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device)
return weight
def unpatch_model(self):
model_sd = self.model_state_dict()
keys = list(self.backup.keys())
for k in keys:
model_sd[k][:] = self.backup[k]
del self.backup[k]
set_attr(self.model, k, self.backup[k])
self.backup = {}
@ -647,16 +706,57 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
else:
return torch.cat([tensor] * batched_number, dim=0)
class ControlNet:
def __init__(self, control_model, global_average_pooling=False, device=None):
self.control_model = control_model
class ControlBase:
def __init__(self, device=None):
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
self.timestep_percent_range = (1.0, 0.0)
self.timestep_range = None
if device is None:
device = model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
self.cond_hint_original = cond_hint
self.strength = strength
self.timestep_percent_range = timestep_percent_range
return self
def pre_run(self, model, percent_to_timestep_function):
self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
if self.previous_controlnet is not None:
self.previous_controlnet.pre_run(model, percent_to_timestep_function)
def set_previous_controlnet(self, controlnet):
self.previous_controlnet = controlnet
return self
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.timestep_range = None
def get_models(self):
out = []
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def copy_to(self, c):
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None):
super().__init__(device)
self.control_model = control_model
self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond, batched_number):
@ -664,6 +764,13 @@ class ControlNet:
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
if control_prev is not None:
return control_prev
else:
return {}
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
@ -711,37 +818,64 @@ class ControlNet:
out['input'] = control_prev['input']
return out
def set_cond_hint(self, cond_hint, strength=1.0):
self.cond_hint_original = cond_hint
self.strength = strength
return self
def set_previous_controlnet(self, controlnet):
self.previous_controlnet = controlnet
return self
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
def copy(self):
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
self.copy_to(c)
return c
def get_models(self):
out = []
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
out = super().get_models()
out.append(self.control_model)
return out
def load_controlnet(ckpt_path, model=None):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
controlnet_config = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
use_fp16 = model_management.should_use_fp16()
controlnet_config = model_detection.model_config_from_diffusers_unet(controlnet_data, use_fp16).unet_config
diffusers_keys = utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
count = 0
loop = True
while loop:
suffix = [".weight", ".bias"]
for s in suffix:
k_in = "controlnet_down_blocks.{}{}".format(count, s)
k_out = "zero_convs.{}.0{}".format(count, s)
if k_in not in controlnet_data:
loop = False
break
diffusers_keys[k_in] = k_out
count += 1
count = 0
loop = True
while loop:
suffix = [".weight", ".bias"]
for s in suffix:
if count == 0:
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
else:
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
k_out = "input_hint_block.{}{}".format(count * 2, s)
if k_in not in controlnet_data:
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
loop = False
diffusers_keys[k_in] = k_out
count += 1
new_sd = {}
for k in diffusers_keys:
if k in controlnet_data:
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
controlnet_data = new_sd
pth_key = 'control_model.zero_convs.0.0.weight'
pth = False
key = 'zero_convs.0.0.weight'
@ -757,11 +891,11 @@ def load_controlnet(ckpt_path, model=None):
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
return net
use_fp16 = model_management.should_use_fp16()
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
if controlnet_config is None:
use_fp16 = model_management.should_use_fp16()
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = 3
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = cldm.ControlNet(**controlnet_config)
if pth:
@ -799,24 +933,25 @@ def load_controlnet(ckpt_path, model=None):
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
return control
class T2IAdapter:
class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, device=None):
super().__init__(device)
self.t2i_model = t2i_model
self.channels_in = channels_in
self.strength = 1.0
if device is None:
device = model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.control_input = None
self.cond_hint_original = None
self.cond_hint = None
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
if control_prev is not None:
return control_prev
else:
return {}
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
@ -861,33 +996,11 @@ class T2IAdapter:
out['output'] = control_prev['output']
return out
def set_cond_hint(self, cond_hint, strength=1.0):
self.cond_hint_original = cond_hint
self.strength = strength
return self
def set_previous_controlnet(self, controlnet):
self.previous_controlnet = controlnet
return self
def copy(self):
c = T2IAdapter(self.t2i_model, self.channels_in)
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
self.copy_to(c)
return c
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
def get_models(self):
out = []
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def load_t2i_adapter(t2i_data):
keys = t2i_data.keys()
@ -997,11 +1110,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if "noise_aug_config" in model_config_params:
noise_aug_config = model_config_params["noise_aug_config"]
v_prediction = False
model_type = model_base.ModelType.EPS
if "parameterization" in model_config_params:
if model_config_params["parameterization"] == "v":
v_prediction = True
model_type = model_base.ModelType.V_PREDICTION
clip = None
vae = None
@ -1021,11 +1134,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
if config['model']["target"].endswith("LatentInpaintDiffusion"):
model = model_base.SDInpaint(model_config, v_prediction=v_prediction)
model = model_base.SDInpaint(model_config, model_type=model_type)
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], v_prediction=v_prediction)
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
else:
model = model_base.BaseModel(model_config, v_prediction=v_prediction)
model = model_base.BaseModel(model_config, model_type=model_type)
if fp16:
model = model.half()
@ -1087,8 +1200,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.")
model = model.to(offload_device)
model = model_config.get_model(sd, "model.diffusion_model.", device=offload_device)
model.load_model_weights(sd, "model.diffusion_model.")
if output_vae:
@ -1117,66 +1229,24 @@ def load_unet(unet_path): #load unet in diffusers format
parameters = calculate_parameters(sd, "")
fp16 = model_management.should_use_fp16(model_params=parameters)
match = {}
match["context_dim"] = sd["down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
match["model_channels"] = sd["conv_in.weight"].shape[0]
match["in_channels"] = sd["conv_in.weight"].shape[1]
match["adm_in_channels"] = None
if "class_embedding.linear_1.weight" in sd:
match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1]
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
if model_config is None:
print("ERROR UNSUPPORTED UNET", unet_path)
return None
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
diffusers_keys = utils.unet_to_diffusers(model_config.unet_config)
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
print("match", match)
for unet_config in supported_models:
matches = True
for k in match:
if match[k] != unet_config[k]:
matches = False
break
if matches:
diffusers_keys = utils.unet_to_diffusers(unet_config)
new_sd = {}
for k in diffusers_keys:
if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
print(diffusers_keys[k], k)
offload_device = model_management.unet_offload_device()
model_config = model_detection.model_config_from_unet_config(unet_config)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
new_sd = {}
for k in diffusers_keys:
if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
print(diffusers_keys[k], k)
offload_device = model_management.unet_offload_device()
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
def save_checkpoint(output_path, model, clip, vae, metadata=None):
try:

View File

@ -91,13 +91,15 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = []
next_new_token = token_dict_size = current_embeds.weight.shape[0]
next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1
embedding_weights = []
for x in tokens:
tokens_temp = []
for y in x:
if isinstance(y, int):
if y == token_dict_size: #EOS token
y = -1
tokens_temp += [y]
else:
if y.shape[0] == current_embeds.weight.shape[1]:
@ -110,15 +112,21 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_temp += [self.empty_tokens[0][-1]]
out_tokens += [tokens_temp]
n = token_dict_size
if len(embedding_weights) > 0:
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
n = token_dict_size
new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1]
for x in embedding_weights:
new_embedding.weight[n] = x
n += 1
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding
self.transformer.set_input_embeddings(new_embedding)
return out_tokens
processed_tokens = []
for x in out_tokens:
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
return processed_tokens
def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()

View File

@ -53,13 +53,13 @@ class SD20(supported_models_base.BASE):
latent_format = latent_formats.SD15
def v_prediction(self, state_dict, prefix=""):
def model_type(self, state_dict, prefix=""):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
out = state_dict[k]
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
return True
return False
return model_base.ModelType.V_PREDICTION
return model_base.ModelType.EPS
def process_clip_state_dict(self, state_dict):
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
@ -109,8 +109,8 @@ class SDXLRefiner(supported_models_base.BASE):
latent_format = latent_formats.SDXL
def get_model(self, state_dict, prefix=""):
return model_base.SDXLRefiner(self)
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXLRefiner(self, device=device)
def process_clip_state_dict(self, state_dict):
keys_to_replace = {}
@ -126,7 +126,8 @@ class SDXLRefiner(supported_models_base.BASE):
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g
@ -145,8 +146,14 @@ class SDXL(supported_models_base.BASE):
latent_format = latent_formats.SDXL
def get_model(self, state_dict, prefix=""):
return model_base.SDXL(self)
def model_type(self, state_dict, prefix=""):
if "v_pred" in state_dict:
return model_base.ModelType.V_PREDICTION
else:
return model_base.ModelType.EPS
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
def process_clip_state_dict(self, state_dict):
keys_to_replace = {}
@ -165,7 +172,8 @@ class SDXL(supported_models_base.BASE):
replace_prefix = {}
keys_to_replace = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
for k in state_dict:
if k.startswith("clip_l"):
state_dict_g[k] = state_dict[k]

View File

@ -41,8 +41,8 @@ class BASE:
return False
return True
def v_prediction(self, state_dict, prefix=""):
return False
def model_type(self, state_dict, prefix=""):
return model_base.ModelType.EPS
def inpaint_model(self):
return self.unet_config["in_channels"] > 4
@ -53,13 +53,13 @@ class BASE:
for x in self.unet_extra_config:
self.unet_config[x] = self.unet_extra_config[x]
def get_model(self, state_dict, prefix=""):
def get_model(self, state_dict, prefix="", device=None):
if self.inpaint_model():
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict, prefix))
return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix), device=device)
elif self.noise_aug_config is not None:
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict, prefix))
return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
else:
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict, prefix))
return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
def process_clip_state_dict(self, state_dict):
return state_dict

View File

@ -4,18 +4,20 @@ import struct
import comfy.checkpoint_pickle
import safetensors.torch
def load_torch_file(ckpt, safe_load=False):
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors"):
sd = safetensors.torch.load_file(ckpt, device="cpu")
sd = safetensors.torch.load_file(ckpt, device=device.type)
else:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False
if safe_load:
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location="cpu", pickle_module=comfy.checkpoint_pickle)
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
@ -118,20 +120,24 @@ UNET_MAP_RESNET = {
}
UNET_MAP_BASIC = {
"label_emb.0.0.weight": "class_embedding.linear_1.weight",
"label_emb.0.0.bias": "class_embedding.linear_1.bias",
"label_emb.0.2.weight": "class_embedding.linear_2.weight",
"label_emb.0.2.bias": "class_embedding.linear_2.bias",
"input_blocks.0.0.weight": "conv_in.weight",
"input_blocks.0.0.bias": "conv_in.bias",
"out.0.weight": "conv_norm_out.weight",
"out.0.bias": "conv_norm_out.bias",
"out.2.weight": "conv_out.weight",
"out.2.bias": "conv_out.bias",
"time_embed.0.weight": "time_embedding.linear_1.weight",
"time_embed.0.bias": "time_embedding.linear_1.bias",
"time_embed.2.weight": "time_embedding.linear_2.weight",
"time_embed.2.bias": "time_embedding.linear_2.bias"
("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("out.0.weight", "conv_norm_out.weight"),
("out.0.bias", "conv_norm_out.bias"),
("out.2.weight", "conv_out.weight"),
("out.2.bias", "conv_out.bias"),
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias")
}
def unet_to_diffusers(unet_config):
@ -206,7 +212,7 @@ def unet_to_diffusers(unet_config):
n += 1
for k in UNET_MAP_BASIC:
diffusers_unet_map[UNET_MAP_BASIC[k]] = k
diffusers_unet_map[k[1]] = k[0]
return diffusers_unet_map

View File

@ -2,6 +2,35 @@ import torch
from nodes import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8):
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
left, top = (x // multiplier, y // multiplier)
right, bottom = (left + source.shape[3], top + source.shape[2],)
if mask is None:
mask = torch.ones_like(source)
else:
mask = mask.clone()
mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = mask.repeat((source.shape[0], source.shape[1], 1, 1))
# calculate the bounds of the source that will be overlapping the destination
# this prevents the source trying to overwrite latent pixels that are out of bounds
# of the destination
visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)
mask = mask[:, :, :visible_height, :visible_width]
inverse_mask = torch.ones_like(mask) - mask
source_portion = mask * source[:, :, :visible_height, :visible_width]
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
destination[:, :, top:bottom, left:right] = source_portion + destination_portion
return destination
class LatentCompositeMasked:
@classmethod
def INPUT_TYPES(s):
@ -25,36 +54,31 @@ class LatentCompositeMasked:
output = destination.copy()
destination = destination["samples"].clone()
source = source["samples"]
output["samples"] = composite(destination, source, x, y, mask, 8)
return (output,)
x = max(-source.shape[3] * 8, min(x, destination.shape[3] * 8))
y = max(-source.shape[2] * 8, min(y, destination.shape[2] * 8))
class ImageCompositeMasked:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"destination": ("IMAGE",),
"source": ("IMAGE",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
},
"optional": {
"mask": ("MASK",),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "composite"
left, top = (x // 8, y // 8)
right, bottom = (left + source.shape[3], top + source.shape[2],)
if mask is None:
mask = torch.ones_like(source)
else:
mask = mask.clone()
mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = mask.repeat((source.shape[0], source.shape[1], 1, 1))
# calculate the bounds of the source that will be overlapping the destination
# this prevents the source trying to overwrite latent pixels that are out of bounds
# of the destination
visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)
mask = mask[:, :, :visible_height, :visible_width]
inverse_mask = torch.ones_like(mask) - mask
source_portion = mask * source[:, :, :visible_height, :visible_width]
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
destination[:, :, top:bottom, left:right] = source_portion + destination_portion
output["samples"] = destination
CATEGORY = "image"
def composite(self, destination, source, x, y, mask = None):
destination = destination.clone().movedim(-1, 1)
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1).movedim(1, -1)
return (output,)
class MaskToImage:
@ -253,6 +277,7 @@ class FeatherMask:
NODE_CLASS_MAPPINGS = {
"LatentCompositeMasked": LatentCompositeMasked,
"ImageCompositeMasked": ImageCompositeMasked,
"MaskToImage": MaskToImage,
"ImageToMask": ImageToMask,
"SolidMask": SolidMask,

View File

@ -1,9 +1,13 @@
import comfy.sd
import comfy.utils
import comfy.model_base
import folder_paths
import json
import os
from comfy.cli_args import args
class ModelMergeSimple:
@classmethod
def INPUT_TYPES(s):
@ -99,10 +103,36 @@ class CheckpointSave:
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {"prompt": prompt_info}
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
metadata = {}
enable_modelspec = True
if isinstance(model.model, comfy.model_base.SDXL):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, comfy.model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
else:
enable_modelspec = False
if enable_modelspec:
metadata["modelspec.sai_model_spec"] = "1.0.0"
metadata["modelspec.implementation"] = "sgm"
metadata["modelspec.title"] = "{} {}".format(filename, counter)
#TODO:
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"
if model.model.model_type == comfy.model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v"
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)

View File

@ -59,8 +59,8 @@ class Blend:
def g(self, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
def gaussian_kernel(kernel_size: int, sigma: float):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
def gaussian_kernel(kernel_size: int, sigma: float, device=None):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
@ -101,7 +101,7 @@ class Blur:
batch_size, height, width, channels = image.shape
kernel_size = blur_radius * 2 + 1
kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1)
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')

View File

@ -37,12 +37,23 @@ class ImageUpscaleWithModel:
device = model_management.get_torch_device()
upscale_model.to(device)
in_img = image.movedim(-1,-3).to(device)
free_memory = model_management.get_free_memory(device)
tile = 512
overlap = 32
oom = True
while oom:
try:
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False
except model_management.OOM_EXCEPTION as e:
tile //= 2
if tile < 128:
raise e
tile = 128 + 64
overlap = 8
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
upscale_model.cpu()
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
return (s,)

84
cuda_malloc.py Normal file
View File

@ -0,0 +1,84 @@
import os
import importlib.util
from comfy.cli_args import args
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names():
if os.name == 'nt':
import ctypes
# Define necessary C structures and types
class DISPLAY_DEVICEA(ctypes.Structure):
_fields_ = [
('cb', ctypes.c_ulong),
('DeviceName', ctypes.c_char * 32),
('DeviceString', ctypes.c_char * 128),
('StateFlags', ctypes.c_ulong),
('DeviceID', ctypes.c_char * 128),
('DeviceKey', ctypes.c_char * 128)
]
# Load user32.dll
user32 = ctypes.windll.user32
# Call EnumDisplayDevicesA
def enum_display_devices():
device_info = DISPLAY_DEVICEA()
device_info.cb = ctypes.sizeof(device_info)
device_index = 0
gpu_names = set()
while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0):
device_index += 1
gpu_names.add(device_info.DeviceString.decode('utf-8'))
return gpu_names
return enum_display_devices()
else:
return set()
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
"GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620",
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000",
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
"GeForce GTX 1650", "GeForce GTX 1630"
}
def cuda_malloc_supported():
try:
names = get_gpu_names()
except:
names = set()
for x in names:
if "NVIDIA" in x:
for b in blacklist:
if b in x:
return False
return True
if not args.cuda_malloc:
try:
version = ""
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
ver_file = os.path.join(folder, "version.py")
if os.path.isfile(ver_file):
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
version = module.__version__
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
args.cuda_malloc = cuda_malloc_supported()
except:
pass
if args.cuda_malloc and not args.disable_cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"
else:
env_var += ",backend:cudaMallocAsync"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var

View File

@ -51,9 +51,10 @@ class Example:
"default": 0,
"min": 0, #Minimum value
"max": 4096, #Maximum value
"step": 64 #Slider's step
"step": 64, #Slider's step
"display": "number" # Cosmetic only: display as "number" or "slider"
}),
"float_field": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"float_field": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "display": "number"}),
"print_to_screen": (["enable", "disable"],),
"string_field": ("STRING", {
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node

View File

@ -6,7 +6,6 @@ import threading
import heapq
import traceback
import gc
import time
import torch
import nodes
@ -43,11 +42,14 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
# check if node wants the lists
intput_is_list = False
input_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
intput_is_list = obj.INPUT_IS_LIST
input_is_list = obj.INPUT_IS_LIST
max_len_input = max([len(x) for x in input_data_all.values()])
if len(input_data_all) == 0:
max_len_input = 0
else:
max_len_input = max([len(x) for x in input_data_all.values()])
# get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i):
@ -57,10 +59,14 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
return d_new
results = []
if intput_is_list:
if input_is_list:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**input_data_all))
elif max_len_input == 0:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)())
else:
for i in range(max_len_input):
if allow_interrupt:

View File

@ -43,6 +43,10 @@ def set_output_directory(output_dir):
global output_directory
output_directory = output_dir
def set_temp_directory(temp_dir):
global temp_directory
temp_directory = temp_dir
def get_output_directory():
global output_directory
return output_directory
@ -111,6 +115,8 @@ def add_model_folder_path(folder_name, full_folder_path):
global folder_names_and_paths
if folder_name in folder_names_and_paths:
folder_names_and_paths[folder_name][0].append(full_folder_path)
else:
folder_names_and_paths[folder_name] = ([full_folder_path], set())
def get_folder_paths(folder_name):
return folder_names_and_paths[folder_name][0][:]

View File

@ -1,6 +1,5 @@
import torch
from PIL import Image, ImageOps
from io import BytesIO
from PIL import Image
import struct
import numpy as np
from comfy.cli_args import args, LatentPreviewMethod
@ -15,26 +14,7 @@ class LatentPreviewer:
def decode_latent_to_preview_image(self, preview_format, x0):
preview_image = self.decode_latent_to_preview(x0)
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), resampling)
preview_type = 1
if preview_format == "JPEG":
preview_type = 1
elif preview_format == "PNG":
preview_type = 2
bytesIO = BytesIO()
header = struct.pack(">I", preview_type)
bytesIO.write(header)
preview_image.save(bytesIO, format=preview_format, quality=95)
preview_bytes = bytesIO.getvalue()
return preview_bytes
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd):

31
main.py
View File

@ -51,7 +51,6 @@ import threading
import gc
from comfy.cli_args import args
import comfy.utils
if os.name == "nt":
import logging
@ -62,7 +61,9 @@ if __name__ == "__main__":
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
print("Set cuda device to:", args.cuda_device)
import cuda_malloc
import comfy.utils
import yaml
import execution
@ -71,6 +72,17 @@ from server import BinaryEventTypes
from nodes import init_custom_nodes
import comfy.model_management
def cuda_malloc_warning():
device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device)
cuda_malloc_warning = False
if "cudaMallocAsync" in device_name:
for b in cuda_malloc.blacklist:
if b in device_name:
cuda_malloc_warning = True
if cuda_malloc_warning:
print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server):
e = execution.PromptExecutor(server)
while True:
@ -91,15 +103,15 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
def hijack_progress(server):
def hook(value, total, preview_image_bytes):
def hook(value, total, preview_image):
server.send_sync("progress", {"value": value, "max": total}, server.client_id)
if preview_image_bytes is not None:
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id)
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook)
def cleanup_temp():
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
temp_dir = folder_paths.get_temp_directory()
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
@ -126,6 +138,10 @@ def load_extra_path_config(yaml_path):
if __name__ == "__main__":
if args.temp_directory:
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
print(f"Setting temp directory to: {temp_dir}")
folder_paths.set_temp_directory(temp_dir)
cleanup_temp()
loop = asyncio.new_event_loop()
@ -142,6 +158,9 @@ if __name__ == "__main__":
load_extra_path_config(config_path)
init_custom_nodes()
cuda_malloc_warning()
server.add_routes()
hijack_progress(server)
@ -159,6 +178,8 @@ if __name__ == "__main__":
if args.auto_launch:
def startup_server(address, port):
import webbrowser
if os.name == 'nt' and address == '0.0.0.0':
address = '127.0.0.1'
webbrowser.open(f"http://{address}:{port}")
call_on_start = startup_server

177
nodes.py
View File

@ -26,6 +26,8 @@ import comfy.utils
import comfy.clip_vision
import comfy.model_management
from comfy.cli_args import args
import importlib
import folder_paths
@ -204,6 +206,28 @@ class ConditioningZeroOut:
c.append(n)
return (c, )
class ConditioningSetTimestepRange:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "set_range"
CATEGORY = "advanced/conditioning"
def set_range(self, conditioning, start, end):
c = []
for t in conditioning:
d = t[1].copy()
d['start_percent'] = 1.0 - start
d['end_percent'] = 1.0 - end
n = [t[0], d]
c.append(n)
return (c, )
class VAEDecode:
@classmethod
def INPUT_TYPES(s):
@ -330,12 +354,22 @@ class SaveLatent:
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {"prompt": prompt_info}
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
metadata = None
if not args.disable_metadata:
metadata = {"prompt": prompt_info}
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
file = f"{filename}_{counter:05}_.latent"
results = list()
results.append({
"filename": file,
"subfolder": subfolder,
"type": "output"
})
file = os.path.join(full_output_folder, file)
output = {}
@ -343,7 +377,7 @@ class SaveLatent:
output["latent_format_version_0"] = torch.tensor([])
comfy.utils.save_torch_file(output, file, metadata=metadata)
return {}
return { "ui": { "latents": results } }
class LoadLatent:
@ -497,7 +531,9 @@ class LoraLoader:
if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1]
else:
del self.loaded_lora
temp = self.loaded_lora
self.loaded_lora = None
del temp
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
@ -578,9 +614,58 @@ class ControlNetApply:
if 'control' in t[1]:
c_net.set_previous_controlnet(t[1]['control'])
n[1]['control'] = c_net
n[1]['control_apply_to_uncond'] = True
c.append(n)
return (c, )
class ControlNetApplyAdvanced:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"control_net": ("CONTROL_NET", ),
"image": ("IMAGE", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "apply_controlnet"
CATEGORY = "conditioning"
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent):
if strength == 0:
return (positive, negative)
control_hint = image.movedim(-1,1)
cnets = {}
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
prev_cnet = d.get('control', None)
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent))
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
d['control'] = c_net
d['control_apply_to_uncond'] = False
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1])
class UNETLoader:
@classmethod
def INPUT_TYPES(s):
@ -686,7 +771,7 @@ class StyleModelApply:
CATEGORY = "conditioning/style_model"
def apply_stylemodel(self, clip_vision_output, style_model, conditioning):
cond = style_model.get_cond(clip_vision_output)
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
c = []
for t in conditioning:
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
@ -970,6 +1055,47 @@ class LatentComposite:
samples_out["samples"] = s
return (samples_out,)
class LatentBlend:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"samples1": ("LATENT",),
"samples2": ("LATENT",),
"blend_factor": ("FLOAT", {
"default": 0.5,
"min": 0,
"max": 1,
"step": 0.01
}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "blend"
CATEGORY = "_for_testing"
def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
samples_out = samples1.copy()
samples1 = samples1["samples"]
samples2 = samples2["samples"]
if samples1.shape != samples2.shape:
samples2.permute(0, 3, 1, 2)
samples2 = comfy.utils.common_upscale(samples2, samples1.shape[3], samples1.shape[2], 'bicubic', crop='center')
samples2.permute(0, 2, 3, 1)
samples_blended = self.blend_mode(samples1, samples2, blend_mode)
samples_blended = samples1 * blend_factor + samples_blended * (1 - blend_factor)
samples_out["samples"] = samples_blended
return (samples_out,)
def blend_mode(self, img1, img2, mode):
if mode == "normal":
return img2
else:
raise ValueError(f"Unsupported blend mode: {mode}")
class LatentCrop:
@classmethod
def INPUT_TYPES(s):
@ -1141,12 +1267,14 @@ class SaveImage:
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = PngInfo()
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
file = f"{filename}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
@ -1320,6 +1448,22 @@ class ImageInvert:
s = 1.0 - image
return (s,)
class ImageBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image1": ("IMAGE",), "image2": ("IMAGE",)}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "batch"
CATEGORY = "image"
def batch(self, image1, image2):
if image1.shape[1:] != image2.shape[1:]:
image2 = comfy.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1)
s = torch.cat((image1, image2), dim=0)
return (s,)
class ImagePadForOutpaint:
@ -1405,6 +1549,7 @@ NODE_CLASS_MAPPINGS = {
"ImageScale": ImageScale,
"ImageScaleBy": ImageScaleBy,
"ImageInvert": ImageInvert,
"ImageBatch": ImageBatch,
"ImagePadForOutpaint": ImagePadForOutpaint,
"ConditioningAverage ": ConditioningAverage ,
"ConditioningCombine": ConditioningCombine,
@ -1414,6 +1559,7 @@ NODE_CLASS_MAPPINGS = {
"KSamplerAdvanced": KSamplerAdvanced,
"SetLatentNoiseMask": SetLatentNoiseMask,
"LatentComposite": LatentComposite,
"LatentBlend": LatentBlend,
"LatentRotate": LatentRotate,
"LatentFlip": LatentFlip,
"LatentCrop": LatentCrop,
@ -1425,6 +1571,7 @@ NODE_CLASS_MAPPINGS = {
"StyleModelApply": StyleModelApply,
"unCLIPConditioning": unCLIPConditioning,
"ControlNetApply": ControlNetApply,
"ControlNetApplyAdvanced": ControlNetApplyAdvanced,
"ControlNetLoader": ControlNetLoader,
"DiffControlNetLoader": DiffControlNetLoader,
"StyleModelLoader": StyleModelLoader,
@ -1442,6 +1589,7 @@ NODE_CLASS_MAPPINGS = {
"SaveLatent": SaveLatent,
"ConditioningZeroOut": ConditioningZeroOut,
"ConditioningSetTimestepRange": ConditioningSetTimestepRange,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@ -1470,6 +1618,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ConditioningSetArea": "Conditioning (Set Area)",
"ConditioningSetMask": "Conditioning (Set Mask)",
"ControlNetApply": "Apply ControlNet",
"ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
# Latent
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
"SetLatentNoiseMask": "Set Latent Noise Mask",
@ -1482,6 +1631,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LatentUpscale": "Upscale Latent",
"LatentUpscaleBy": "Upscale Latent By",
"LatentComposite": "Latent Composite",
"LatentBlend": "Latent Blend",
"LatentFromBatch" : "Latent From Batch",
"RepeatLatentBatch": "Repeat Latent Batch",
# Image
@ -1494,6 +1644,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImageUpscaleWithModel": "Upscale Image (using Model)",
"ImageInvert": "Invert Image",
"ImagePadForOutpaint": "Pad Image for Outpainting",
"ImageBatch": "Batch Images",
# _for_testing
"VAEDecodeTiled": "VAE Decode (Tiled)",
"VAEEncodeTiled": "VAE Encode (Tiled)",

View File

@ -69,6 +69,13 @@
"source": [
"# Checkpoints\n",
"\n",
"### SDXL\n",
"### I recommend these workflow examples: https://comfyanonymous.github.io/ComfyUI_examples/sdxl/\n",
"\n",
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P ./models/checkpoints/\n",
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0.safetensors -P ./models/checkpoints/\n",
"\n",
"\n",
"# SD1.5\n",
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
"\n",
@ -83,7 +90,7 @@
"#!wget -c https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors -P ./models/checkpoints/\n",
"\n",
"# Waifu Diffusion 1.5 (anime style SD2.x 768-v)\n",
"#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta2/resolve/main/checkpoints/wd-1-5-beta2-fp16.safetensors -P ./models/checkpoints/\n",
"#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta3/resolve/main/wd-illusion-fp16.safetensors -P ./models/checkpoints/\n",
"\n",
"\n",
"# unCLIP models\n",
@ -100,6 +107,7 @@
"# Loras\n",
"#!wget -c https://civitai.com/api/download/models/10350 -O ./models/loras/theovercomer8sContrastFix_sd21768.safetensors #theovercomer8sContrastFix SD2.x 768-v\n",
"#!wget -c https://civitai.com/api/download/models/10638 -O ./models/loras/theovercomer8sContrastFix_sd15.safetensors #theovercomer8sContrastFix SD1.x\n",
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors -P ./models/loras/ #SDXL offset noise lora\n",
"\n",
"\n",
"# T2I-Adapter\n",
@ -151,13 +159,64 @@
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kkkkkkkkkkkkkkk"
},
"source": [
"### Run ComfyUI with cloudflared (Recommended Way)\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jjjjjjjjjjjjjj"
},
"outputs": [],
"source": [
"!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb\n",
"!dpkg -i cloudflared-linux-amd64.deb\n",
"\n",
"import subprocess\n",
"import threading\n",
"import time\n",
"import socket\n",
"import urllib.request\n",
"\n",
"def iframe_thread(port):\n",
" while True:\n",
" time.sleep(0.5)\n",
" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
" result = sock.connect_ex(('127.0.0.1', port))\n",
" if result == 0:\n",
" break\n",
" sock.close()\n",
" print(\"\\nComfyUI finished loading, trying to launch cloudflared (if it gets stuck here cloudflared is having issues)\\n\")\n",
"\n",
" p = subprocess.Popen([\"cloudflared\", \"tunnel\", \"--url\", \"http://127.0.0.1:{}\".format(port)], stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
" for line in p.stderr:\n",
" l = line.decode()\n",
" if \"trycloudflare.com \" in l:\n",
" print(\"This is the URL to access ComfyUI:\", l[l.find(\"http\"):], end='')\n",
" #print(l, end='')\n",
"\n",
"\n",
"threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n",
"\n",
"!python main.py --dont-print-server"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kkkkkkkkkkkkkk"
},
"source": [
"### Run ComfyUI with localtunnel (Recommended Way)\n",
"### Run ComfyUI with localtunnel\n",
"\n",
"\n"
]

View File

@ -1,5 +1,4 @@
torch
torchdiffeq
torchsde
einops
transformers>=4.25.1
@ -10,3 +9,4 @@ pyyaml
Pillow
scipy
tqdm
psutil

View File

@ -2,8 +2,12 @@ import json
from urllib import request, parse
import random
#this is the ComfyUI api prompt format. If you want it for a specific workflow you can copy it from the prompt section
#of the image metadata of images generated with ComfyUI
#This is the ComfyUI api prompt format.
#If you want it for a specific workflow you can "enable dev mode options"
#in the settings of the UI (gear beside the "Queue Size: ") this will enable
#a button on the UI to save workflows in api format.
#keep in mind ComfyUI is pre alpha software so this format will change a bit.
#this is the one for the default workflow

View File

@ -8,7 +8,7 @@ import uuid
import json
import glob
import struct
from PIL import Image
from PIL import Image, ImageOps
from io import BytesIO
try:
@ -29,6 +29,7 @@ import comfy.model_management
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
async def send_socket_catch_exception(function, message):
try:
@ -344,6 +345,11 @@ class PromptServer():
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
system_stats = {
"system": {
"os": os.name,
"python_version": sys.version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
},
"devices": [
{
"name": device_name,
@ -498,7 +504,9 @@ class PromptServer():
return prompt_info
async def send(self, event, data, sid=None):
if isinstance(data, (bytes, bytearray)):
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
await self.send_image(data, sid=sid)
elif isinstance(data, (bytes, bytearray)):
await self.send_bytes(event, data, sid)
else:
await self.send_json(event, data, sid)
@ -512,6 +520,30 @@ class PromptServer():
message.extend(data)
return message
async def send_image(self, image_data, sid=None):
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
image = ImageOps.contain(image, (max_size, max_size), resampling)
type_num = 1
if image_type == "JPEG":
type_num = 1
elif image_type == "PNG":
type_num = 2
bytesIO = BytesIO()
header = struct.pack(">I", type_num)
bytesIO.write(header)
image.save(bytesIO, format=image_type, quality=95, compress_level=4)
preview_bytes = bytesIO.getvalue()
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_bytes(self, event, data, sid=None):
message = self.encode_bytes(event, data)

View File

@ -1,4 +1,4 @@
import {app} from "/scripts/app.js";
import {app} from "../../scripts/app.js";
// Adds filtering to combo context menus
@ -27,10 +27,13 @@ const ext = {
const clickedComboValue = currentNode.widgets
.filter(w => w.type === "combo" && w.options.values.length === values.length)
.find(w => w.options.values.every((v, i) => v === values[i]))
.value;
?.value;
let selectedIndex = values.findIndex(v => v === clickedComboValue);
let selectedItem = displayedItems?.[selectedIndex];
let selectedIndex = clickedComboValue ? values.findIndex(v => v === clickedComboValue) : 0;
if (selectedIndex < 0) {
selectedIndex = 0;
}
let selectedItem = displayedItems[selectedIndex];
updateSelected();
// Apply highlighting to the selected item

View File

@ -0,0 +1,25 @@
import { app } from "../../scripts/app.js";
const id = "Comfy.LinkRenderMode";
const ext = {
name: id,
async setup(app) {
app.ui.settings.addSetting({
id,
name: "Link Render Mode",
defaultValue: 2,
type: "combo",
options: LiteGraph.LINK_RENDER_MODES.map((m, i) => ({
value: i,
text: m,
selected: i == app.canvas.links_render_mode,
})),
onChange(value) {
app.canvas.links_render_mode = +value;
app.graph.setDirtyCanvas(true);
},
});
},
};
app.registerExtension(ext);

View File

@ -2,7 +2,7 @@ import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js";
import { app } from "../../scripts/app.js";
const CONVERTED_TYPE = "converted-widget";
const VALID_TYPES = ["STRING", "combo", "number"];
const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"];
function isConvertableWidget(widget, config) {
return VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0]);

View File

@ -9766,6 +9766,7 @@ LGraphNode.prototype.executeAction = function(action)
switch (w.type) {
case "button":
ctx.fillStyle = background_color;
if (w.clicked) {
ctx.fillStyle = "#AAA";
w.clicked = false;
@ -9835,7 +9836,11 @@ LGraphNode.prototype.executeAction = function(action)
ctx.textAlign = "center";
ctx.fillStyle = text_color;
ctx.fillText(
w.label || w.name + " " + Number(w.value).toFixed(3),
w.label || w.name + " " + Number(w.value).toFixed(
w.options.precision != null
? w.options.precision
: 3
),
widget_width * 0.5,
y + H * 0.7
);
@ -13835,7 +13840,7 @@ LGraphNode.prototype.executeAction = function(action)
if (!disabled) {
element.addEventListener("click", inner_onclick);
}
if (options.autoopen) {
if (!disabled && options.autoopen) {
LiteGraph.pointerListenerAdd(element,"enter",inner_over);
}

View File

@ -264,6 +264,15 @@ class ComfyApi extends EventTarget {
}
}
/**
* Gets system & device stats
* @returns System stats such as python version, OS, per device info
*/
async getSystemStats() {
const res = await this.fetchApi("/system_stats");
return await res.json();
}
/**
* Sends a POST request to the API
* @param {*} type The endpoint to post to

View File

@ -1,3 +1,4 @@
import { ComfyLogging } from "./logging.js";
import { ComfyWidgets } from "./widgets.js";
import { ComfyUI, $el } from "./ui.js";
import { api } from "./api.js";
@ -31,6 +32,7 @@ export class ComfyApp {
constructor() {
this.ui = new ComfyUI(this);
this.logging = new ComfyLogging(this);
/**
* List of extensions that are registered with the app
@ -282,6 +284,11 @@ export class ComfyApp {
}
}
options.push({
content: "Bypass",
callback: (obj) => { if (this.mode === 4) this.mode = 0; else this.mode = 4; this.graph.change(); }
});
// prevent conflict of clipspace content
if(!ComfyApp.clipspace_return_node) {
options.push({
@ -768,6 +775,19 @@ export class ComfyApp {
}
block_default = true;
}
if (e.keyCode == 66 && e.ctrlKey) {
if (this.selected_nodes) {
for (var i in this.selected_nodes) {
if (this.selected_nodes[i].mode === 4) { // never
this.selected_nodes[i].mode = 0; // always
} else {
this.selected_nodes[i].mode = 4; // never
}
}
}
block_default = true;
}
}
this.graph.change();
@ -914,14 +934,21 @@ export class ComfyApp {
const origDrawNode = LGraphCanvas.prototype.drawNode;
LGraphCanvas.prototype.drawNode = function (node, ctx) {
var editor_alpha = this.editor_alpha;
var old_color = node.bgcolor;
if (node.mode === 2) { // never
this.editor_alpha = 0.4;
}
if (node.mode === 4) { // never
node.bgcolor = "#FF00FF";
this.editor_alpha = 0.2;
}
const res = origDrawNode.apply(this, arguments);
this.editor_alpha = editor_alpha;
node.bgcolor = old_color;
return res;
};
@ -1003,6 +1030,7 @@ export class ComfyApp {
*/
async #loadExtensions() {
const extensions = await api.getExtensions();
this.logging.addEntry("Comfy.App", "debug", { Extensions: extensions });
for (const ext of extensions) {
try {
await import(api.apiURL(ext));
@ -1286,6 +1314,9 @@ export class ComfyApp {
(t) => `<li>${t}</li>`
).join("")}</ul>Nodes that have failed to load will show as red on the graph.`
);
this.logging.addEntry("Comfy.App", "warn", {
MissingNodes: missingNodeTypes,
});
}
}
@ -1308,7 +1339,7 @@ export class ComfyApp {
continue;
}
if (node.mode === 2) {
if (node.mode === 2 || node.mode === 4) {
// Don't serialize muted nodes
continue;
}
@ -1331,12 +1362,36 @@ export class ComfyApp {
let parent = node.getInputNode(i);
if (parent) {
let link = node.getInputLink(i);
while (parent && parent.isVirtualNode) {
link = parent.getInputLink(link.origin_slot);
if (link) {
parent = parent.getInputNode(link.origin_slot);
} else {
parent = null;
while (parent.mode === 4 || parent.isVirtualNode) {
let found = false;
if (parent.isVirtualNode) {
link = parent.getInputLink(link.origin_slot);
if (link) {
parent = parent.getInputNode(link.target_slot);
if (parent) {
found = true;
}
}
} else if (link && parent.mode === 4) {
let all_inputs = [link.origin_slot];
if (parent.inputs) {
all_inputs = all_inputs.concat(Object.keys(parent.inputs))
for (let parent_input in all_inputs) {
parent_input = all_inputs[parent_input];
if (parent.inputs[parent_input].type === node.inputs[i].type) {
link = parent.getInputLink(parent_input);
if (link) {
parent = parent.getInputNode(parent_input);
}
found = true;
break;
}
}
}
}
if (!found) {
break;
}
}

367
web/scripts/logging.js Normal file
View File

@ -0,0 +1,367 @@
import { $el, ComfyDialog } from "./ui.js";
import { api } from "./api.js";
$el("style", {
textContent: `
.comfy-logging-logs {
display: grid;
color: var(--fg-color);
white-space: pre-wrap;
}
.comfy-logging-log {
display: contents;
}
.comfy-logging-title {
background: var(--tr-even-bg-color);
font-weight: bold;
margin-bottom: 5px;
text-align: center;
}
.comfy-logging-log div {
background: var(--row-bg);
padding: 5px;
}
`,
parent: document.body,
});
// Stringify function supporting max depth and removal of circular references
// https://stackoverflow.com/a/57193345
function stringify(val, depth, replacer, space, onGetObjID) {
depth = isNaN(+depth) ? 1 : depth;
var recursMap = new WeakMap();
function _build(val, depth, o, a, r) {
// (JSON.stringify() has it's own rules, which we respect here by using it for property iteration)
return !val || typeof val != "object"
? val
: ((r = recursMap.has(val)),
recursMap.set(val, true),
(a = Array.isArray(val)),
r
? (o = (onGetObjID && onGetObjID(val)) || null)
: JSON.stringify(val, function (k, v) {
if (a || depth > 0) {
if (replacer) v = replacer(k, v);
if (!k) return (a = Array.isArray(v)), (val = v);
!o && (o = a ? [] : {});
o[k] = _build(v, a ? depth : depth - 1);
}
}),
o === void 0 ? (a ? [] : {}) : o);
}
return JSON.stringify(_build(val, depth), null, space);
}
const jsonReplacer = (k, v, ui) => {
if (v instanceof Array && v.length === 1) {
v = v[0];
}
if (v instanceof Date) {
v = v.toISOString();
if (ui) {
v = v.split("T")[1];
}
}
if (v instanceof Error) {
let err = "";
if (v.name) err += v.name + "\n";
if (v.message) err += v.message + "\n";
if (v.stack) err += v.stack + "\n";
if (!err) {
err = v.toString();
}
v = err;
}
return v;
};
const fileInput = $el("input", {
type: "file",
accept: ".json",
style: { display: "none" },
parent: document.body,
});
class ComfyLoggingDialog extends ComfyDialog {
constructor(logging) {
super();
this.logging = logging;
}
clear() {
this.logging.clear();
this.show();
}
export() {
const blob = new Blob([stringify([...this.logging.entries], 20, jsonReplacer, "\t")], {
type: "application/json",
});
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: `comfyui-logs-${Date.now()}.json`,
style: { display: "none" },
parent: document.body,
});
a.click();
setTimeout(function () {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
}
import() {
fileInput.onchange = () => {
const reader = new FileReader();
reader.onload = () => {
fileInput.remove();
try {
const obj = JSON.parse(reader.result);
if (obj instanceof Array) {
this.show(obj);
} else {
throw new Error("Invalid file selected.");
}
} catch (error) {
alert("Unable to load logs: " + error.message);
}
};
reader.readAsText(fileInput.files[0]);
};
fileInput.click();
}
createButtons() {
return [
$el("button", {
type: "button",
textContent: "Clear",
onclick: () => this.clear(),
}),
$el("button", {
type: "button",
textContent: "Export logs...",
onclick: () => this.export(),
}),
$el("button", {
type: "button",
textContent: "View exported logs...",
onclick: () => this.import(),
}),
...super.createButtons(),
];
}
getTypeColor(type) {
switch (type) {
case "error":
return "red";
case "warn":
return "orange";
case "debug":
return "dodgerblue";
}
}
show(entries) {
if (!entries) entries = this.logging.entries;
this.element.style.width = "100%";
const cols = {
source: "Source",
type: "Type",
timestamp: "Timestamp",
message: "Message",
};
const keys = Object.keys(cols);
const headers = Object.values(cols).map((title) =>
$el("div.comfy-logging-title", {
textContent: title,
})
);
const rows = entries.map((entry, i) => {
return $el(
"div.comfy-logging-log",
{
$: (el) => el.style.setProperty("--row-bg", `var(--tr-${i % 2 ? "even" : "odd"}-bg-color)`),
},
keys.map((key) => {
let v = entry[key];
let color;
if (key === "type") {
color = this.getTypeColor(v);
} else {
v = jsonReplacer(key, v, true);
if (typeof v === "object") {
v = stringify(v, 5, jsonReplacer, " ");
}
}
return $el("div", {
style: {
color,
},
textContent: v,
});
})
);
});
const grid = $el(
"div.comfy-logging-logs",
{
style: {
gridTemplateColumns: `repeat(${headers.length}, 1fr)`,
},
},
[...headers, ...rows]
);
const els = [grid];
if (!this.logging.enabled) {
els.unshift(
$el("h3", {
style: { textAlign: "center" },
textContent: "Logging is disabled",
})
);
}
super.show($el("div", els));
}
}
export class ComfyLogging {
/**
* @type Array<{ source: string, type: string, timestamp: Date, message: any }>
*/
entries = [];
#enabled;
#console = {};
get enabled() {
return this.#enabled;
}
set enabled(value) {
if (value === this.#enabled) return;
if (value) {
this.patchConsole();
} else {
this.unpatchConsole();
}
this.#enabled = value;
}
constructor(app) {
this.app = app;
this.dialog = new ComfyLoggingDialog(this);
this.addSetting();
this.catchUnhandled();
this.addInitData();
}
addSetting() {
const settingId = "Comfy.Logging.Enabled";
const htmlSettingId = settingId.replaceAll(".", "-");
const setting = this.app.ui.settings.addSetting({
id: settingId,
name: settingId,
defaultValue: true,
type: (name, setter, value) => {
return $el("tr", [
$el("td", [
$el("label", {
textContent: "Logging",
for: htmlSettingId,
}),
]),
$el("td", [
$el("input", {
id: htmlSettingId,
type: "checkbox",
checked: value,
onchange: (event) => {
setter((this.enabled = event.target.checked));
},
}),
$el("button", {
textContent: "View Logs",
onclick: () => {
this.app.ui.settings.element.close();
this.dialog.show();
},
style: {
fontSize: "14px",
display: "block",
marginTop: "5px",
},
}),
]),
]);
},
});
this.enabled = setting.value;
}
patchConsole() {
// Capture common console outputs
const self = this;
for (const type of ["log", "warn", "error", "debug"]) {
const orig = console[type];
this.#console[type] = orig;
console[type] = function () {
orig.apply(console, arguments);
self.addEntry("console", type, ...arguments);
};
}
}
unpatchConsole() {
// Restore original console functions
for (const type of Object.keys(this.#console)) {
console[type] = this.#console[type];
}
this.#console = {};
}
catchUnhandled() {
// Capture uncaught errors
window.addEventListener("error", (e) => {
this.addEntry("window", "error", e.error ?? "Unknown error");
return false;
});
window.addEventListener("unhandledrejection", (e) => {
this.addEntry("unhandledrejection", "error", e.reason ?? "Unknown error");
});
}
clear() {
this.entries = [];
}
addEntry(source, type, ...args) {
if (this.enabled) {
this.entries.push({
source,
type,
timestamp: new Date(),
message: args,
});
}
}
log(source, ...args) {
this.addEntry(source, "log", ...args);
}
async addInitData() {
if (!this.enabled) return;
const source = "ComfyUI.Logging";
this.addEntry(source, "debug", { UserAgent: navigator.userAgent });
const systemStats = await api.getSystemStats();
this.addEntry(source, "debug", systemStats);
}
}

View File

@ -234,7 +234,7 @@ class ComfySettingsDialog extends ComfyDialog {
localStorage[settingId] = JSON.stringify(value);
}
addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "",}) {
addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined}) {
if (!id) {
throw new Error("Settings must have an ID");
}
@ -347,6 +347,32 @@ class ComfySettingsDialog extends ComfyDialog {
]),
]);
break;
case "combo":
element = $el("tr", [
labelCell,
$el("td", [
$el(
"select",
{
oninput: (e) => {
setter(e.target.value);
},
},
(typeof options === "function" ? options(value) : options || []).map((opt) => {
if (typeof opt === "string") {
opt = { text: opt };
}
const v = opt.value ?? opt.text;
return $el("option", {
value: v,
textContent: opt.text,
selected: value + "" === v + "",
});
})
),
]),
]);
break;
case "text":
default:
if (type !== "text") {
@ -480,7 +506,7 @@ class ComfyList {
hide() {
this.element.style.display = "none";
this.button.textContent = "See " + this.#text;
this.button.textContent = "View " + this.#text;
}
toggle() {
@ -542,6 +568,13 @@ export class ComfyUI {
defaultValue: "",
});
this.settings.addSetting({
id: "Comfy.DisableSliders",
name: "Disable sliders.",
type: "boolean",
defaultValue: false,
});
const fileInput = $el("input", {
id: "comfy-file-input",
type: "file",

View File

@ -79,8 +79,8 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
return valueControl;
};
function seedWidget(node, inputName, inputData) {
const seed = ComfyWidgets.INT(node, inputName, inputData);
function seedWidget(node, inputName, inputData, app) {
const seed = ComfyWidgets.INT(node, inputName, inputData, app);
const seedControl = addValueControlWidget(node, seed.widget, "randomize");
seed.widget.linkedWidgets = [seedControl];
@ -250,19 +250,29 @@ function addMultilineWidget(node, name, opts, app) {
return { minWidth: 400, minHeight: 200, widget };
}
function isSlider(display, app) {
if (app.ui.settings.getSettingValue("Comfy.DisableSliders")) {
return "number"
}
return (display==="slider") ? "slider" : "number"
}
export const ComfyWidgets = {
"INT:seed": seedWidget,
"INT:noise_seed": seedWidget,
FLOAT(node, inputName, inputData) {
FLOAT(node, inputName, inputData, app) {
let widgetType = isSlider(inputData[1]["display"], app);
const { val, config } = getNumberDefaults(inputData, 0.5);
return { widget: node.addWidget("number", inputName, val, () => {}, config) };
return { widget: node.addWidget(widgetType, inputName, val, () => {}, config) };
},
INT(node, inputName, inputData) {
INT(node, inputName, inputData, app) {
let widgetType = isSlider(inputData[1]["display"], app);
const { val, config } = getNumberDefaults(inputData, 1);
Object.assign(config, { precision: 0 });
return {
widget: node.addWidget(
"number",
widgetType,
inputName,
val,
function (v) {
@ -273,6 +283,18 @@ export const ComfyWidgets = {
),
};
},
BOOLEAN(node, inputName, inputData) {
let defaultVal = inputData[1]["default"];
return {
widget: node.addWidget(
"toggle",
inputName,
defaultVal,
() => {},
{"on": inputData[1].label_on, "off": inputData[1].label_off}
)
};
},
STRING(node, inputName, inputData, app) {
const defaultVal = inputData[1].default || "";
const multiline = !!inputData[1].multiline;
@ -411,7 +433,7 @@ export const ComfyWidgets = {
// Add handler to check if an image is being dragged over our node
node.onDragOver = function (e) {
if (e.dataTransfer && e.dataTransfer.items) {
const image = [...e.dataTransfer.items].find((f) => f.kind === "file" && f.type.startsWith("image/"));
const image = [...e.dataTransfer.items].find((f) => f.kind === "file");
return !!image;
}

View File

@ -30,9 +30,7 @@ export interface ComfyExtension {
getCustomWidgets(
app: ComfyApp
): Promise<
Array<
Record<string, (node, inputName, inputData, app) => { widget?: IWidget; minWidth?: number; minHeight?: number }>
>
Record<string, (node, inputName, inputData, app) => { widget?: IWidget; minWidth?: number; minHeight?: number }>
>;
/**
* Allows the extension to add additional handling to the node before it is registered with LGraph