mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-22 07:49:33 +08:00
Merge branch 'comfyanonymous:master' into dpr
This commit is contained in:
commit
caf4ff59cc
@ -2,6 +2,13 @@ name: "Windows Release cu118 dependencies 2"
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
xformers:
|
||||||
|
description: 'xformers version'
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
default: "xformers"
|
||||||
|
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
@ -17,7 +24,7 @@ jobs:
|
|||||||
|
|
||||||
- shell: bash
|
- shell: bash
|
||||||
run: |
|
run: |
|
||||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
|
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
|
||||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||||
echo installed basic
|
echo installed basic
|
||||||
ls -lah temp_wheel_dir
|
ls -lah temp_wheel_dir
|
||||||
|
|||||||
1
CODEOWNERS
Normal file
1
CODEOWNERS
Normal file
@ -0,0 +1 @@
|
|||||||
|
* @comfyanonymous
|
||||||
@ -47,6 +47,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
| Ctrl + O | Load workflow |
|
| Ctrl + O | Load workflow |
|
||||||
| Ctrl + A | Select all nodes |
|
| Ctrl + A | Select all nodes |
|
||||||
| Ctrl + M | Mute/unmute selected nodes |
|
| Ctrl + M | Mute/unmute selected nodes |
|
||||||
|
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||||
| Delete/Backspace | Delete selected nodes |
|
| Delete/Backspace | Delete selected nodes |
|
||||||
| Ctrl + Delete/Backspace | Delete the current graph |
|
| Ctrl + Delete/Backspace | Delete the current graph |
|
||||||
| Space | Move the canvas around when held and moving the cursor |
|
| Space | Move the canvas around when held and moving the cursor |
|
||||||
@ -93,8 +94,8 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
|||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 5.5 that supports the 7000 series and might have some performance improvements:
|
This is the command to install the nightly with ROCm 5.6 that supports the 7000 series and might have some performance improvements:
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.5 -r requirements.txt```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.6```
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
@ -126,10 +127,10 @@ After this you should have everything installed and can proceed to running Comfy
|
|||||||
|
|
||||||
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
|
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
|
||||||
|
|
||||||
1. Install pytorch. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide.
|
1. Install pytorch nightly. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide (make sure to install the latest pytorch nightly).
|
||||||
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
|
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
|
||||||
1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies).
|
1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies).
|
||||||
1. Launch ComfyUI by running `python main.py`.
|
1. Launch ComfyUI by running `python main.py --force-fp16`. Note that --force-fp16 will only work if you installed the latest pytorch nightly.
|
||||||
|
|
||||||
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
|
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
|
||||||
|
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from ..ldm.modules.diffusionmodules.util import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from ..ldm.modules.attention import SpatialTransformer
|
from ..ldm.modules.attention import SpatialTransformer
|
||||||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
|
||||||
from ..ldm.util import exists
|
from ..ldm.util import exists
|
||||||
|
|
||||||
|
|
||||||
@ -57,6 +57,7 @@ class ControlNet(nn.Module):
|
|||||||
transformer_depth_middle=None,
|
transformer_depth_middle=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||||
if use_spatial_transformer:
|
if use_spatial_transformer:
|
||||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||||
|
|
||||||
@ -200,13 +201,7 @@ class ControlNet(nn.Module):
|
|||||||
|
|
||||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||||
layers.append(
|
layers.append(
|
||||||
AttentionBlock(
|
SpatialTransformer(
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=dim_head,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
|
||||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint
|
use_checkpoint=use_checkpoint
|
||||||
@ -259,13 +254,7 @@ class ControlNet(nn.Module):
|
|||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
),
|
),
|
||||||
AttentionBlock(
|
SpatialTransformer( # always uses a self-attn
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=dim_head,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
|
||||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint
|
use_checkpoint=use_checkpoint
|
||||||
|
|||||||
@ -38,8 +38,14 @@ parser.add_argument("--port", type=int, default=8188, help="Set the listen port.
|
|||||||
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
||||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
||||||
|
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
|
||||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||||
|
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||||
|
cm_group = parser.add_mutually_exclusive_group()
|
||||||
|
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||||
|
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
||||||
|
|
||||||
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||||
|
|
||||||
fp_group = parser.add_mutually_exclusive_group()
|
fp_group = parser.add_mutually_exclusive_group()
|
||||||
@ -80,7 +86,12 @@ parser.add_argument("--dont-print-server", action="store_true", help="Don't prin
|
|||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||||
|
|
||||||
|
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.windows_standalone_build:
|
if args.windows_standalone_build:
|
||||||
args.auto_launch = True
|
args.auto_launch = True
|
||||||
|
|
||||||
|
if args.disable_auto_launch:
|
||||||
|
args.auto_launch = False
|
||||||
|
|||||||
@ -24,8 +24,8 @@ class ClipVisionModel():
|
|||||||
return self.model.load_state_dict(sd, strict=False)
|
return self.model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
def encode_image(self, image):
|
def encode_image(self, image):
|
||||||
img = torch.clip((255. * image[0]), 0, 255).round().int()
|
img = torch.clip((255. * image), 0, 255).round().int()
|
||||||
inputs = self.processor(images=[img], return_tensors="pt")
|
inputs = self.processor(images=img, return_tensors="pt")
|
||||||
outputs = self.model(**inputs)
|
outputs = self.model(**inputs)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|||||||
@ -148,6 +148,10 @@ vae_conversion_map_attn = [
|
|||||||
("q.", "query."),
|
("q.", "query."),
|
||||||
("k.", "key."),
|
("k.", "key."),
|
||||||
("v.", "value."),
|
("v.", "value."),
|
||||||
|
("q.", "to_q."),
|
||||||
|
("k.", "to_k."),
|
||||||
|
("v.", "to_v."),
|
||||||
|
("proj_out.", "to_out.0."),
|
||||||
("proj_out.", "proj_attn."),
|
("proj_out.", "proj_attn."),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -180,7 +180,6 @@ class NoiseScheduleVP:
|
|||||||
|
|
||||||
def model_wrapper(
|
def model_wrapper(
|
||||||
model,
|
model,
|
||||||
sampling_function,
|
|
||||||
noise_schedule,
|
noise_schedule,
|
||||||
model_type="noise",
|
model_type="noise",
|
||||||
model_kwargs={},
|
model_kwargs={},
|
||||||
@ -295,7 +294,7 @@ def model_wrapper(
|
|||||||
if t_continuous.reshape((-1,)).shape[0] == 1:
|
if t_continuous.reshape((-1,)).shape[0] == 1:
|
||||||
t_continuous = t_continuous.expand((x.shape[0]))
|
t_continuous = t_continuous.expand((x.shape[0]))
|
||||||
t_input = get_model_input_time(t_continuous)
|
t_input = get_model_input_time(t_continuous)
|
||||||
output = sampling_function(model, x, t_input, **model_kwargs)
|
output = model(x, t_input, **model_kwargs)
|
||||||
if model_type == "noise":
|
if model_type == "noise":
|
||||||
return output
|
return output
|
||||||
elif model_type == "x_start":
|
elif model_type == "x_start":
|
||||||
@ -843,10 +842,12 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
|||||||
else:
|
else:
|
||||||
timesteps = sigmas.clone()
|
timesteps = sigmas.clone()
|
||||||
|
|
||||||
for s in range(timesteps.shape[0]):
|
alphas_cumprod = model.inner_model.alphas_cumprod
|
||||||
timesteps[s] = (model.sigma_to_t(timesteps[s]) / 1000) + (1 / len(model.sigmas))
|
|
||||||
|
|
||||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=model.inner_model.alphas_cumprod)
|
for s in range(timesteps.shape[0]):
|
||||||
|
timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod))
|
||||||
|
|
||||||
|
ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
img = image * ns.marginal_alpha(timesteps[0])
|
img = image * ns.marginal_alpha(timesteps[0])
|
||||||
@ -859,18 +860,15 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
|||||||
img = noise
|
img = noise
|
||||||
|
|
||||||
if to_zero:
|
if to_zero:
|
||||||
timesteps[-1] = (1 / len(model.sigmas))
|
timesteps[-1] = (1 / len(alphas_cumprod))
|
||||||
|
|
||||||
device = noise.device
|
device = noise.device
|
||||||
|
|
||||||
if model.parameterization == "v":
|
|
||||||
model_type = "v"
|
model_type = "noise"
|
||||||
else:
|
|
||||||
model_type = "noise"
|
|
||||||
|
|
||||||
model_fn = model_wrapper(
|
model_fn = model_wrapper(
|
||||||
model.inner_model.inner_model.apply_model,
|
model.predict_eps_discrete_timestep,
|
||||||
sampling_function,
|
|
||||||
ns,
|
ns,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
guidance_type="uncond",
|
guidance_type="uncond",
|
||||||
|
|||||||
@ -63,12 +63,17 @@ class DiscreteSchedule(nn.Module):
|
|||||||
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
||||||
return sampling.append_zero(self.t_to_sigma(t))
|
return sampling.append_zero(self.t_to_sigma(t))
|
||||||
|
|
||||||
def sigma_to_t(self, sigma, quantize=None):
|
def sigma_to_discrete_timestep(self, sigma):
|
||||||
quantize = self.quantize if quantize is None else quantize
|
|
||||||
log_sigma = sigma.log()
|
log_sigma = sigma.log()
|
||||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||||
|
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||||
|
|
||||||
|
def sigma_to_t(self, sigma, quantize=None):
|
||||||
|
quantize = self.quantize if quantize is None else quantize
|
||||||
if quantize:
|
if quantize:
|
||||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
return self.sigma_to_discrete_timestep(sigma)
|
||||||
|
log_sigma = sigma.log()
|
||||||
|
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||||
high_idx = low_idx + 1
|
high_idx = low_idx + 1
|
||||||
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
||||||
@ -85,6 +90,12 @@ class DiscreteSchedule(nn.Module):
|
|||||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||||
return log_sigma.exp()
|
return log_sigma.exp()
|
||||||
|
|
||||||
|
def predict_eps_discrete_timestep(self, input, t, **kwargs):
|
||||||
|
if t.dtype != torch.int64 and t.dtype != torch.int32:
|
||||||
|
t = t.round()
|
||||||
|
sigma = self.t_to_sigma(t)
|
||||||
|
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
||||||
|
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
||||||
|
|
||||||
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
||||||
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import math
|
|||||||
from scipy import integrate
|
from scipy import integrate
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchdiffeq import odeint
|
|
||||||
import torchsde
|
import torchsde
|
||||||
from tqdm.auto import trange, tqdm
|
from tqdm.auto import trange, tqdm
|
||||||
|
|
||||||
@ -131,9 +130,9 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||||
eps = torch.randn_like(x) * s_noise
|
|
||||||
sigma_hat = sigmas[i] * (gamma + 1)
|
sigma_hat = sigmas[i] * (gamma + 1)
|
||||||
if gamma > 0:
|
if gamma > 0:
|
||||||
|
eps = torch.randn_like(x) * s_noise
|
||||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||||
d = to_d(x, sigma_hat, denoised)
|
d = to_d(x, sigma_hat, denoised)
|
||||||
@ -172,9 +171,9 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||||
eps = torch.randn_like(x) * s_noise
|
|
||||||
sigma_hat = sigmas[i] * (gamma + 1)
|
sigma_hat = sigmas[i] * (gamma + 1)
|
||||||
if gamma > 0:
|
if gamma > 0:
|
||||||
|
eps = torch.randn_like(x) * s_noise
|
||||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||||
d = to_d(x, sigma_hat, denoised)
|
d = to_d(x, sigma_hat, denoised)
|
||||||
@ -201,9 +200,9 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||||
eps = torch.randn_like(x) * s_noise
|
|
||||||
sigma_hat = sigmas[i] * (gamma + 1)
|
sigma_hat = sigmas[i] * (gamma + 1)
|
||||||
if gamma > 0:
|
if gamma > 0:
|
||||||
|
eps = torch.randn_like(x) * s_noise
|
||||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||||
d = to_d(x, sigma_hat, denoised)
|
d = to_d(x, sigma_hat, denoised)
|
||||||
@ -287,30 +286,6 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
|
||||||
v = torch.randint_like(x, 2) * 2 - 1
|
|
||||||
fevals = 0
|
|
||||||
def ode_fn(sigma, x):
|
|
||||||
nonlocal fevals
|
|
||||||
with torch.enable_grad():
|
|
||||||
x = x[0].detach().requires_grad_()
|
|
||||||
denoised = model(x, sigma * s_in, **extra_args)
|
|
||||||
d = to_d(x, sigma, denoised)
|
|
||||||
fevals += 1
|
|
||||||
grad = torch.autograd.grad((d * v).sum(), x)[0]
|
|
||||||
d_ll = (v * grad).flatten(1).sum(1)
|
|
||||||
return d.detach(), d_ll
|
|
||||||
x_min = x, x.new_zeros([x.shape[0]])
|
|
||||||
t = x.new_tensor([sigma_min, sigma_max])
|
|
||||||
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
|
|
||||||
latent, delta_ll = sol[0][-1], sol[1][-1]
|
|
||||||
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
|
|
||||||
return ll_prior + delta_ll, {'fevals': fevals}
|
|
||||||
|
|
||||||
|
|
||||||
class PIDStepSizeController:
|
class PIDStepSizeController:
|
||||||
"""A PID controller for ODE adaptive step size control."""
|
"""A PID controller for ODE adaptive step size control."""
|
||||||
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
||||||
@ -656,23 +631,78 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
elif solver_type == 'midpoint':
|
elif solver_type == 'midpoint':
|
||||||
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||||
|
|
||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
if eta:
|
||||||
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||||
|
|
||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
h_last = h
|
h_last = h
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
"""DPM-Solver++(3M) SDE."""
|
||||||
|
|
||||||
|
seed = extra_args.get("seed", None)
|
||||||
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
denoised_1, denoised_2 = None, None
|
||||||
|
h_1, h_2 = None, None
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
if sigmas[i + 1] == 0:
|
||||||
|
# Denoising step
|
||||||
|
x = denoised
|
||||||
|
else:
|
||||||
|
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||||
|
h = s - t
|
||||||
|
h_eta = h * (eta + 1)
|
||||||
|
|
||||||
|
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
|
||||||
|
|
||||||
|
if h_2 is not None:
|
||||||
|
r0 = h_1 / h
|
||||||
|
r1 = h_2 / h
|
||||||
|
d1_0 = (denoised - denoised_1) / r0
|
||||||
|
d1_1 = (denoised_1 - denoised_2) / r1
|
||||||
|
d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
|
||||||
|
d2 = (d1_0 - d1_1) / (r0 + r1)
|
||||||
|
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
||||||
|
phi_3 = phi_2 / h_eta - 0.5
|
||||||
|
x = x + phi_2 * d1 - phi_3 * d2
|
||||||
|
elif h_1 is not None:
|
||||||
|
r = h_1 / h
|
||||||
|
d = (denoised - denoised_1) / r
|
||||||
|
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
||||||
|
x = x + phi_2 * d
|
||||||
|
|
||||||
|
if eta:
|
||||||
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
||||||
|
|
||||||
|
denoised_1, denoised_2 = denoised, denoised_1
|
||||||
|
h_1, h_2 = h, h_1
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
|
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,7 @@ class DDIMSampler(object):
|
|||||||
self.ddpm_num_timesteps = model.num_timesteps
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.parameterization = kwargs.get("parameterization", "eps")
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
if type(attr) == torch.Tensor:
|
if type(attr) == torch.Tensor:
|
||||||
@ -261,7 +262,7 @@ class DDIMSampler(object):
|
|||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
if denoise_function is not None:
|
if denoise_function is not None:
|
||||||
model_output = denoise_function(self.model.apply_model, x, t, **extra_args)
|
model_output = denoise_function(x, t, **extra_args)
|
||||||
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
model_output = self.model.apply_model(x, t, c)
|
model_output = self.model.apply_model(x, t, c)
|
||||||
else:
|
else:
|
||||||
@ -289,13 +290,13 @@ class DDIMSampler(object):
|
|||||||
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
||||||
|
|
||||||
if self.model.parameterization == "v":
|
if self.parameterization == "v":
|
||||||
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
||||||
else:
|
else:
|
||||||
e_t = model_output
|
e_t = model_output
|
||||||
|
|
||||||
if score_corrector is not None:
|
if score_corrector is not None:
|
||||||
assert self.model.parameterization == "eps", 'not implemented'
|
assert self.parameterization == "eps", 'not implemented'
|
||||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
@ -309,7 +310,7 @@ class DDIMSampler(object):
|
|||||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||||
|
|
||||||
# current prediction for x_0
|
# current prediction for x_0
|
||||||
if self.model.parameterization != "v":
|
if self.parameterization != "v":
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
else:
|
else:
|
||||||
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
|
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
|
||||||
|
|||||||
@ -36,7 +36,7 @@ def uniq(arr):
|
|||||||
def default(val, d):
|
def default(val, d):
|
||||||
if exists(val):
|
if exists(val):
|
||||||
return val
|
return val
|
||||||
return d() if isfunction(d) else d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def max_neg_value(t):
|
def max_neg_value(t):
|
||||||
@ -52,9 +52,9 @@ def init_(tensor):
|
|||||||
|
|
||||||
# feedforward
|
# feedforward
|
||||||
class GEGLU(nn.Module):
|
class GEGLU(nn.Module):
|
||||||
def __init__(self, dim_in, dim_out, dtype=None):
|
def __init__(self, dim_in, dim_out, dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype)
|
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||||
@ -62,19 +62,19 @@ class GEGLU(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None):
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = int(dim * mult)
|
||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
project_in = nn.Sequential(
|
project_in = nn.Sequential(
|
||||||
comfy.ops.Linear(dim, inner_dim, dtype=dtype),
|
comfy.ops.Linear(dim, inner_dim, dtype=dtype, device=device),
|
||||||
nn.GELU()
|
nn.GELU()
|
||||||
) if not glu else GEGLU(dim, inner_dim, dtype=dtype)
|
) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
project_in,
|
project_in,
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
comfy.ops.Linear(inner_dim, dim_out, dtype=dtype)
|
comfy.ops.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -90,8 +90,8 @@ def zero_module(module):
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels, dtype=None):
|
def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
|
||||||
class SpatialSelfAttention(nn.Module):
|
class SpatialSelfAttention(nn.Module):
|
||||||
@ -148,7 +148,7 @@ class SpatialSelfAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CrossAttentionBirchSan(nn.Module):
|
class CrossAttentionBirchSan(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -156,12 +156,12 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
|
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -245,7 +245,7 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CrossAttentionDoggettx(nn.Module):
|
class CrossAttentionDoggettx(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -253,12 +253,12 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
|
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -343,7 +343,7 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
return self.to_out(r2)
|
return self.to_out(r2)
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -351,12 +351,12 @@ class CrossAttention(nn.Module):
|
|||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
|
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -399,7 +399,7 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
class MemoryEfficientCrossAttention(nn.Module):
|
class MemoryEfficientCrossAttention(nn.Module):
|
||||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||||
f"{heads} heads.")
|
f"{heads} heads.")
|
||||||
@ -409,11 +409,11 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
|
|
||||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
@ -450,7 +450,7 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
class CrossAttentionPytorch(nn.Module):
|
class CrossAttentionPytorch(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -458,11 +458,11 @@ class CrossAttentionPytorch(nn.Module):
|
|||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
|
|
||||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
@ -508,17 +508,17 @@ else:
|
|||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||||
disable_self_attn=False, dtype=None):
|
disable_self_attn=False, dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.disable_self_attn = disable_self_attn
|
self.disable_self_attn = disable_self_attn
|
||||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||||
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype) # is a self-attention if not self.disable_self_attn
|
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device) # is a self-attention if not self.disable_self_attn
|
||||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype)
|
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device)
|
||||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype) # is self-attn if context is none
|
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device) # is self-attn if context is none
|
||||||
self.norm1 = nn.LayerNorm(dim, dtype=dtype)
|
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
|
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
|
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.d_head = d_head
|
self.d_head = d_head
|
||||||
@ -648,34 +648,34 @@ class SpatialTransformer(nn.Module):
|
|||||||
def __init__(self, in_channels, n_heads, d_head,
|
def __init__(self, in_channels, n_heads, d_head,
|
||||||
depth=1, dropout=0., context_dim=None,
|
depth=1, dropout=0., context_dim=None,
|
||||||
disable_self_attn=False, use_linear=False,
|
disable_self_attn=False, use_linear=False,
|
||||||
use_checkpoint=True, dtype=None):
|
use_checkpoint=True, dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if exists(context_dim) and not isinstance(context_dim, list):
|
if exists(context_dim) and not isinstance(context_dim, list):
|
||||||
context_dim = [context_dim] * depth
|
context_dim = [context_dim] * depth
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
inner_dim = n_heads * d_head
|
inner_dim = n_heads * d_head
|
||||||
self.norm = Normalize(in_channels, dtype=dtype)
|
self.norm = Normalize(in_channels, dtype=dtype, device=device)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
self.proj_in = nn.Conv2d(in_channels,
|
self.proj_in = nn.Conv2d(in_channels,
|
||||||
inner_dim,
|
inner_dim,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0, dtype=dtype)
|
padding=0, dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype)
|
self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.transformer_blocks = nn.ModuleList(
|
self.transformer_blocks = nn.ModuleList(
|
||||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
||||||
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype)
|
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device)
|
||||||
for d in range(depth)]
|
for d in range(depth)]
|
||||||
)
|
)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
self.proj_out = nn.Conv2d(inner_dim,in_channels,
|
self.proj_out = nn.Conv2d(inner_dim,in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0, dtype=dtype)
|
padding=0, dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype)
|
self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device)
|
||||||
self.use_linear = use_linear
|
self.use_linear = use_linear
|
||||||
|
|
||||||
def forward(self, x, context=None, transformer_options={}):
|
def forward(self, x, context=None, transformer_options={}):
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from typing import Optional, Any
|
|||||||
|
|
||||||
from ..attention import MemoryEfficientCrossAttention
|
from ..attention import MemoryEfficientCrossAttention
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
if model_management.xformers_enabled_vae():
|
if model_management.xformers_enabled_vae():
|
||||||
import xformers
|
import xformers
|
||||||
@ -48,7 +49,7 @@ class Upsample(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
self.conv = torch.nn.Conv2d(in_channels,
|
self.conv = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -67,7 +68,7 @@ class Downsample(nn.Module):
|
|||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
# no asymmetric padding in torch conv, must do it ourselves
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
self.conv = torch.nn.Conv2d(in_channels,
|
self.conv = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=2,
|
stride=2,
|
||||||
@ -95,30 +96,30 @@ class ResnetBlock(nn.Module):
|
|||||||
|
|
||||||
self.swish = torch.nn.SiLU(inplace=True)
|
self.swish = torch.nn.SiLU(inplace=True)
|
||||||
self.norm1 = Normalize(in_channels)
|
self.norm1 = Normalize(in_channels)
|
||||||
self.conv1 = torch.nn.Conv2d(in_channels,
|
self.conv1 = comfy.ops.Conv2d(in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
if temb_channels > 0:
|
if temb_channels > 0:
|
||||||
self.temb_proj = torch.nn.Linear(temb_channels,
|
self.temb_proj = comfy.ops.Linear(temb_channels,
|
||||||
out_channels)
|
out_channels)
|
||||||
self.norm2 = Normalize(out_channels)
|
self.norm2 = Normalize(out_channels)
|
||||||
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||||
self.conv2 = torch.nn.Conv2d(out_channels,
|
self.conv2 = comfy.ops.Conv2d(out_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
if self.use_conv_shortcut:
|
if self.use_conv_shortcut:
|
||||||
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
self.conv_shortcut = comfy.ops.Conv2d(in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
else:
|
else:
|
||||||
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
self.nin_shortcut = comfy.ops.Conv2d(in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -188,22 +189,22 @@ class AttnBlock(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels)
|
||||||
self.q = torch.nn.Conv2d(in_channels,
|
self.q = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.k = torch.nn.Conv2d(in_channels,
|
self.k = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.v = torch.nn.Conv2d(in_channels,
|
self.v = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
self.proj_out = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -243,22 +244,22 @@ class MemoryEfficientAttnBlock(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels)
|
||||||
self.q = torch.nn.Conv2d(in_channels,
|
self.q = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.k = torch.nn.Conv2d(in_channels,
|
self.k = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.v = torch.nn.Conv2d(in_channels,
|
self.v = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
self.proj_out = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -302,22 +303,22 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels)
|
||||||
self.q = torch.nn.Conv2d(in_channels,
|
self.q = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.k = torch.nn.Conv2d(in_channels,
|
self.k = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.v = torch.nn.Conv2d(in_channels,
|
self.v = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
self.proj_out = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -399,14 +400,14 @@ class Model(nn.Module):
|
|||||||
# timestep embedding
|
# timestep embedding
|
||||||
self.temb = nn.Module()
|
self.temb = nn.Module()
|
||||||
self.temb.dense = nn.ModuleList([
|
self.temb.dense = nn.ModuleList([
|
||||||
torch.nn.Linear(self.ch,
|
comfy.ops.Linear(self.ch,
|
||||||
self.temb_ch),
|
self.temb_ch),
|
||||||
torch.nn.Linear(self.temb_ch,
|
comfy.ops.Linear(self.temb_ch,
|
||||||
self.temb_ch),
|
self.temb_ch),
|
||||||
])
|
])
|
||||||
|
|
||||||
# downsampling
|
# downsampling
|
||||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
self.conv_in = comfy.ops.Conv2d(in_channels,
|
||||||
self.ch,
|
self.ch,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -475,7 +476,7 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
# end
|
# end
|
||||||
self.norm_out = Normalize(block_in)
|
self.norm_out = Normalize(block_in)
|
||||||
self.conv_out = torch.nn.Conv2d(block_in,
|
self.conv_out = comfy.ops.Conv2d(block_in,
|
||||||
out_ch,
|
out_ch,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -548,7 +549,7 @@ class Encoder(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
# downsampling
|
# downsampling
|
||||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
self.conv_in = comfy.ops.Conv2d(in_channels,
|
||||||
self.ch,
|
self.ch,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -593,7 +594,7 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
# end
|
# end
|
||||||
self.norm_out = Normalize(block_in)
|
self.norm_out = Normalize(block_in)
|
||||||
self.conv_out = torch.nn.Conv2d(block_in,
|
self.conv_out = comfy.ops.Conv2d(block_in,
|
||||||
2*z_channels if double_z else z_channels,
|
2*z_channels if double_z else z_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -653,7 +654,7 @@ class Decoder(nn.Module):
|
|||||||
self.z_shape, np.prod(self.z_shape)))
|
self.z_shape, np.prod(self.z_shape)))
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
self.conv_in = torch.nn.Conv2d(z_channels,
|
self.conv_in = comfy.ops.Conv2d(z_channels,
|
||||||
block_in,
|
block_in,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -695,7 +696,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
# end
|
# end
|
||||||
self.norm_out = Normalize(block_in)
|
self.norm_out = Normalize(block_in)
|
||||||
self.conv_out = torch.nn.Conv2d(block_in,
|
self.conv_out = comfy.ops.Conv2d(block_in,
|
||||||
out_ch,
|
out_ch,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
|
|||||||
@ -19,45 +19,6 @@ from ..attention import SpatialTransformer
|
|||||||
from comfy.ldm.util import exists
|
from comfy.ldm.util import exists
|
||||||
|
|
||||||
|
|
||||||
# dummy replace
|
|
||||||
def convert_module_to_f16(x):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def convert_module_to_f32(x):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
## go
|
|
||||||
class AttentionPool2d(nn.Module):
|
|
||||||
"""
|
|
||||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
spacial_dim: int,
|
|
||||||
embed_dim: int,
|
|
||||||
num_heads_channels: int,
|
|
||||||
output_dim: int = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
|
|
||||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
|
||||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
|
||||||
self.num_heads = embed_dim // num_heads_channels
|
|
||||||
self.attention = QKVAttention(self.num_heads)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
b, c, *_spatial = x.shape
|
|
||||||
x = x.reshape(b, c, -1) # NC(HW)
|
|
||||||
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
|
||||||
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
|
||||||
x = self.qkv_proj(x)
|
|
||||||
x = self.attention(x)
|
|
||||||
x = self.c_proj(x)
|
|
||||||
return x[:, :, 0]
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepBlock(nn.Module):
|
class TimestepBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
Any module where forward() takes timestep embeddings as a second argument.
|
Any module where forward() takes timestep embeddings as a second argument.
|
||||||
@ -111,14 +72,14 @@ class Upsample(nn.Module):
|
|||||||
upsampling occurs in the inner-two dimensions.
|
upsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
self.use_conv = use_conv
|
self.use_conv = use_conv
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
if use_conv:
|
if use_conv:
|
||||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype)
|
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x, output_shape=None):
|
def forward(self, x, output_shape=None):
|
||||||
assert x.shape[1] == self.channels
|
assert x.shape[1] == self.channels
|
||||||
@ -138,19 +99,6 @@ class Upsample(nn.Module):
|
|||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class TransposedUpsample(nn.Module):
|
|
||||||
'Learned 2x upsampling without padding'
|
|
||||||
def __init__(self, channels, out_channels=None, ks=5):
|
|
||||||
super().__init__()
|
|
||||||
self.channels = channels
|
|
||||||
self.out_channels = out_channels or channels
|
|
||||||
|
|
||||||
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
|
|
||||||
|
|
||||||
def forward(self,x):
|
|
||||||
return self.up(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Downsample(nn.Module):
|
class Downsample(nn.Module):
|
||||||
"""
|
"""
|
||||||
A downsampling layer with an optional convolution.
|
A downsampling layer with an optional convolution.
|
||||||
@ -160,7 +108,7 @@ class Downsample(nn.Module):
|
|||||||
downsampling occurs in the inner-two dimensions.
|
downsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
@ -169,7 +117,7 @@ class Downsample(nn.Module):
|
|||||||
stride = 2 if dims != 3 else (1, 2, 2)
|
stride = 2 if dims != 3 else (1, 2, 2)
|
||||||
if use_conv:
|
if use_conv:
|
||||||
self.op = conv_nd(
|
self.op = conv_nd(
|
||||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype
|
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert self.channels == self.out_channels
|
assert self.channels == self.out_channels
|
||||||
@ -208,7 +156,8 @@ class ResBlock(TimestepBlock):
|
|||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
up=False,
|
up=False,
|
||||||
down=False,
|
down=False,
|
||||||
dtype=None
|
dtype=None,
|
||||||
|
device=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
@ -220,19 +169,19 @@ class ResBlock(TimestepBlock):
|
|||||||
self.use_scale_shift_norm = use_scale_shift_norm
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
|
||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
nn.GroupNorm(32, channels, dtype=dtype),
|
nn.GroupNorm(32, channels, dtype=dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.updown = up or down
|
self.updown = up or down
|
||||||
|
|
||||||
if up:
|
if up:
|
||||||
self.h_upd = Upsample(channels, False, dims, dtype=dtype)
|
self.h_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
|
||||||
self.x_upd = Upsample(channels, False, dims, dtype=dtype)
|
self.x_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
|
||||||
elif down:
|
elif down:
|
||||||
self.h_upd = Downsample(channels, False, dims, dtype=dtype)
|
self.h_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
|
||||||
self.x_upd = Downsample(channels, False, dims, dtype=dtype)
|
self.x_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
self.h_upd = self.x_upd = nn.Identity()
|
self.h_upd = self.x_upd = nn.Identity()
|
||||||
|
|
||||||
@ -240,15 +189,15 @@ class ResBlock(TimestepBlock):
|
|||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(
|
linear(
|
||||||
emb_channels,
|
emb_channels,
|
||||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype
|
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.out_layers = nn.Sequential(
|
self.out_layers = nn.Sequential(
|
||||||
nn.GroupNorm(32, self.out_channels, dtype=dtype),
|
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Dropout(p=dropout),
|
nn.Dropout(p=dropout),
|
||||||
zero_module(
|
zero_module(
|
||||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype)
|
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -256,10 +205,10 @@ class ResBlock(TimestepBlock):
|
|||||||
self.skip_connection = nn.Identity()
|
self.skip_connection = nn.Identity()
|
||||||
elif use_conv:
|
elif use_conv:
|
||||||
self.skip_connection = conv_nd(
|
self.skip_connection = conv_nd(
|
||||||
dims, channels, self.out_channels, 3, padding=1, dtype=dtype
|
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype)
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x, emb):
|
def forward(self, x, emb):
|
||||||
"""
|
"""
|
||||||
@ -295,142 +244,6 @@ class ResBlock(TimestepBlock):
|
|||||||
h = self.out_layers(h)
|
h = self.out_layers(h)
|
||||||
return self.skip_connection(x) + h
|
return self.skip_connection(x) + h
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlock(nn.Module):
|
|
||||||
"""
|
|
||||||
An attention block that allows spatial positions to attend to each other.
|
|
||||||
Originally ported from here, but adapted to the N-d case.
|
|
||||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
channels,
|
|
||||||
num_heads=1,
|
|
||||||
num_head_channels=-1,
|
|
||||||
use_checkpoint=False,
|
|
||||||
use_new_attention_order=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.channels = channels
|
|
||||||
if num_head_channels == -1:
|
|
||||||
self.num_heads = num_heads
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
channels % num_head_channels == 0
|
|
||||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
|
||||||
self.num_heads = channels // num_head_channels
|
|
||||||
self.use_checkpoint = use_checkpoint
|
|
||||||
self.norm = normalization(channels)
|
|
||||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
|
||||||
if use_new_attention_order:
|
|
||||||
# split qkv before split heads
|
|
||||||
self.attention = QKVAttention(self.num_heads)
|
|
||||||
else:
|
|
||||||
# split heads before split qkv
|
|
||||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
|
||||||
|
|
||||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
|
||||||
#return pt_checkpoint(self._forward, x) # pytorch
|
|
||||||
|
|
||||||
def _forward(self, x):
|
|
||||||
b, c, *spatial = x.shape
|
|
||||||
x = x.reshape(b, c, -1)
|
|
||||||
qkv = self.qkv(self.norm(x))
|
|
||||||
h = self.attention(qkv)
|
|
||||||
h = self.proj_out(h)
|
|
||||||
return (x + h).reshape(b, c, *spatial)
|
|
||||||
|
|
||||||
|
|
||||||
def count_flops_attn(model, _x, y):
|
|
||||||
"""
|
|
||||||
A counter for the `thop` package to count the operations in an
|
|
||||||
attention operation.
|
|
||||||
Meant to be used like:
|
|
||||||
macs, params = thop.profile(
|
|
||||||
model,
|
|
||||||
inputs=(inputs, timestamps),
|
|
||||||
custom_ops={QKVAttention: QKVAttention.count_flops},
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
b, c, *spatial = y[0].shape
|
|
||||||
num_spatial = int(np.prod(spatial))
|
|
||||||
# We perform two matmuls with the same number of ops.
|
|
||||||
# The first computes the weight matrix, the second computes
|
|
||||||
# the combination of the value vectors.
|
|
||||||
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
|
||||||
model.total_ops += th.DoubleTensor([matmul_ops])
|
|
||||||
|
|
||||||
|
|
||||||
class QKVAttentionLegacy(nn.Module):
|
|
||||||
"""
|
|
||||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, n_heads):
|
|
||||||
super().__init__()
|
|
||||||
self.n_heads = n_heads
|
|
||||||
|
|
||||||
def forward(self, qkv):
|
|
||||||
"""
|
|
||||||
Apply QKV attention.
|
|
||||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
|
||||||
:return: an [N x (H * C) x T] tensor after attention.
|
|
||||||
"""
|
|
||||||
bs, width, length = qkv.shape
|
|
||||||
assert width % (3 * self.n_heads) == 0
|
|
||||||
ch = width // (3 * self.n_heads)
|
|
||||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
|
||||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
||||||
weight = th.einsum(
|
|
||||||
"bct,bcs->bts", q * scale, k * scale
|
|
||||||
) # More stable with f16 than dividing afterwards
|
|
||||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
||||||
a = th.einsum("bts,bcs->bct", weight, v)
|
|
||||||
return a.reshape(bs, -1, length)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def count_flops(model, _x, y):
|
|
||||||
return count_flops_attn(model, _x, y)
|
|
||||||
|
|
||||||
|
|
||||||
class QKVAttention(nn.Module):
|
|
||||||
"""
|
|
||||||
A module which performs QKV attention and splits in a different order.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, n_heads):
|
|
||||||
super().__init__()
|
|
||||||
self.n_heads = n_heads
|
|
||||||
|
|
||||||
def forward(self, qkv):
|
|
||||||
"""
|
|
||||||
Apply QKV attention.
|
|
||||||
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
|
||||||
:return: an [N x (H * C) x T] tensor after attention.
|
|
||||||
"""
|
|
||||||
bs, width, length = qkv.shape
|
|
||||||
assert width % (3 * self.n_heads) == 0
|
|
||||||
ch = width // (3 * self.n_heads)
|
|
||||||
q, k, v = qkv.chunk(3, dim=1)
|
|
||||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
||||||
weight = th.einsum(
|
|
||||||
"bct,bcs->bts",
|
|
||||||
(q * scale).view(bs * self.n_heads, ch, length),
|
|
||||||
(k * scale).view(bs * self.n_heads, ch, length),
|
|
||||||
) # More stable with f16 than dividing afterwards
|
|
||||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
||||||
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
|
||||||
return a.reshape(bs, -1, length)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def count_flops(model, _x, y):
|
|
||||||
return count_flops_attn(model, _x, y)
|
|
||||||
|
|
||||||
|
|
||||||
class Timestep(nn.Module):
|
class Timestep(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -503,8 +316,10 @@ class UNetModel(nn.Module):
|
|||||||
use_linear_in_transformer=False,
|
use_linear_in_transformer=False,
|
||||||
adm_in_channels=None,
|
adm_in_channels=None,
|
||||||
transformer_depth_middle=None,
|
transformer_depth_middle=None,
|
||||||
|
device=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||||
if use_spatial_transformer:
|
if use_spatial_transformer:
|
||||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||||
|
|
||||||
@ -564,9 +379,9 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
time_embed_dim = model_channels * 4
|
time_embed_dim = model_channels * 4
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
linear(model_channels, time_embed_dim, dtype=self.dtype),
|
linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
@ -579,9 +394,9 @@ class UNetModel(nn.Module):
|
|||||||
assert adm_in_channels is not None
|
assert adm_in_channels is not None
|
||||||
self.label_emb = nn.Sequential(
|
self.label_emb = nn.Sequential(
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
linear(adm_in_channels, time_embed_dim, dtype=self.dtype),
|
linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -590,7 +405,7 @@ class UNetModel(nn.Module):
|
|||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
TimestepEmbedSequential(
|
TimestepEmbedSequential(
|
||||||
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype)
|
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -609,7 +424,8 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
dtype=self.dtype
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = mult * model_channels
|
ch = mult * model_channels
|
||||||
@ -628,17 +444,10 @@ class UNetModel(nn.Module):
|
|||||||
disabled_sa = False
|
disabled_sa = False
|
||||||
|
|
||||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||||
layers.append(
|
layers.append(SpatialTransformer(
|
||||||
AttentionBlock(
|
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=dim_head,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
|
||||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
@ -657,11 +466,12 @@ class UNetModel(nn.Module):
|
|||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
down=True,
|
down=True,
|
||||||
dtype=self.dtype
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Downsample(
|
else Downsample(
|
||||||
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype
|
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -686,18 +496,13 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
dtype=self.dtype
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
),
|
),
|
||||||
AttentionBlock(
|
SpatialTransformer( # always uses a self-attn
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=dim_head,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
|
||||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||||
),
|
),
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
@ -706,7 +511,8 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
dtype=self.dtype
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
@ -724,7 +530,8 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
dtype=self.dtype
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = model_channels * mult
|
ch = model_channels * mult
|
||||||
@ -744,16 +551,10 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||||
layers.append(
|
layers.append(
|
||||||
AttentionBlock(
|
SpatialTransformer(
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads_upsample,
|
|
||||||
num_head_channels=dim_head,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
|
||||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if level and i == self.num_res_blocks[level]:
|
if level and i == self.num_res_blocks[level]:
|
||||||
@ -768,43 +569,28 @@ class UNetModel(nn.Module):
|
|||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
up=True,
|
up=True,
|
||||||
dtype=self.dtype
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype)
|
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device)
|
||||||
)
|
)
|
||||||
ds //= 2
|
ds //= 2
|
||||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
nn.GroupNorm(32, ch, dtype=self.dtype),
|
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
|
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
|
||||||
)
|
)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
self.id_predictor = nn.Sequential(
|
self.id_predictor = nn.Sequential(
|
||||||
nn.GroupNorm(32, ch, dtype=self.dtype),
|
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||||
conv_nd(dims, model_channels, n_embed, 1),
|
conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
|
||||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||||
)
|
)
|
||||||
|
|
||||||
def convert_to_fp16(self):
|
|
||||||
"""
|
|
||||||
Convert the torso of the model to float16.
|
|
||||||
"""
|
|
||||||
self.input_blocks.apply(convert_module_to_f16)
|
|
||||||
self.middle_block.apply(convert_module_to_f16)
|
|
||||||
self.output_blocks.apply(convert_module_to_f16)
|
|
||||||
|
|
||||||
def convert_to_fp32(self):
|
|
||||||
"""
|
|
||||||
Convert the torso of the model to float32.
|
|
||||||
"""
|
|
||||||
self.input_blocks.apply(convert_module_to_f32)
|
|
||||||
self.middle_block.apply(convert_module_to_f32)
|
|
||||||
self.output_blocks.apply(convert_module_to_f32)
|
|
||||||
|
|
||||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||||
"""
|
"""
|
||||||
Apply the model to an input batch.
|
Apply the model to an input batch.
|
||||||
|
|||||||
@ -4,27 +4,27 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
|
|||||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from enum import Enum
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
|
class ModelType(Enum):
|
||||||
|
EPS = 1
|
||||||
|
V_PREDICTION = 2
|
||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
class BaseModel(torch.nn.Module):
|
||||||
def __init__(self, model_config, v_prediction=False):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
unet_config = model_config.unet_config
|
unet_config = model_config.unet_config
|
||||||
self.latent_format = model_config.latent_format
|
self.latent_format = model_config.latent_format
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||||
self.diffusion_model = UNetModel(**unet_config)
|
self.diffusion_model = UNetModel(**unet_config, device=device)
|
||||||
self.v_prediction = v_prediction
|
self.model_type = model_type
|
||||||
if self.v_prediction:
|
|
||||||
self.parameterization = "v"
|
|
||||||
else:
|
|
||||||
self.parameterization = "eps"
|
|
||||||
|
|
||||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||||
if self.adm_channels is None:
|
if self.adm_channels is None:
|
||||||
self.adm_channels = 0
|
self.adm_channels = 0
|
||||||
print("v_prediction", v_prediction)
|
print("model_type", model_type.name)
|
||||||
print("adm", self.adm_channels)
|
print("adm", self.adm_channels)
|
||||||
|
|
||||||
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||||
@ -99,53 +99,58 @@ class BaseModel(torch.nn.Module):
|
|||||||
if self.get_dtype() == torch.float16:
|
if self.get_dtype() == torch.float16:
|
||||||
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
|
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
|
||||||
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
|
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
|
||||||
|
|
||||||
|
if self.model_type == ModelType.V_PREDICTION:
|
||||||
|
unet_state_dict["v_pred"] = torch.tensor([])
|
||||||
|
|
||||||
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
||||||
|
|
||||||
|
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
|
||||||
|
adm_inputs = []
|
||||||
|
weights = []
|
||||||
|
noise_aug = []
|
||||||
|
for unclip_cond in unclip_conditioning:
|
||||||
|
for adm_cond in unclip_cond["clip_vision_output"].image_embeds:
|
||||||
|
weight = unclip_cond["strength"]
|
||||||
|
noise_augment = unclip_cond["noise_augmentation"]
|
||||||
|
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
|
||||||
|
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
|
||||||
|
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
|
||||||
|
weights.append(weight)
|
||||||
|
noise_aug.append(noise_augment)
|
||||||
|
adm_inputs.append(adm_out)
|
||||||
|
|
||||||
|
if len(noise_aug) > 1:
|
||||||
|
adm_out = torch.stack(adm_inputs).sum(0)
|
||||||
|
noise_augment = noise_augment_merge
|
||||||
|
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
|
||||||
|
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
|
||||||
|
adm_out = torch.cat((c_adm, noise_level_emb), 1)
|
||||||
|
|
||||||
|
return adm_out
|
||||||
|
|
||||||
class SD21UNCLIP(BaseModel):
|
class SD21UNCLIP(BaseModel):
|
||||||
def __init__(self, model_config, noise_aug_config, v_prediction=True):
|
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||||
super().__init__(model_config, v_prediction)
|
super().__init__(model_config, model_type, device=device)
|
||||||
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
unclip_conditioning = kwargs.get("unclip_conditioning", None)
|
unclip_conditioning = kwargs.get("unclip_conditioning", None)
|
||||||
device = kwargs["device"]
|
device = kwargs["device"]
|
||||||
|
if unclip_conditioning is None:
|
||||||
if unclip_conditioning is not None:
|
return torch.zeros((1, self.adm_channels))
|
||||||
adm_inputs = []
|
|
||||||
weights = []
|
|
||||||
noise_aug = []
|
|
||||||
for unclip_cond in unclip_conditioning:
|
|
||||||
adm_cond = unclip_cond["clip_vision_output"].image_embeds
|
|
||||||
weight = unclip_cond["strength"]
|
|
||||||
noise_augment = unclip_cond["noise_augmentation"]
|
|
||||||
noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment)
|
|
||||||
c_adm, noise_level_emb = self.noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
|
|
||||||
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
|
|
||||||
weights.append(weight)
|
|
||||||
noise_aug.append(noise_augment)
|
|
||||||
adm_inputs.append(adm_out)
|
|
||||||
|
|
||||||
if len(noise_aug) > 1:
|
|
||||||
adm_out = torch.stack(adm_inputs).sum(0)
|
|
||||||
#TODO: add a way to control this
|
|
||||||
noise_augment = 0.05
|
|
||||||
noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment)
|
|
||||||
c_adm, noise_level_emb = self.noise_augmentor(adm_out[:, :self.noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
|
|
||||||
adm_out = torch.cat((c_adm, noise_level_emb), 1)
|
|
||||||
else:
|
else:
|
||||||
adm_out = torch.zeros((1, self.adm_channels))
|
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05))
|
||||||
|
|
||||||
return adm_out
|
|
||||||
|
|
||||||
class SDInpaint(BaseModel):
|
class SDInpaint(BaseModel):
|
||||||
def __init__(self, model_config, v_prediction=False):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__(model_config, v_prediction)
|
super().__init__(model_config, model_type, device=device)
|
||||||
self.concat_keys = ("mask", "masked_image")
|
self.concat_keys = ("mask", "masked_image")
|
||||||
|
|
||||||
class SDXLRefiner(BaseModel):
|
class SDXLRefiner(BaseModel):
|
||||||
def __init__(self, model_config, v_prediction=False):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__(model_config, v_prediction)
|
super().__init__(model_config, model_type, device=device)
|
||||||
self.embedder = Timestep(256)
|
self.embedder = Timestep(256)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
@ -160,7 +165,6 @@ class SDXLRefiner(BaseModel):
|
|||||||
else:
|
else:
|
||||||
aesthetic_score = kwargs.get("aesthetic_score", 6)
|
aesthetic_score = kwargs.get("aesthetic_score", 6)
|
||||||
|
|
||||||
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
|
|
||||||
out = []
|
out = []
|
||||||
out.append(self.embedder(torch.Tensor([height])))
|
out.append(self.embedder(torch.Tensor([height])))
|
||||||
out.append(self.embedder(torch.Tensor([width])))
|
out.append(self.embedder(torch.Tensor([width])))
|
||||||
@ -171,8 +175,8 @@ class SDXLRefiner(BaseModel):
|
|||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
|
|
||||||
class SDXL(BaseModel):
|
class SDXL(BaseModel):
|
||||||
def __init__(self, model_config, v_prediction=False):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__(model_config, v_prediction)
|
super().__init__(model_config, model_type, device=device)
|
||||||
self.embedder = Timestep(256)
|
self.embedder = Timestep(256)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
@ -184,7 +188,6 @@ class SDXL(BaseModel):
|
|||||||
target_width = kwargs.get("target_width", width)
|
target_width = kwargs.get("target_width", width)
|
||||||
target_height = kwargs.get("target_height", height)
|
target_height = kwargs.get("target_height", height)
|
||||||
|
|
||||||
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
|
|
||||||
out = []
|
out = []
|
||||||
out.append(self.embedder(torch.Tensor([height])))
|
out.append(self.embedder(torch.Tensor([height])))
|
||||||
out.append(self.embedder(torch.Tensor([width])))
|
out.append(self.embedder(torch.Tensor([width])))
|
||||||
|
|||||||
@ -113,8 +113,63 @@ def model_config_from_unet_config(unet_config):
|
|||||||
if model_config.matches(unet_config):
|
if model_config.matches(unet_config):
|
||||||
return model_config(unet_config)
|
return model_config(unet_config)
|
||||||
|
|
||||||
|
print("no match", unet_config)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
|
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
|
||||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
||||||
return model_config_from_unet_config(unet_config)
|
return model_config_from_unet_config(unet_config)
|
||||||
|
|
||||||
|
|
||||||
|
def model_config_from_diffusers_unet(state_dict, use_fp16):
|
||||||
|
match = {}
|
||||||
|
match["context_dim"] = state_dict["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
|
||||||
|
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
|
||||||
|
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
|
||||||
|
match["adm_in_channels"] = None
|
||||||
|
if "class_embedding.linear_1.weight" in state_dict:
|
||||||
|
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
|
||||||
|
elif "add_embedding.linear_1.weight" in state_dict:
|
||||||
|
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
||||||
|
|
||||||
|
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||||
|
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
|
||||||
|
|
||||||
|
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280}
|
||||||
|
|
||||||
|
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||||
|
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||||
|
|
||||||
|
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||||
|
|
||||||
|
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||||
|
|
||||||
|
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||||
|
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
|
||||||
|
|
||||||
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
|
||||||
|
|
||||||
|
for unet_config in supported_models:
|
||||||
|
matches = True
|
||||||
|
for k in match:
|
||||||
|
if match[k] != unet_config[k]:
|
||||||
|
matches = False
|
||||||
|
break
|
||||||
|
if matches:
|
||||||
|
return model_config_from_unet_config(unet_config)
|
||||||
|
return None
|
||||||
|
|||||||
@ -49,6 +49,7 @@ except:
|
|||||||
try:
|
try:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
cpu_state = CPUState.MPS
|
cpu_state = CPUState.MPS
|
||||||
|
import torch.mps
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -204,7 +205,11 @@ print(f"Set vram state to: {vram_state.name}")
|
|||||||
def get_torch_device_name(device):
|
def get_torch_device_name(device):
|
||||||
if hasattr(device, 'type'):
|
if hasattr(device, 'type'):
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
return "{} {}".format(device, torch.cuda.get_device_name(device))
|
try:
|
||||||
|
allocator_backend = torch.cuda.get_allocator_backend()
|
||||||
|
except:
|
||||||
|
allocator_backend = ""
|
||||||
|
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
|
||||||
else:
|
else:
|
||||||
return "{}".format(device.type)
|
return "{}".format(device.type)
|
||||||
else:
|
else:
|
||||||
@ -233,10 +238,9 @@ def unload_model():
|
|||||||
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
|
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
|
|
||||||
|
current_loaded_model.unpatch_model()
|
||||||
current_loaded_model.model.to(current_loaded_model.offload_device)
|
current_loaded_model.model.to(current_loaded_model.offload_device)
|
||||||
current_loaded_model.model_patches_to(current_loaded_model.offload_device)
|
current_loaded_model.model_patches_to(current_loaded_model.offload_device)
|
||||||
current_loaded_model.unpatch_model()
|
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
if vram_state != VRAMState.HIGH_VRAM:
|
if vram_state != VRAMState.HIGH_VRAM:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
@ -258,15 +262,11 @@ def load_model_gpu(model):
|
|||||||
if model is current_loaded_model:
|
if model is current_loaded_model:
|
||||||
return
|
return
|
||||||
unload_model()
|
unload_model()
|
||||||
try:
|
|
||||||
real_model = model.patch_model()
|
|
||||||
except Exception as e:
|
|
||||||
model.unpatch_model()
|
|
||||||
raise e
|
|
||||||
|
|
||||||
torch_dev = model.load_device
|
torch_dev = model.load_device
|
||||||
model.model_patches_to(torch_dev)
|
model.model_patches_to(torch_dev)
|
||||||
model.model_patches_to(model.model_dtype())
|
model.model_patches_to(model.model_dtype())
|
||||||
|
current_loaded_model = model
|
||||||
|
|
||||||
if is_device_cpu(torch_dev):
|
if is_device_cpu(torch_dev):
|
||||||
vram_set_state = VRAMState.DISABLED
|
vram_set_state = VRAMState.DISABLED
|
||||||
@ -280,21 +280,33 @@ def load_model_gpu(model):
|
|||||||
if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary
|
if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary
|
||||||
vram_set_state = VRAMState.LOW_VRAM
|
vram_set_state = VRAMState.LOW_VRAM
|
||||||
|
|
||||||
current_loaded_model = model
|
real_model = model.model
|
||||||
|
patch_model_to = None
|
||||||
if vram_set_state == VRAMState.DISABLED:
|
if vram_set_state == VRAMState.DISABLED:
|
||||||
pass
|
pass
|
||||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
real_model.to(torch_dev)
|
patch_model_to = torch_dev
|
||||||
else:
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
|
||||||
elif vram_set_state == VRAMState.LOW_VRAM:
|
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
real_model = model.patch_model(device_to=patch_model_to)
|
||||||
|
except Exception as e:
|
||||||
|
model.unpatch_model()
|
||||||
|
unload_model()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if patch_model_to is not None:
|
||||||
|
real_model.to(torch_dev)
|
||||||
|
|
||||||
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
||||||
model_accelerated = True
|
model_accelerated = True
|
||||||
|
elif vram_set_state == VRAMState.LOW_VRAM:
|
||||||
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
||||||
|
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
||||||
|
model_accelerated = True
|
||||||
|
|
||||||
return current_loaded_model
|
return current_loaded_model
|
||||||
|
|
||||||
def load_controlnet_gpu(control_models):
|
def load_controlnet_gpu(control_models):
|
||||||
@ -352,6 +364,7 @@ def text_encoder_device():
|
|||||||
if args.gpu_only:
|
if args.gpu_only:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
|
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
|
||||||
|
#NOTE: on a Ryzen 5 7600X with 4080 it's faster to shift to GPU
|
||||||
if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough.
|
if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough.
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
else:
|
else:
|
||||||
@ -522,7 +535,7 @@ def should_use_fp16(device=None, model_params=0):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
#FP16 is just broken on these cards
|
#FP16 is just broken on these cards
|
||||||
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600"]
|
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX"]
|
||||||
for x in nvidia_16_series:
|
for x in nvidia_16_series:
|
||||||
if x in props.name:
|
if x in props.name:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from comfy import model_management
|
|||||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||||
import math
|
import math
|
||||||
|
from comfy import model_base
|
||||||
|
|
||||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||||
return abs(a*b) // math.gcd(a, b)
|
return abs(a*b) // math.gcd(a, b)
|
||||||
@ -16,6 +17,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
strength = 1.0
|
strength = 1.0
|
||||||
|
if 'timestep_start' in cond[1]:
|
||||||
|
timestep_start = cond[1]['timestep_start']
|
||||||
|
if timestep_in[0] > timestep_start:
|
||||||
|
return None
|
||||||
|
if 'timestep_end' in cond[1]:
|
||||||
|
timestep_end = cond[1]['timestep_end']
|
||||||
|
if timestep_in[0] < timestep_end:
|
||||||
|
return None
|
||||||
if 'area' in cond[1]:
|
if 'area' in cond[1]:
|
||||||
area = cond[1]['area']
|
area = cond[1]['area']
|
||||||
if 'strength' in cond[1]:
|
if 'strength' in cond[1]:
|
||||||
@ -180,12 +189,13 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
to_run += [(p, COND)]
|
to_run += [(p, COND)]
|
||||||
for x in uncond:
|
if uncond is not None:
|
||||||
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
|
for x in uncond:
|
||||||
if p is None:
|
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
|
||||||
continue
|
if p is None:
|
||||||
|
continue
|
||||||
|
|
||||||
to_run += [(p, UNCOND)]
|
to_run += [(p, UNCOND)]
|
||||||
|
|
||||||
while len(to_run) > 0:
|
while len(to_run) > 0:
|
||||||
first = to_run[0]
|
first = to_run[0]
|
||||||
@ -247,7 +257,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
|
|
||||||
c['transformer_options'] = transformer_options
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
|
if 'model_function_wrapper' in model_options:
|
||||||
|
output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||||
|
else:
|
||||||
|
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
|
||||||
del input_x
|
del input_x
|
||||||
|
|
||||||
model_management.throw_exception_if_processing_interrupted()
|
model_management.throw_exception_if_processing_interrupted()
|
||||||
@ -270,6 +283,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
|
|
||||||
|
|
||||||
max_total_area = model_management.maximum_batch_area()
|
max_total_area = model_management.maximum_batch_area()
|
||||||
|
if math.isclose(cond_scale, 1.0):
|
||||||
|
uncond = None
|
||||||
|
|
||||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
|
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
|
||||||
if "sampler_cfg_function" in model_options:
|
if "sampler_cfg_function" in model_options:
|
||||||
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
|
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
|
||||||
@ -331,6 +347,17 @@ def ddim_scheduler(model, steps):
|
|||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
|
def sgm_scheduler(model, steps):
|
||||||
|
sigs = []
|
||||||
|
timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int)
|
||||||
|
for x in range(len(timesteps)):
|
||||||
|
ts = timesteps[x]
|
||||||
|
if ts > 999:
|
||||||
|
ts = 999
|
||||||
|
sigs.append(model.t_to_sigma(torch.tensor(ts)))
|
||||||
|
sigs += [0.0]
|
||||||
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
def blank_inpaint_image_like(latent_image):
|
def blank_inpaint_image_like(latent_image):
|
||||||
blank_image = torch.ones_like(latent_image)
|
blank_image = torch.ones_like(latent_image)
|
||||||
# these are the values for "zero" in pixel space translated to latent space
|
# these are the values for "zero" in pixel space translated to latent space
|
||||||
@ -424,6 +451,35 @@ def create_cond_with_same_area_if_none(conds, c):
|
|||||||
n = c[1].copy()
|
n = c[1].copy()
|
||||||
conds += [[smallest[0], n]]
|
conds += [[smallest[0], n]]
|
||||||
|
|
||||||
|
def calculate_start_end_timesteps(model, conds):
|
||||||
|
for t in range(len(conds)):
|
||||||
|
x = conds[t]
|
||||||
|
|
||||||
|
timestep_start = None
|
||||||
|
timestep_end = None
|
||||||
|
if 'start_percent' in x[1]:
|
||||||
|
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0)))
|
||||||
|
if 'end_percent' in x[1]:
|
||||||
|
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0)))
|
||||||
|
|
||||||
|
if (timestep_start is not None) or (timestep_end is not None):
|
||||||
|
n = x[1].copy()
|
||||||
|
if (timestep_start is not None):
|
||||||
|
n['timestep_start'] = timestep_start
|
||||||
|
if (timestep_end is not None):
|
||||||
|
n['timestep_end'] = timestep_end
|
||||||
|
conds[t] = [x[0], n]
|
||||||
|
|
||||||
|
def pre_run_control(model, conds):
|
||||||
|
for t in range(len(conds)):
|
||||||
|
x = conds[t]
|
||||||
|
|
||||||
|
timestep_start = None
|
||||||
|
timestep_end = None
|
||||||
|
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
|
||||||
|
if 'control' in x[1]:
|
||||||
|
x[1]['control'].pre_run(model.inner_model, percent_to_timestep_function)
|
||||||
|
|
||||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||||
cond_cnets = []
|
cond_cnets = []
|
||||||
cond_other = []
|
cond_other = []
|
||||||
@ -480,19 +536,19 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
|
|||||||
|
|
||||||
|
|
||||||
class KSampler:
|
class KSampler:
|
||||||
SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
|
SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
||||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"]
|
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model_denoise = CFGNoisePredictor(self.model)
|
self.model_denoise = CFGNoisePredictor(self.model)
|
||||||
if self.model.parameterization == "v":
|
if self.model.model_type == model_base.ModelType.V_PREDICTION:
|
||||||
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
|
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
|
||||||
else:
|
else:
|
||||||
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
|
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
|
||||||
self.model_wrap.parameterization = self.model.parameterization
|
|
||||||
self.model_k = KSamplerX0Inpaint(self.model_wrap)
|
self.model_k = KSamplerX0Inpaint(self.model_wrap)
|
||||||
self.device = device
|
self.device = device
|
||||||
if scheduler not in self.SCHEDULERS:
|
if scheduler not in self.SCHEDULERS:
|
||||||
@ -525,6 +581,8 @@ class KSampler:
|
|||||||
sigmas = simple_scheduler(self.model_wrap, steps)
|
sigmas = simple_scheduler(self.model_wrap, steps)
|
||||||
elif self.scheduler == "ddim_uniform":
|
elif self.scheduler == "ddim_uniform":
|
||||||
sigmas = ddim_scheduler(self.model_wrap, steps)
|
sigmas = ddim_scheduler(self.model_wrap, steps)
|
||||||
|
elif self.scheduler == "sgm_uniform":
|
||||||
|
sigmas = sgm_scheduler(self.model_wrap, steps)
|
||||||
else:
|
else:
|
||||||
print("error invalid scheduler", self.scheduler)
|
print("error invalid scheduler", self.scheduler)
|
||||||
|
|
||||||
@ -567,13 +625,18 @@ class KSampler:
|
|||||||
resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
|
resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
|
||||||
resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
|
resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
|
||||||
|
|
||||||
|
calculate_start_end_timesteps(self.model_wrap, negative)
|
||||||
|
calculate_start_end_timesteps(self.model_wrap, positive)
|
||||||
|
|
||||||
#make sure each cond area has an opposite one with the same area
|
#make sure each cond area has an opposite one with the same area
|
||||||
for c in positive:
|
for c in positive:
|
||||||
create_cond_with_same_area_if_none(negative, c)
|
create_cond_with_same_area_if_none(negative, c)
|
||||||
for c in negative:
|
for c in negative:
|
||||||
create_cond_with_same_area_if_none(positive, c)
|
create_cond_with_same_area_if_none(positive, c)
|
||||||
|
|
||||||
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
pre_run_control(self.model_wrap, negative + positive)
|
||||||
|
|
||||||
|
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||||
|
|
||||||
if self.model.is_adm():
|
if self.model.is_adm():
|
||||||
@ -614,7 +677,7 @@ class KSampler:
|
|||||||
elif self.sampler == "ddim":
|
elif self.sampler == "ddim":
|
||||||
timesteps = []
|
timesteps = []
|
||||||
for s in range(sigmas.shape[0]):
|
for s in range(sigmas.shape[0]):
|
||||||
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
|
timesteps.insert(0, self.model_wrap.sigma_to_discrete_timestep(sigmas[s]))
|
||||||
noise_mask = None
|
noise_mask = None
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
noise_mask = 1.0 - denoise_mask
|
noise_mask = 1.0 - denoise_mask
|
||||||
@ -638,7 +701,7 @@ class KSampler:
|
|||||||
x_T=z_enc,
|
x_T=z_enc,
|
||||||
x0=latent_image,
|
x0=latent_image,
|
||||||
img_callback=ddim_callback,
|
img_callback=ddim_callback,
|
||||||
denoise_function=sampling_function,
|
denoise_function=self.model_wrap.predict_eps_discrete_timestep,
|
||||||
extra_args=extra_args,
|
extra_args=extra_args,
|
||||||
mask=noise_mask,
|
mask=noise_mask,
|
||||||
to_zero=sigmas[-1]==0,
|
to_zero=sigmas[-1]==0,
|
||||||
|
|||||||
372
comfy/sd.py
372
comfy/sd.py
@ -70,13 +70,27 @@ def load_lora(lora, to_load):
|
|||||||
alpha = lora[alpha_name].item()
|
alpha = lora[alpha_name].item()
|
||||||
loaded_keys.add(alpha_name)
|
loaded_keys.add(alpha_name)
|
||||||
|
|
||||||
A_name = "{}.lora_up.weight".format(x)
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
B_name = "{}.lora_down.weight".format(x)
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
mid_name = "{}.lora_mid.weight".format(x)
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||||
|
A_name = None
|
||||||
|
|
||||||
if A_name in lora.keys():
|
if regular_lora in lora.keys():
|
||||||
|
A_name = regular_lora
|
||||||
|
B_name = "{}.lora_down.weight".format(x)
|
||||||
|
mid_name = "{}.lora_mid.weight".format(x)
|
||||||
|
elif diffusers_lora in lora.keys():
|
||||||
|
A_name = diffusers_lora
|
||||||
|
B_name = "{}_lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif transformers_lora in lora.keys():
|
||||||
|
A_name = transformers_lora
|
||||||
|
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
|
||||||
|
if A_name is not None:
|
||||||
mid = None
|
mid = None
|
||||||
if mid_name in lora.keys():
|
if mid_name is not None and mid_name in lora.keys():
|
||||||
mid = lora[mid_name]
|
mid = lora[mid_name]
|
||||||
loaded_keys.add(mid_name)
|
loaded_keys.add(mid_name)
|
||||||
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
||||||
@ -170,20 +184,31 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
clip_l_present = True
|
clip_l_present = True
|
||||||
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
if clip_l_present:
|
if clip_l_present:
|
||||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||||
|
key_map[lora_key] = k
|
||||||
|
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||||
|
key_map[lora_key] = k
|
||||||
else:
|
else:
|
||||||
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
|
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
@ -198,10 +223,26 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config)
|
diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config)
|
||||||
for k in diffusers_keys:
|
for k in diffusers_keys:
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
|
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||||
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
|
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||||
|
|
||||||
|
diffusers_lora_prefix = ["", "unet."]
|
||||||
|
for p in diffusers_lora_prefix:
|
||||||
|
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
||||||
|
if diffusers_lora_key.endswith(".to_out.0"):
|
||||||
|
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||||
|
key_map[diffusers_lora_key] = unet_key
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
def set_attr(obj, attr, value):
|
||||||
|
attrs = attr.split(".")
|
||||||
|
for name in attrs[:-1]:
|
||||||
|
obj = getattr(obj, name)
|
||||||
|
prev = getattr(obj, attrs[-1])
|
||||||
|
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
||||||
|
del prev
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0):
|
def __init__(self, model, load_device, offload_device, size=0):
|
||||||
self.size = size
|
self.size = size
|
||||||
@ -330,7 +371,7 @@ class ModelPatcher:
|
|||||||
sd.pop(k)
|
sd.pop(k)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def patch_model(self):
|
def patch_model(self, device_to=None):
|
||||||
model_sd = self.model_state_dict()
|
model_sd = self.model_state_dict()
|
||||||
for key in self.patches:
|
for key in self.patches:
|
||||||
if key not in model_sd:
|
if key not in model_sd:
|
||||||
@ -340,10 +381,14 @@ class ModelPatcher:
|
|||||||
weight = model_sd[key]
|
weight = model_sd[key]
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = weight.clone()
|
self.backup[key] = weight.to(self.offload_device)
|
||||||
|
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
if device_to is not None:
|
||||||
weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
temp_weight = weight.float().to(device_to, copy=True)
|
||||||
|
else:
|
||||||
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||||
|
set_attr(self.model, key, out_weight)
|
||||||
del temp_weight
|
del temp_weight
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
@ -367,15 +412,19 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
||||||
elif len(v) == 4: #lora/locon
|
elif len(v) == 4: #lora/locon
|
||||||
mat1 = v[0]
|
mat1 = v[0].float().to(weight.device)
|
||||||
mat2 = v[1]
|
mat2 = v[1].float().to(weight.device)
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
alpha *= v[2] / mat2.shape[0]
|
alpha *= v[2] / mat2.shape[0]
|
||||||
if v[3] is not None:
|
if v[3] is not None:
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
|
mat3 = v[3].float().to(weight.device)
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
|
try:
|
||||||
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
||||||
|
except Exception as e:
|
||||||
|
print("ERROR", key, e)
|
||||||
elif len(v) == 8: #lokr
|
elif len(v) == 8: #lokr
|
||||||
w1 = v[0]
|
w1 = v[0]
|
||||||
w2 = v[1]
|
w2 = v[1]
|
||||||
@ -389,20 +438,27 @@ class ModelPatcher:
|
|||||||
if w1 is None:
|
if w1 is None:
|
||||||
dim = w1_b.shape[0]
|
dim = w1_b.shape[0]
|
||||||
w1 = torch.mm(w1_a.float(), w1_b.float())
|
w1 = torch.mm(w1_a.float(), w1_b.float())
|
||||||
|
else:
|
||||||
|
w1 = w1.float().to(weight.device)
|
||||||
|
|
||||||
if w2 is None:
|
if w2 is None:
|
||||||
dim = w2_b.shape[0]
|
dim = w2_b.shape[0]
|
||||||
if t2 is None:
|
if t2 is None:
|
||||||
w2 = torch.mm(w2_a.float(), w2_b.float())
|
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
|
||||||
else:
|
else:
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float())
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
|
||||||
|
else:
|
||||||
|
w2 = w2.float().to(weight.device)
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
if len(w2.shape) == 4:
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
if v[2] is not None and dim is not None:
|
if v[2] is not None and dim is not None:
|
||||||
alpha *= v[2] / dim
|
alpha *= v[2] / dim
|
||||||
|
|
||||||
weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
try:
|
||||||
|
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
||||||
|
except Exception as e:
|
||||||
|
print("ERROR", key, e)
|
||||||
else: #loha
|
else: #loha
|
||||||
w1a = v[0]
|
w1a = v[0]
|
||||||
w1b = v[1]
|
w1b = v[1]
|
||||||
@ -413,21 +469,24 @@ class ModelPatcher:
|
|||||||
if v[5] is not None: #cp decomposition
|
if v[5] is not None: #cp decomposition
|
||||||
t1 = v[5]
|
t1 = v[5]
|
||||||
t2 = v[6]
|
t2 = v[6]
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float(), w1b.float(), w1a.float())
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2b.float(), w2a.float())
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
|
||||||
else:
|
else:
|
||||||
m1 = torch.mm(w1a.float(), w1b.float())
|
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
|
||||||
m2 = torch.mm(w2a.float(), w2b.float())
|
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
|
||||||
|
|
||||||
|
try:
|
||||||
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||||
|
except Exception as e:
|
||||||
|
print("ERROR", key, e)
|
||||||
|
|
||||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def unpatch_model(self):
|
def unpatch_model(self):
|
||||||
model_sd = self.model_state_dict()
|
|
||||||
keys = list(self.backup.keys())
|
keys = list(self.backup.keys())
|
||||||
|
|
||||||
for k in keys:
|
for k in keys:
|
||||||
model_sd[k][:] = self.backup[k]
|
set_attr(self.model, k, self.backup[k])
|
||||||
del self.backup[k]
|
|
||||||
|
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
|
|
||||||
@ -647,16 +706,57 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
|||||||
else:
|
else:
|
||||||
return torch.cat([tensor] * batched_number, dim=0)
|
return torch.cat([tensor] * batched_number, dim=0)
|
||||||
|
|
||||||
class ControlNet:
|
class ControlBase:
|
||||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
def __init__(self, device=None):
|
||||||
self.control_model = control_model
|
|
||||||
self.cond_hint_original = None
|
self.cond_hint_original = None
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.strength = 1.0
|
self.strength = 1.0
|
||||||
|
self.timestep_percent_range = (1.0, 0.0)
|
||||||
|
self.timestep_range = None
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
|
|
||||||
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
|
||||||
|
self.cond_hint_original = cond_hint
|
||||||
|
self.strength = strength
|
||||||
|
self.timestep_percent_range = timestep_percent_range
|
||||||
|
return self
|
||||||
|
|
||||||
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
|
self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
|
||||||
|
if self.previous_controlnet is not None:
|
||||||
|
self.previous_controlnet.pre_run(model, percent_to_timestep_function)
|
||||||
|
|
||||||
|
def set_previous_controlnet(self, controlnet):
|
||||||
|
self.previous_controlnet = controlnet
|
||||||
|
return self
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
if self.previous_controlnet is not None:
|
||||||
|
self.previous_controlnet.cleanup()
|
||||||
|
if self.cond_hint is not None:
|
||||||
|
del self.cond_hint
|
||||||
|
self.cond_hint = None
|
||||||
|
self.timestep_range = None
|
||||||
|
|
||||||
|
def get_models(self):
|
||||||
|
out = []
|
||||||
|
if self.previous_controlnet is not None:
|
||||||
|
out += self.previous_controlnet.get_models()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def copy_to(self, c):
|
||||||
|
c.cond_hint_original = self.cond_hint_original
|
||||||
|
c.strength = self.strength
|
||||||
|
c.timestep_percent_range = self.timestep_percent_range
|
||||||
|
|
||||||
|
class ControlNet(ControlBase):
|
||||||
|
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||||
|
super().__init__(device)
|
||||||
|
self.control_model = control_model
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
@ -664,6 +764,13 @@ class ControlNet:
|
|||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||||
|
|
||||||
|
if self.timestep_range is not None:
|
||||||
|
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||||
|
if control_prev is not None:
|
||||||
|
return control_prev
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
output_dtype = x_noisy.dtype
|
output_dtype = x_noisy.dtype
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
@ -711,37 +818,64 @@ class ControlNet:
|
|||||||
out['input'] = control_prev['input']
|
out['input'] = control_prev['input']
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0):
|
|
||||||
self.cond_hint_original = cond_hint
|
|
||||||
self.strength = strength
|
|
||||||
return self
|
|
||||||
|
|
||||||
def set_previous_controlnet(self, controlnet):
|
|
||||||
self.previous_controlnet = controlnet
|
|
||||||
return self
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
if self.previous_controlnet is not None:
|
|
||||||
self.previous_controlnet.cleanup()
|
|
||||||
if self.cond_hint is not None:
|
|
||||||
del self.cond_hint
|
|
||||||
self.cond_hint = None
|
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
||||||
c.cond_hint_original = self.cond_hint_original
|
self.copy_to(c)
|
||||||
c.strength = self.strength
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
out = []
|
out = super().get_models()
|
||||||
if self.previous_controlnet is not None:
|
|
||||||
out += self.previous_controlnet.get_models()
|
|
||||||
out.append(self.control_model)
|
out.append(self.control_model)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def load_controlnet(ckpt_path, model=None):
|
||||||
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
|
|
||||||
|
controlnet_config = None
|
||||||
|
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
||||||
|
use_fp16 = model_management.should_use_fp16()
|
||||||
|
controlnet_config = model_detection.model_config_from_diffusers_unet(controlnet_data, use_fp16).unet_config
|
||||||
|
diffusers_keys = utils.unet_to_diffusers(controlnet_config)
|
||||||
|
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||||
|
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
loop = True
|
||||||
|
while loop:
|
||||||
|
suffix = [".weight", ".bias"]
|
||||||
|
for s in suffix:
|
||||||
|
k_in = "controlnet_down_blocks.{}{}".format(count, s)
|
||||||
|
k_out = "zero_convs.{}.0{}".format(count, s)
|
||||||
|
if k_in not in controlnet_data:
|
||||||
|
loop = False
|
||||||
|
break
|
||||||
|
diffusers_keys[k_in] = k_out
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
loop = True
|
||||||
|
while loop:
|
||||||
|
suffix = [".weight", ".bias"]
|
||||||
|
for s in suffix:
|
||||||
|
if count == 0:
|
||||||
|
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
|
||||||
|
else:
|
||||||
|
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
|
||||||
|
k_out = "input_hint_block.{}{}".format(count * 2, s)
|
||||||
|
if k_in not in controlnet_data:
|
||||||
|
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
|
||||||
|
loop = False
|
||||||
|
diffusers_keys[k_in] = k_out
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
new_sd = {}
|
||||||
|
for k in diffusers_keys:
|
||||||
|
if k in controlnet_data:
|
||||||
|
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
||||||
|
|
||||||
|
controlnet_data = new_sd
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
pth = False
|
pth = False
|
||||||
key = 'zero_convs.0.0.weight'
|
key = 'zero_convs.0.0.weight'
|
||||||
@ -757,11 +891,11 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
|
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
|
||||||
return net
|
return net
|
||||||
|
|
||||||
use_fp16 = model_management.should_use_fp16()
|
if controlnet_config is None:
|
||||||
|
use_fp16 = model_management.should_use_fp16()
|
||||||
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
|
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = 3
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||||
control_model = cldm.ControlNet(**controlnet_config)
|
control_model = cldm.ControlNet(**controlnet_config)
|
||||||
|
|
||||||
if pth:
|
if pth:
|
||||||
@ -799,24 +933,25 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
class T2IAdapter:
|
class T2IAdapter(ControlBase):
|
||||||
def __init__(self, t2i_model, channels_in, device=None):
|
def __init__(self, t2i_model, channels_in, device=None):
|
||||||
|
super().__init__(device)
|
||||||
self.t2i_model = t2i_model
|
self.t2i_model = t2i_model
|
||||||
self.channels_in = channels_in
|
self.channels_in = channels_in
|
||||||
self.strength = 1.0
|
|
||||||
if device is None:
|
|
||||||
device = model_management.get_torch_device()
|
|
||||||
self.device = device
|
|
||||||
self.previous_controlnet = None
|
|
||||||
self.control_input = None
|
self.control_input = None
|
||||||
self.cond_hint_original = None
|
|
||||||
self.cond_hint = None
|
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||||
|
|
||||||
|
if self.timestep_range is not None:
|
||||||
|
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||||
|
if control_prev is not None:
|
||||||
|
return control_prev
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
@ -861,33 +996,11 @@ class T2IAdapter:
|
|||||||
out['output'] = control_prev['output']
|
out['output'] = control_prev['output']
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0):
|
|
||||||
self.cond_hint_original = cond_hint
|
|
||||||
self.strength = strength
|
|
||||||
return self
|
|
||||||
|
|
||||||
def set_previous_controlnet(self, controlnet):
|
|
||||||
self.previous_controlnet = controlnet
|
|
||||||
return self
|
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = T2IAdapter(self.t2i_model, self.channels_in)
|
c = T2IAdapter(self.t2i_model, self.channels_in)
|
||||||
c.cond_hint_original = self.cond_hint_original
|
self.copy_to(c)
|
||||||
c.strength = self.strength
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
if self.previous_controlnet is not None:
|
|
||||||
self.previous_controlnet.cleanup()
|
|
||||||
if self.cond_hint is not None:
|
|
||||||
del self.cond_hint
|
|
||||||
self.cond_hint = None
|
|
||||||
|
|
||||||
def get_models(self):
|
|
||||||
out = []
|
|
||||||
if self.previous_controlnet is not None:
|
|
||||||
out += self.previous_controlnet.get_models()
|
|
||||||
return out
|
|
||||||
|
|
||||||
def load_t2i_adapter(t2i_data):
|
def load_t2i_adapter(t2i_data):
|
||||||
keys = t2i_data.keys()
|
keys = t2i_data.keys()
|
||||||
@ -997,11 +1110,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
if "noise_aug_config" in model_config_params:
|
if "noise_aug_config" in model_config_params:
|
||||||
noise_aug_config = model_config_params["noise_aug_config"]
|
noise_aug_config = model_config_params["noise_aug_config"]
|
||||||
|
|
||||||
v_prediction = False
|
model_type = model_base.ModelType.EPS
|
||||||
|
|
||||||
if "parameterization" in model_config_params:
|
if "parameterization" in model_config_params:
|
||||||
if model_config_params["parameterization"] == "v":
|
if model_config_params["parameterization"] == "v":
|
||||||
v_prediction = True
|
model_type = model_base.ModelType.V_PREDICTION
|
||||||
|
|
||||||
clip = None
|
clip = None
|
||||||
vae = None
|
vae = None
|
||||||
@ -1021,11 +1134,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
||||||
|
|
||||||
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
||||||
model = model_base.SDInpaint(model_config, v_prediction=v_prediction)
|
model = model_base.SDInpaint(model_config, model_type=model_type)
|
||||||
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
||||||
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], v_prediction=v_prediction)
|
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
|
||||||
else:
|
else:
|
||||||
model = model_base.BaseModel(model_config, v_prediction=v_prediction)
|
model = model_base.BaseModel(model_config, model_type=model_type)
|
||||||
|
|
||||||
if fp16:
|
if fp16:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
@ -1087,8 +1200,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
model = model_config.get_model(sd, "model.diffusion_model.")
|
model = model_config.get_model(sd, "model.diffusion_model.", device=offload_device)
|
||||||
model = model.to(offload_device)
|
|
||||||
model.load_model_weights(sd, "model.diffusion_model.")
|
model.load_model_weights(sd, "model.diffusion_model.")
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
@ -1117,66 +1229,24 @@ def load_unet(unet_path): #load unet in diffusers format
|
|||||||
parameters = calculate_parameters(sd, "")
|
parameters = calculate_parameters(sd, "")
|
||||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||||
|
|
||||||
match = {}
|
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
|
||||||
match["context_dim"] = sd["down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
|
if model_config is None:
|
||||||
match["model_channels"] = sd["conv_in.weight"].shape[0]
|
print("ERROR UNSUPPORTED UNET", unet_path)
|
||||||
match["in_channels"] = sd["conv_in.weight"].shape[1]
|
return None
|
||||||
match["adm_in_channels"] = None
|
|
||||||
if "class_embedding.linear_1.weight" in sd:
|
|
||||||
match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1]
|
|
||||||
|
|
||||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
diffusers_keys = utils.unet_to_diffusers(model_config.unet_config)
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320,
|
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
|
||||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
|
|
||||||
|
|
||||||
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
new_sd = {}
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 384,
|
for k in diffusers_keys:
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
if k in sd:
|
||||||
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280}
|
new_sd[diffusers_keys[k]] = sd.pop(k)
|
||||||
|
else:
|
||||||
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
print(diffusers_keys[k], k)
|
||||||
'adm_in_channels': None, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
offload_device = model_management.unet_offload_device()
|
||||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
model = model_config.get_model(new_sd, "")
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
model = model.to(offload_device)
|
||||||
|
model.load_model_weights(new_sd, "")
|
||||||
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
|
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
|
||||||
|
|
||||||
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
|
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
|
||||||
|
|
||||||
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
|
||||||
'adm_in_channels': None, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
|
||||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
|
|
||||||
|
|
||||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
|
|
||||||
print("match", match)
|
|
||||||
for unet_config in supported_models:
|
|
||||||
matches = True
|
|
||||||
for k in match:
|
|
||||||
if match[k] != unet_config[k]:
|
|
||||||
matches = False
|
|
||||||
break
|
|
||||||
if matches:
|
|
||||||
diffusers_keys = utils.unet_to_diffusers(unet_config)
|
|
||||||
new_sd = {}
|
|
||||||
for k in diffusers_keys:
|
|
||||||
if k in sd:
|
|
||||||
new_sd[diffusers_keys[k]] = sd.pop(k)
|
|
||||||
else:
|
|
||||||
print(diffusers_keys[k], k)
|
|
||||||
offload_device = model_management.unet_offload_device()
|
|
||||||
model_config = model_detection.model_config_from_unet_config(unet_config)
|
|
||||||
model = model_config.get_model(new_sd, "")
|
|
||||||
model = model.to(offload_device)
|
|
||||||
model.load_model_weights(new_sd, "")
|
|
||||||
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -91,13 +91,15 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
def set_up_textual_embeddings(self, tokens, current_embeds):
|
def set_up_textual_embeddings(self, tokens, current_embeds):
|
||||||
out_tokens = []
|
out_tokens = []
|
||||||
next_new_token = token_dict_size = current_embeds.weight.shape[0]
|
next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1
|
||||||
embedding_weights = []
|
embedding_weights = []
|
||||||
|
|
||||||
for x in tokens:
|
for x in tokens:
|
||||||
tokens_temp = []
|
tokens_temp = []
|
||||||
for y in x:
|
for y in x:
|
||||||
if isinstance(y, int):
|
if isinstance(y, int):
|
||||||
|
if y == token_dict_size: #EOS token
|
||||||
|
y = -1
|
||||||
tokens_temp += [y]
|
tokens_temp += [y]
|
||||||
else:
|
else:
|
||||||
if y.shape[0] == current_embeds.weight.shape[1]:
|
if y.shape[0] == current_embeds.weight.shape[1]:
|
||||||
@ -110,15 +112,21 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
tokens_temp += [self.empty_tokens[0][-1]]
|
tokens_temp += [self.empty_tokens[0][-1]]
|
||||||
out_tokens += [tokens_temp]
|
out_tokens += [tokens_temp]
|
||||||
|
|
||||||
|
n = token_dict_size
|
||||||
if len(embedding_weights) > 0:
|
if len(embedding_weights) > 0:
|
||||||
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
||||||
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
|
new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1]
|
||||||
n = token_dict_size
|
|
||||||
for x in embedding_weights:
|
for x in embedding_weights:
|
||||||
new_embedding.weight[n] = x
|
new_embedding.weight[n] = x
|
||||||
n += 1
|
n += 1
|
||||||
|
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding
|
||||||
self.transformer.set_input_embeddings(new_embedding)
|
self.transformer.set_input_embeddings(new_embedding)
|
||||||
return out_tokens
|
|
||||||
|
processed_tokens = []
|
||||||
|
for x in out_tokens:
|
||||||
|
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
|
||||||
|
|
||||||
|
return processed_tokens
|
||||||
|
|
||||||
def forward(self, tokens):
|
def forward(self, tokens):
|
||||||
backup_embeds = self.transformer.get_input_embeddings()
|
backup_embeds = self.transformer.get_input_embeddings()
|
||||||
|
|||||||
@ -53,13 +53,13 @@ class SD20(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.SD15
|
latent_format = latent_formats.SD15
|
||||||
|
|
||||||
def v_prediction(self, state_dict, prefix=""):
|
def model_type(self, state_dict, prefix=""):
|
||||||
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
||||||
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
|
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
|
||||||
out = state_dict[k]
|
out = state_dict[k]
|
||||||
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
||||||
return True
|
return model_base.ModelType.V_PREDICTION
|
||||||
return False
|
return model_base.ModelType.EPS
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
||||||
@ -109,8 +109,8 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.SDXL
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix=""):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
return model_base.SDXLRefiner(self)
|
return model_base.SDXLRefiner(self, device=device)
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
keys_to_replace = {}
|
keys_to_replace = {}
|
||||||
@ -126,7 +126,8 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||||||
def process_clip_state_dict_for_saving(self, state_dict):
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
replace_prefix = {}
|
replace_prefix = {}
|
||||||
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||||
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
|
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
|
||||||
|
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
|
||||||
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
|
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
|
||||||
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||||
return state_dict_g
|
return state_dict_g
|
||||||
@ -145,8 +146,14 @@ class SDXL(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.SDXL
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix=""):
|
def model_type(self, state_dict, prefix=""):
|
||||||
return model_base.SDXL(self)
|
if "v_pred" in state_dict:
|
||||||
|
return model_base.ModelType.V_PREDICTION
|
||||||
|
else:
|
||||||
|
return model_base.ModelType.EPS
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
keys_to_replace = {}
|
keys_to_replace = {}
|
||||||
@ -165,7 +172,8 @@ class SDXL(supported_models_base.BASE):
|
|||||||
replace_prefix = {}
|
replace_prefix = {}
|
||||||
keys_to_replace = {}
|
keys_to_replace = {}
|
||||||
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||||
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
|
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
|
||||||
|
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
|
||||||
for k in state_dict:
|
for k in state_dict:
|
||||||
if k.startswith("clip_l"):
|
if k.startswith("clip_l"):
|
||||||
state_dict_g[k] = state_dict[k]
|
state_dict_g[k] = state_dict[k]
|
||||||
|
|||||||
@ -41,8 +41,8 @@ class BASE:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def v_prediction(self, state_dict, prefix=""):
|
def model_type(self, state_dict, prefix=""):
|
||||||
return False
|
return model_base.ModelType.EPS
|
||||||
|
|
||||||
def inpaint_model(self):
|
def inpaint_model(self):
|
||||||
return self.unet_config["in_channels"] > 4
|
return self.unet_config["in_channels"] > 4
|
||||||
@ -53,13 +53,13 @@ class BASE:
|
|||||||
for x in self.unet_extra_config:
|
for x in self.unet_extra_config:
|
||||||
self.unet_config[x] = self.unet_extra_config[x]
|
self.unet_config[x] = self.unet_extra_config[x]
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix=""):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
if self.inpaint_model():
|
if self.inpaint_model():
|
||||||
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict, prefix))
|
return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||||
elif self.noise_aug_config is not None:
|
elif self.noise_aug_config is not None:
|
||||||
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict, prefix))
|
return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
|
||||||
else:
|
else:
|
||||||
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict, prefix))
|
return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|||||||
@ -4,18 +4,20 @@ import struct
|
|||||||
import comfy.checkpoint_pickle
|
import comfy.checkpoint_pickle
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False):
|
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||||
|
if device is None:
|
||||||
|
device = torch.device("cpu")
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
||||||
else:
|
else:
|
||||||
if safe_load:
|
if safe_load:
|
||||||
if not 'weights_only' in torch.load.__code__.co_varnames:
|
if not 'weights_only' in torch.load.__code__.co_varnames:
|
||||||
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
||||||
safe_load = False
|
safe_load = False
|
||||||
if safe_load:
|
if safe_load:
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu", pickle_module=comfy.checkpoint_pickle)
|
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
if "state_dict" in pl_sd:
|
if "state_dict" in pl_sd:
|
||||||
@ -118,20 +120,24 @@ UNET_MAP_RESNET = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
UNET_MAP_BASIC = {
|
UNET_MAP_BASIC = {
|
||||||
"label_emb.0.0.weight": "class_embedding.linear_1.weight",
|
("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
|
||||||
"label_emb.0.0.bias": "class_embedding.linear_1.bias",
|
("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
|
||||||
"label_emb.0.2.weight": "class_embedding.linear_2.weight",
|
("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
|
||||||
"label_emb.0.2.bias": "class_embedding.linear_2.bias",
|
("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
|
||||||
"input_blocks.0.0.weight": "conv_in.weight",
|
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
||||||
"input_blocks.0.0.bias": "conv_in.bias",
|
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
||||||
"out.0.weight": "conv_norm_out.weight",
|
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
||||||
"out.0.bias": "conv_norm_out.bias",
|
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
||||||
"out.2.weight": "conv_out.weight",
|
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||||
"out.2.bias": "conv_out.bias",
|
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||||
"time_embed.0.weight": "time_embedding.linear_1.weight",
|
("out.0.weight", "conv_norm_out.weight"),
|
||||||
"time_embed.0.bias": "time_embedding.linear_1.bias",
|
("out.0.bias", "conv_norm_out.bias"),
|
||||||
"time_embed.2.weight": "time_embedding.linear_2.weight",
|
("out.2.weight", "conv_out.weight"),
|
||||||
"time_embed.2.bias": "time_embedding.linear_2.bias"
|
("out.2.bias", "conv_out.bias"),
|
||||||
|
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||||
|
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||||
|
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||||
|
("time_embed.2.bias", "time_embedding.linear_2.bias")
|
||||||
}
|
}
|
||||||
|
|
||||||
def unet_to_diffusers(unet_config):
|
def unet_to_diffusers(unet_config):
|
||||||
@ -206,7 +212,7 @@ def unet_to_diffusers(unet_config):
|
|||||||
n += 1
|
n += 1
|
||||||
|
|
||||||
for k in UNET_MAP_BASIC:
|
for k in UNET_MAP_BASIC:
|
||||||
diffusers_unet_map[UNET_MAP_BASIC[k]] = k
|
diffusers_unet_map[k[1]] = k[0]
|
||||||
|
|
||||||
return diffusers_unet_map
|
return diffusers_unet_map
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,35 @@ import torch
|
|||||||
|
|
||||||
from nodes import MAX_RESOLUTION
|
from nodes import MAX_RESOLUTION
|
||||||
|
|
||||||
|
def composite(destination, source, x, y, mask = None, multiplier = 8):
|
||||||
|
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
|
||||||
|
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
|
||||||
|
|
||||||
|
left, top = (x // multiplier, y // multiplier)
|
||||||
|
right, bottom = (left + source.shape[3], top + source.shape[2],)
|
||||||
|
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
mask = torch.ones_like(source)
|
||||||
|
else:
|
||||||
|
mask = mask.clone()
|
||||||
|
mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear")
|
||||||
|
mask = mask.repeat((source.shape[0], source.shape[1], 1, 1))
|
||||||
|
|
||||||
|
# calculate the bounds of the source that will be overlapping the destination
|
||||||
|
# this prevents the source trying to overwrite latent pixels that are out of bounds
|
||||||
|
# of the destination
|
||||||
|
visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)
|
||||||
|
|
||||||
|
mask = mask[:, :, :visible_height, :visible_width]
|
||||||
|
inverse_mask = torch.ones_like(mask) - mask
|
||||||
|
|
||||||
|
source_portion = mask * source[:, :, :visible_height, :visible_width]
|
||||||
|
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
|
||||||
|
|
||||||
|
destination[:, :, top:bottom, left:right] = source_portion + destination_portion
|
||||||
|
return destination
|
||||||
|
|
||||||
class LatentCompositeMasked:
|
class LatentCompositeMasked:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -25,36 +54,31 @@ class LatentCompositeMasked:
|
|||||||
output = destination.copy()
|
output = destination.copy()
|
||||||
destination = destination["samples"].clone()
|
destination = destination["samples"].clone()
|
||||||
source = source["samples"]
|
source = source["samples"]
|
||||||
|
output["samples"] = composite(destination, source, x, y, mask, 8)
|
||||||
|
return (output,)
|
||||||
|
|
||||||
x = max(-source.shape[3] * 8, min(x, destination.shape[3] * 8))
|
class ImageCompositeMasked:
|
||||||
y = max(-source.shape[2] * 8, min(y, destination.shape[2] * 8))
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"destination": ("IMAGE",),
|
||||||
|
"source": ("IMAGE",),
|
||||||
|
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||||
|
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"mask": ("MASK",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "composite"
|
||||||
|
|
||||||
left, top = (x // 8, y // 8)
|
CATEGORY = "image"
|
||||||
right, bottom = (left + source.shape[3], top + source.shape[2],)
|
|
||||||
|
|
||||||
|
|
||||||
if mask is None:
|
|
||||||
mask = torch.ones_like(source)
|
|
||||||
else:
|
|
||||||
mask = mask.clone()
|
|
||||||
mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear")
|
|
||||||
mask = mask.repeat((source.shape[0], source.shape[1], 1, 1))
|
|
||||||
|
|
||||||
# calculate the bounds of the source that will be overlapping the destination
|
|
||||||
# this prevents the source trying to overwrite latent pixels that are out of bounds
|
|
||||||
# of the destination
|
|
||||||
visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)
|
|
||||||
|
|
||||||
mask = mask[:, :, :visible_height, :visible_width]
|
|
||||||
inverse_mask = torch.ones_like(mask) - mask
|
|
||||||
|
|
||||||
source_portion = mask * source[:, :, :visible_height, :visible_width]
|
|
||||||
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
|
|
||||||
|
|
||||||
destination[:, :, top:bottom, left:right] = source_portion + destination_portion
|
|
||||||
|
|
||||||
output["samples"] = destination
|
|
||||||
|
|
||||||
|
def composite(self, destination, source, x, y, mask = None):
|
||||||
|
destination = destination.clone().movedim(-1, 1)
|
||||||
|
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1).movedim(1, -1)
|
||||||
return (output,)
|
return (output,)
|
||||||
|
|
||||||
class MaskToImage:
|
class MaskToImage:
|
||||||
@ -253,6 +277,7 @@ class FeatherMask:
|
|||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"LatentCompositeMasked": LatentCompositeMasked,
|
"LatentCompositeMasked": LatentCompositeMasked,
|
||||||
|
"ImageCompositeMasked": ImageCompositeMasked,
|
||||||
"MaskToImage": MaskToImage,
|
"MaskToImage": MaskToImage,
|
||||||
"ImageToMask": ImageToMask,
|
"ImageToMask": ImageToMask,
|
||||||
"SolidMask": SolidMask,
|
"SolidMask": SolidMask,
|
||||||
|
|||||||
@ -1,9 +1,13 @@
|
|||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.model_base
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
class ModelMergeSimple:
|
class ModelMergeSimple:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -99,10 +103,36 @@ class CheckpointSave:
|
|||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
prompt_info = json.dumps(prompt)
|
prompt_info = json.dumps(prompt)
|
||||||
|
|
||||||
metadata = {"prompt": prompt_info}
|
metadata = {}
|
||||||
if extra_pnginfo is not None:
|
|
||||||
for x in extra_pnginfo:
|
enable_modelspec = True
|
||||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
if isinstance(model.model, comfy.model_base.SDXL):
|
||||||
|
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
|
||||||
|
elif isinstance(model.model, comfy.model_base.SDXLRefiner):
|
||||||
|
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
|
||||||
|
else:
|
||||||
|
enable_modelspec = False
|
||||||
|
|
||||||
|
if enable_modelspec:
|
||||||
|
metadata["modelspec.sai_model_spec"] = "1.0.0"
|
||||||
|
metadata["modelspec.implementation"] = "sgm"
|
||||||
|
metadata["modelspec.title"] = "{} {}".format(filename, counter)
|
||||||
|
|
||||||
|
#TODO:
|
||||||
|
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
|
||||||
|
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
|
||||||
|
# "v2-inpainting"
|
||||||
|
|
||||||
|
if model.model.model_type == comfy.model_base.ModelType.EPS:
|
||||||
|
metadata["modelspec.predict_key"] = "epsilon"
|
||||||
|
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
|
||||||
|
metadata["modelspec.predict_key"] = "v"
|
||||||
|
|
||||||
|
if not args.disable_metadata:
|
||||||
|
metadata["prompt"] = prompt_info
|
||||||
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|||||||
@ -59,8 +59,8 @@ class Blend:
|
|||||||
def g(self, x):
|
def g(self, x):
|
||||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||||
|
|
||||||
def gaussian_kernel(kernel_size: int, sigma: float):
|
def gaussian_kernel(kernel_size: int, sigma: float, device=None):
|
||||||
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
|
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
|
||||||
d = torch.sqrt(x * x + y * y)
|
d = torch.sqrt(x * x + y * y)
|
||||||
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||||
return g / g.sum()
|
return g / g.sum()
|
||||||
@ -101,7 +101,7 @@ class Blur:
|
|||||||
batch_size, height, width, channels = image.shape
|
batch_size, height, width, channels = image.shape
|
||||||
|
|
||||||
kernel_size = blur_radius * 2 + 1
|
kernel_size = blur_radius * 2 + 1
|
||||||
kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
|
kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1)
|
||||||
|
|
||||||
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||||
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
|
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
|
||||||
|
|||||||
@ -37,12 +37,23 @@ class ImageUpscaleWithModel:
|
|||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
upscale_model.to(device)
|
upscale_model.to(device)
|
||||||
in_img = image.movedim(-1,-3).to(device)
|
in_img = image.movedim(-1,-3).to(device)
|
||||||
|
free_memory = model_management.get_free_memory(device)
|
||||||
|
|
||||||
|
tile = 512
|
||||||
|
overlap = 32
|
||||||
|
|
||||||
|
oom = True
|
||||||
|
while oom:
|
||||||
|
try:
|
||||||
|
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||||
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||||
|
oom = False
|
||||||
|
except model_management.OOM_EXCEPTION as e:
|
||||||
|
tile //= 2
|
||||||
|
if tile < 128:
|
||||||
|
raise e
|
||||||
|
|
||||||
tile = 128 + 64
|
|
||||||
overlap = 8
|
|
||||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
|
||||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
|
||||||
upscale_model.cpu()
|
upscale_model.cpu()
|
||||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|||||||
84
cuda_malloc.py
Normal file
84
cuda_malloc.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import os
|
||||||
|
import importlib.util
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
|
||||||
|
def get_gpu_names():
|
||||||
|
if os.name == 'nt':
|
||||||
|
import ctypes
|
||||||
|
|
||||||
|
# Define necessary C structures and types
|
||||||
|
class DISPLAY_DEVICEA(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
('cb', ctypes.c_ulong),
|
||||||
|
('DeviceName', ctypes.c_char * 32),
|
||||||
|
('DeviceString', ctypes.c_char * 128),
|
||||||
|
('StateFlags', ctypes.c_ulong),
|
||||||
|
('DeviceID', ctypes.c_char * 128),
|
||||||
|
('DeviceKey', ctypes.c_char * 128)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Load user32.dll
|
||||||
|
user32 = ctypes.windll.user32
|
||||||
|
|
||||||
|
# Call EnumDisplayDevicesA
|
||||||
|
def enum_display_devices():
|
||||||
|
device_info = DISPLAY_DEVICEA()
|
||||||
|
device_info.cb = ctypes.sizeof(device_info)
|
||||||
|
device_index = 0
|
||||||
|
gpu_names = set()
|
||||||
|
|
||||||
|
while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0):
|
||||||
|
device_index += 1
|
||||||
|
gpu_names.add(device_info.DeviceString.decode('utf-8'))
|
||||||
|
return gpu_names
|
||||||
|
return enum_display_devices()
|
||||||
|
else:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
|
||||||
|
"GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620",
|
||||||
|
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
|
||||||
|
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000",
|
||||||
|
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
|
||||||
|
"GeForce GTX 1650", "GeForce GTX 1630"
|
||||||
|
}
|
||||||
|
|
||||||
|
def cuda_malloc_supported():
|
||||||
|
try:
|
||||||
|
names = get_gpu_names()
|
||||||
|
except:
|
||||||
|
names = set()
|
||||||
|
for x in names:
|
||||||
|
if "NVIDIA" in x:
|
||||||
|
for b in blacklist:
|
||||||
|
if b in x:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
if not args.cuda_malloc:
|
||||||
|
try:
|
||||||
|
version = ""
|
||||||
|
torch_spec = importlib.util.find_spec("torch")
|
||||||
|
for folder in torch_spec.submodule_search_locations:
|
||||||
|
ver_file = os.path.join(folder, "version.py")
|
||||||
|
if os.path.isfile(ver_file):
|
||||||
|
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
version = module.__version__
|
||||||
|
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
|
||||||
|
args.cuda_malloc = cuda_malloc_supported()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if args.cuda_malloc and not args.disable_cuda_malloc:
|
||||||
|
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
||||||
|
if env_var is None:
|
||||||
|
env_var = "backend:cudaMallocAsync"
|
||||||
|
else:
|
||||||
|
env_var += ",backend:cudaMallocAsync"
|
||||||
|
|
||||||
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
||||||
@ -51,9 +51,10 @@ class Example:
|
|||||||
"default": 0,
|
"default": 0,
|
||||||
"min": 0, #Minimum value
|
"min": 0, #Minimum value
|
||||||
"max": 4096, #Maximum value
|
"max": 4096, #Maximum value
|
||||||
"step": 64 #Slider's step
|
"step": 64, #Slider's step
|
||||||
|
"display": "number" # Cosmetic only: display as "number" or "slider"
|
||||||
}),
|
}),
|
||||||
"float_field": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
"float_field": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "display": "number"}),
|
||||||
"print_to_screen": (["enable", "disable"],),
|
"print_to_screen": (["enable", "disable"],),
|
||||||
"string_field": ("STRING", {
|
"string_field": ("STRING", {
|
||||||
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
|
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
|
||||||
|
|||||||
18
execution.py
18
execution.py
@ -6,7 +6,6 @@ import threading
|
|||||||
import heapq
|
import heapq
|
||||||
import traceback
|
import traceback
|
||||||
import gc
|
import gc
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import nodes
|
import nodes
|
||||||
@ -43,11 +42,14 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
|||||||
|
|
||||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||||
# check if node wants the lists
|
# check if node wants the lists
|
||||||
intput_is_list = False
|
input_is_list = False
|
||||||
if hasattr(obj, "INPUT_IS_LIST"):
|
if hasattr(obj, "INPUT_IS_LIST"):
|
||||||
intput_is_list = obj.INPUT_IS_LIST
|
input_is_list = obj.INPUT_IS_LIST
|
||||||
|
|
||||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
if len(input_data_all) == 0:
|
||||||
|
max_len_input = 0
|
||||||
|
else:
|
||||||
|
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||||
|
|
||||||
# get a slice of inputs, repeat last input when list isn't long enough
|
# get a slice of inputs, repeat last input when list isn't long enough
|
||||||
def slice_dict(d, i):
|
def slice_dict(d, i):
|
||||||
@ -57,11 +59,15 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
|||||||
return d_new
|
return d_new
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
if intput_is_list:
|
if input_is_list:
|
||||||
if allow_interrupt:
|
if allow_interrupt:
|
||||||
nodes.before_node_execution()
|
nodes.before_node_execution()
|
||||||
results.append(getattr(obj, func)(**input_data_all))
|
results.append(getattr(obj, func)(**input_data_all))
|
||||||
else:
|
elif max_len_input == 0:
|
||||||
|
if allow_interrupt:
|
||||||
|
nodes.before_node_execution()
|
||||||
|
results.append(getattr(obj, func)())
|
||||||
|
else:
|
||||||
for i in range(max_len_input):
|
for i in range(max_len_input):
|
||||||
if allow_interrupt:
|
if allow_interrupt:
|
||||||
nodes.before_node_execution()
|
nodes.before_node_execution()
|
||||||
|
|||||||
@ -43,6 +43,10 @@ def set_output_directory(output_dir):
|
|||||||
global output_directory
|
global output_directory
|
||||||
output_directory = output_dir
|
output_directory = output_dir
|
||||||
|
|
||||||
|
def set_temp_directory(temp_dir):
|
||||||
|
global temp_directory
|
||||||
|
temp_directory = temp_dir
|
||||||
|
|
||||||
def get_output_directory():
|
def get_output_directory():
|
||||||
global output_directory
|
global output_directory
|
||||||
return output_directory
|
return output_directory
|
||||||
@ -111,6 +115,8 @@ def add_model_folder_path(folder_name, full_folder_path):
|
|||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
if folder_name in folder_names_and_paths:
|
if folder_name in folder_names_and_paths:
|
||||||
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
||||||
|
else:
|
||||||
|
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
||||||
|
|
||||||
def get_folder_paths(folder_name):
|
def get_folder_paths(folder_name):
|
||||||
return folder_names_and_paths[folder_name][0][:]
|
return folder_names_and_paths[folder_name][0][:]
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image
|
||||||
from io import BytesIO
|
|
||||||
import struct
|
import struct
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from comfy.cli_args import args, LatentPreviewMethod
|
from comfy.cli_args import args, LatentPreviewMethod
|
||||||
@ -15,26 +14,7 @@ class LatentPreviewer:
|
|||||||
|
|
||||||
def decode_latent_to_preview_image(self, preview_format, x0):
|
def decode_latent_to_preview_image(self, preview_format, x0):
|
||||||
preview_image = self.decode_latent_to_preview(x0)
|
preview_image = self.decode_latent_to_preview(x0)
|
||||||
|
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
|
||||||
if hasattr(Image, 'Resampling'):
|
|
||||||
resampling = Image.Resampling.BILINEAR
|
|
||||||
else:
|
|
||||||
resampling = Image.ANTIALIAS
|
|
||||||
|
|
||||||
preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), resampling)
|
|
||||||
|
|
||||||
preview_type = 1
|
|
||||||
if preview_format == "JPEG":
|
|
||||||
preview_type = 1
|
|
||||||
elif preview_format == "PNG":
|
|
||||||
preview_type = 2
|
|
||||||
|
|
||||||
bytesIO = BytesIO()
|
|
||||||
header = struct.pack(">I", preview_type)
|
|
||||||
bytesIO.write(header)
|
|
||||||
preview_image.save(bytesIO, format=preview_format, quality=95)
|
|
||||||
preview_bytes = bytesIO.getvalue()
|
|
||||||
return preview_bytes
|
|
||||||
|
|
||||||
class TAESDPreviewerImpl(LatentPreviewer):
|
class TAESDPreviewerImpl(LatentPreviewer):
|
||||||
def __init__(self, taesd):
|
def __init__(self, taesd):
|
||||||
|
|||||||
31
main.py
31
main.py
@ -51,7 +51,6 @@ import threading
|
|||||||
import gc
|
import gc
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import comfy.utils
|
|
||||||
|
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
import logging
|
import logging
|
||||||
@ -62,7 +61,9 @@ if __name__ == "__main__":
|
|||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||||
print("Set cuda device to:", args.cuda_device)
|
print("Set cuda device to:", args.cuda_device)
|
||||||
|
|
||||||
|
import cuda_malloc
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
import execution
|
import execution
|
||||||
@ -71,6 +72,17 @@ from server import BinaryEventTypes
|
|||||||
from nodes import init_custom_nodes
|
from nodes import init_custom_nodes
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
|
def cuda_malloc_warning():
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
device_name = comfy.model_management.get_torch_device_name(device)
|
||||||
|
cuda_malloc_warning = False
|
||||||
|
if "cudaMallocAsync" in device_name:
|
||||||
|
for b in cuda_malloc.blacklist:
|
||||||
|
if b in device_name:
|
||||||
|
cuda_malloc_warning = True
|
||||||
|
if cuda_malloc_warning:
|
||||||
|
print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||||
|
|
||||||
def prompt_worker(q, server):
|
def prompt_worker(q, server):
|
||||||
e = execution.PromptExecutor(server)
|
e = execution.PromptExecutor(server)
|
||||||
while True:
|
while True:
|
||||||
@ -91,15 +103,15 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
|||||||
|
|
||||||
|
|
||||||
def hijack_progress(server):
|
def hijack_progress(server):
|
||||||
def hook(value, total, preview_image_bytes):
|
def hook(value, total, preview_image):
|
||||||
server.send_sync("progress", {"value": value, "max": total}, server.client_id)
|
server.send_sync("progress", {"value": value, "max": total}, server.client_id)
|
||||||
if preview_image_bytes is not None:
|
if preview_image is not None:
|
||||||
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id)
|
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
|
||||||
comfy.utils.set_progress_bar_global_hook(hook)
|
comfy.utils.set_progress_bar_global_hook(hook)
|
||||||
|
|
||||||
|
|
||||||
def cleanup_temp():
|
def cleanup_temp():
|
||||||
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
temp_dir = folder_paths.get_temp_directory()
|
||||||
if os.path.exists(temp_dir):
|
if os.path.exists(temp_dir):
|
||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
@ -126,6 +138,10 @@ def load_extra_path_config(yaml_path):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
if args.temp_directory:
|
||||||
|
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
||||||
|
print(f"Setting temp directory to: {temp_dir}")
|
||||||
|
folder_paths.set_temp_directory(temp_dir)
|
||||||
cleanup_temp()
|
cleanup_temp()
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
@ -142,6 +158,9 @@ if __name__ == "__main__":
|
|||||||
load_extra_path_config(config_path)
|
load_extra_path_config(config_path)
|
||||||
|
|
||||||
init_custom_nodes()
|
init_custom_nodes()
|
||||||
|
|
||||||
|
cuda_malloc_warning()
|
||||||
|
|
||||||
server.add_routes()
|
server.add_routes()
|
||||||
hijack_progress(server)
|
hijack_progress(server)
|
||||||
|
|
||||||
@ -159,6 +178,8 @@ if __name__ == "__main__":
|
|||||||
if args.auto_launch:
|
if args.auto_launch:
|
||||||
def startup_server(address, port):
|
def startup_server(address, port):
|
||||||
import webbrowser
|
import webbrowser
|
||||||
|
if os.name == 'nt' and address == '0.0.0.0':
|
||||||
|
address = '127.0.0.1'
|
||||||
webbrowser.open(f"http://{address}:{port}")
|
webbrowser.open(f"http://{address}:{port}")
|
||||||
call_on_start = startup_server
|
call_on_start = startup_server
|
||||||
|
|
||||||
|
|||||||
177
nodes.py
177
nodes.py
@ -26,6 +26,8 @@ import comfy.utils
|
|||||||
import comfy.clip_vision
|
import comfy.clip_vision
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
@ -204,6 +206,28 @@ class ConditioningZeroOut:
|
|||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
class ConditioningSetTimestepRange:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||||
|
"start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "set_range"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning"
|
||||||
|
|
||||||
|
def set_range(self, conditioning, start, end):
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
d = t[1].copy()
|
||||||
|
d['start_percent'] = 1.0 - start
|
||||||
|
d['end_percent'] = 1.0 - end
|
||||||
|
n = [t[0], d]
|
||||||
|
c.append(n)
|
||||||
|
return (c, )
|
||||||
|
|
||||||
class VAEDecode:
|
class VAEDecode:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -330,12 +354,22 @@ class SaveLatent:
|
|||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
prompt_info = json.dumps(prompt)
|
prompt_info = json.dumps(prompt)
|
||||||
|
|
||||||
metadata = {"prompt": prompt_info}
|
metadata = None
|
||||||
if extra_pnginfo is not None:
|
if not args.disable_metadata:
|
||||||
for x in extra_pnginfo:
|
metadata = {"prompt": prompt_info}
|
||||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
file = f"{filename}_{counter:05}_.latent"
|
file = f"{filename}_{counter:05}_.latent"
|
||||||
|
|
||||||
|
results = list()
|
||||||
|
results.append({
|
||||||
|
"filename": file,
|
||||||
|
"subfolder": subfolder,
|
||||||
|
"type": "output"
|
||||||
|
})
|
||||||
|
|
||||||
file = os.path.join(full_output_folder, file)
|
file = os.path.join(full_output_folder, file)
|
||||||
|
|
||||||
output = {}
|
output = {}
|
||||||
@ -343,7 +377,7 @@ class SaveLatent:
|
|||||||
output["latent_format_version_0"] = torch.tensor([])
|
output["latent_format_version_0"] = torch.tensor([])
|
||||||
|
|
||||||
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
||||||
return {}
|
return { "ui": { "latents": results } }
|
||||||
|
|
||||||
|
|
||||||
class LoadLatent:
|
class LoadLatent:
|
||||||
@ -497,7 +531,9 @@ class LoraLoader:
|
|||||||
if self.loaded_lora[0] == lora_path:
|
if self.loaded_lora[0] == lora_path:
|
||||||
lora = self.loaded_lora[1]
|
lora = self.loaded_lora[1]
|
||||||
else:
|
else:
|
||||||
del self.loaded_lora
|
temp = self.loaded_lora
|
||||||
|
self.loaded_lora = None
|
||||||
|
del temp
|
||||||
|
|
||||||
if lora is None:
|
if lora is None:
|
||||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||||
@ -578,9 +614,58 @@ class ControlNetApply:
|
|||||||
if 'control' in t[1]:
|
if 'control' in t[1]:
|
||||||
c_net.set_previous_controlnet(t[1]['control'])
|
c_net.set_previous_controlnet(t[1]['control'])
|
||||||
n[1]['control'] = c_net
|
n[1]['control'] = c_net
|
||||||
|
n[1]['control_apply_to_uncond'] = True
|
||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetApplyAdvanced:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"control_net": ("CONTROL_NET", ),
|
||||||
|
"image": ("IMAGE", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
||||||
|
RETURN_NAMES = ("positive", "negative")
|
||||||
|
FUNCTION = "apply_controlnet"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
|
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent):
|
||||||
|
if strength == 0:
|
||||||
|
return (positive, negative)
|
||||||
|
|
||||||
|
control_hint = image.movedim(-1,1)
|
||||||
|
cnets = {}
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for conditioning in [positive, negative]:
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
d = t[1].copy()
|
||||||
|
|
||||||
|
prev_cnet = d.get('control', None)
|
||||||
|
if prev_cnet in cnets:
|
||||||
|
c_net = cnets[prev_cnet]
|
||||||
|
else:
|
||||||
|
c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent))
|
||||||
|
c_net.set_previous_controlnet(prev_cnet)
|
||||||
|
cnets[prev_cnet] = c_net
|
||||||
|
|
||||||
|
d['control'] = c_net
|
||||||
|
d['control_apply_to_uncond'] = False
|
||||||
|
n = [t[0], d]
|
||||||
|
c.append(n)
|
||||||
|
out.append(c)
|
||||||
|
return (out[0], out[1])
|
||||||
|
|
||||||
|
|
||||||
class UNETLoader:
|
class UNETLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -686,7 +771,7 @@ class StyleModelApply:
|
|||||||
CATEGORY = "conditioning/style_model"
|
CATEGORY = "conditioning/style_model"
|
||||||
|
|
||||||
def apply_stylemodel(self, clip_vision_output, style_model, conditioning):
|
def apply_stylemodel(self, clip_vision_output, style_model, conditioning):
|
||||||
cond = style_model.get_cond(clip_vision_output)
|
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
|
||||||
c = []
|
c = []
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
|
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
|
||||||
@ -970,6 +1055,47 @@ class LatentComposite:
|
|||||||
samples_out["samples"] = s
|
samples_out["samples"] = s
|
||||||
return (samples_out,)
|
return (samples_out,)
|
||||||
|
|
||||||
|
class LatentBlend:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"samples1": ("LATENT",),
|
||||||
|
"samples2": ("LATENT",),
|
||||||
|
"blend_factor": ("FLOAT", {
|
||||||
|
"default": 0.5,
|
||||||
|
"min": 0,
|
||||||
|
"max": 1,
|
||||||
|
"step": 0.01
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "blend"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
|
||||||
|
|
||||||
|
samples_out = samples1.copy()
|
||||||
|
samples1 = samples1["samples"]
|
||||||
|
samples2 = samples2["samples"]
|
||||||
|
|
||||||
|
if samples1.shape != samples2.shape:
|
||||||
|
samples2.permute(0, 3, 1, 2)
|
||||||
|
samples2 = comfy.utils.common_upscale(samples2, samples1.shape[3], samples1.shape[2], 'bicubic', crop='center')
|
||||||
|
samples2.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
samples_blended = self.blend_mode(samples1, samples2, blend_mode)
|
||||||
|
samples_blended = samples1 * blend_factor + samples_blended * (1 - blend_factor)
|
||||||
|
samples_out["samples"] = samples_blended
|
||||||
|
return (samples_out,)
|
||||||
|
|
||||||
|
def blend_mode(self, img1, img2, mode):
|
||||||
|
if mode == "normal":
|
||||||
|
return img2
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported blend mode: {mode}")
|
||||||
|
|
||||||
class LatentCrop:
|
class LatentCrop:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -1141,12 +1267,14 @@ class SaveImage:
|
|||||||
for image in images:
|
for image in images:
|
||||||
i = 255. * image.cpu().numpy()
|
i = 255. * image.cpu().numpy()
|
||||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||||
metadata = PngInfo()
|
metadata = None
|
||||||
if prompt is not None:
|
if not args.disable_metadata:
|
||||||
metadata.add_text("prompt", json.dumps(prompt))
|
metadata = PngInfo()
|
||||||
if extra_pnginfo is not None:
|
if prompt is not None:
|
||||||
for x in extra_pnginfo:
|
metadata.add_text("prompt", json.dumps(prompt))
|
||||||
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
||||||
|
|
||||||
file = f"{filename}_{counter:05}_.png"
|
file = f"{filename}_{counter:05}_.png"
|
||||||
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
|
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
|
||||||
@ -1320,6 +1448,22 @@ class ImageInvert:
|
|||||||
s = 1.0 - image
|
s = 1.0 - image
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
|
class ImageBatch:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "image1": ("IMAGE",), "image2": ("IMAGE",)}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "batch"
|
||||||
|
|
||||||
|
CATEGORY = "image"
|
||||||
|
|
||||||
|
def batch(self, image1, image2):
|
||||||
|
if image1.shape[1:] != image2.shape[1:]:
|
||||||
|
image2 = comfy.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1)
|
||||||
|
s = torch.cat((image1, image2), dim=0)
|
||||||
|
return (s,)
|
||||||
|
|
||||||
class ImagePadForOutpaint:
|
class ImagePadForOutpaint:
|
||||||
|
|
||||||
@ -1405,6 +1549,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ImageScale": ImageScale,
|
"ImageScale": ImageScale,
|
||||||
"ImageScaleBy": ImageScaleBy,
|
"ImageScaleBy": ImageScaleBy,
|
||||||
"ImageInvert": ImageInvert,
|
"ImageInvert": ImageInvert,
|
||||||
|
"ImageBatch": ImageBatch,
|
||||||
"ImagePadForOutpaint": ImagePadForOutpaint,
|
"ImagePadForOutpaint": ImagePadForOutpaint,
|
||||||
"ConditioningAverage ": ConditioningAverage ,
|
"ConditioningAverage ": ConditioningAverage ,
|
||||||
"ConditioningCombine": ConditioningCombine,
|
"ConditioningCombine": ConditioningCombine,
|
||||||
@ -1414,6 +1559,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"KSamplerAdvanced": KSamplerAdvanced,
|
"KSamplerAdvanced": KSamplerAdvanced,
|
||||||
"SetLatentNoiseMask": SetLatentNoiseMask,
|
"SetLatentNoiseMask": SetLatentNoiseMask,
|
||||||
"LatentComposite": LatentComposite,
|
"LatentComposite": LatentComposite,
|
||||||
|
"LatentBlend": LatentBlend,
|
||||||
"LatentRotate": LatentRotate,
|
"LatentRotate": LatentRotate,
|
||||||
"LatentFlip": LatentFlip,
|
"LatentFlip": LatentFlip,
|
||||||
"LatentCrop": LatentCrop,
|
"LatentCrop": LatentCrop,
|
||||||
@ -1425,6 +1571,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"StyleModelApply": StyleModelApply,
|
"StyleModelApply": StyleModelApply,
|
||||||
"unCLIPConditioning": unCLIPConditioning,
|
"unCLIPConditioning": unCLIPConditioning,
|
||||||
"ControlNetApply": ControlNetApply,
|
"ControlNetApply": ControlNetApply,
|
||||||
|
"ControlNetApplyAdvanced": ControlNetApplyAdvanced,
|
||||||
"ControlNetLoader": ControlNetLoader,
|
"ControlNetLoader": ControlNetLoader,
|
||||||
"DiffControlNetLoader": DiffControlNetLoader,
|
"DiffControlNetLoader": DiffControlNetLoader,
|
||||||
"StyleModelLoader": StyleModelLoader,
|
"StyleModelLoader": StyleModelLoader,
|
||||||
@ -1442,6 +1589,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"SaveLatent": SaveLatent,
|
"SaveLatent": SaveLatent,
|
||||||
|
|
||||||
"ConditioningZeroOut": ConditioningZeroOut,
|
"ConditioningZeroOut": ConditioningZeroOut,
|
||||||
|
"ConditioningSetTimestepRange": ConditioningSetTimestepRange,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@ -1470,6 +1618,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ConditioningSetArea": "Conditioning (Set Area)",
|
"ConditioningSetArea": "Conditioning (Set Area)",
|
||||||
"ConditioningSetMask": "Conditioning (Set Mask)",
|
"ConditioningSetMask": "Conditioning (Set Mask)",
|
||||||
"ControlNetApply": "Apply ControlNet",
|
"ControlNetApply": "Apply ControlNet",
|
||||||
|
"ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
|
||||||
# Latent
|
# Latent
|
||||||
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
||||||
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
||||||
@ -1482,6 +1631,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"LatentUpscale": "Upscale Latent",
|
"LatentUpscale": "Upscale Latent",
|
||||||
"LatentUpscaleBy": "Upscale Latent By",
|
"LatentUpscaleBy": "Upscale Latent By",
|
||||||
"LatentComposite": "Latent Composite",
|
"LatentComposite": "Latent Composite",
|
||||||
|
"LatentBlend": "Latent Blend",
|
||||||
"LatentFromBatch" : "Latent From Batch",
|
"LatentFromBatch" : "Latent From Batch",
|
||||||
"RepeatLatentBatch": "Repeat Latent Batch",
|
"RepeatLatentBatch": "Repeat Latent Batch",
|
||||||
# Image
|
# Image
|
||||||
@ -1494,6 +1644,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ImageUpscaleWithModel": "Upscale Image (using Model)",
|
"ImageUpscaleWithModel": "Upscale Image (using Model)",
|
||||||
"ImageInvert": "Invert Image",
|
"ImageInvert": "Invert Image",
|
||||||
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
||||||
|
"ImageBatch": "Batch Images",
|
||||||
# _for_testing
|
# _for_testing
|
||||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||||
|
|||||||
@ -69,6 +69,13 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# Checkpoints\n",
|
"# Checkpoints\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"### SDXL\n",
|
||||||
|
"### I recommend these workflow examples: https://comfyanonymous.github.io/ComfyUI_examples/sdxl/\n",
|
||||||
|
"\n",
|
||||||
|
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P ./models/checkpoints/\n",
|
||||||
|
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0.safetensors -P ./models/checkpoints/\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
"# SD1.5\n",
|
"# SD1.5\n",
|
||||||
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
|
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -83,7 +90,7 @@
|
|||||||
"#!wget -c https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors -P ./models/checkpoints/\n",
|
"#!wget -c https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors -P ./models/checkpoints/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Waifu Diffusion 1.5 (anime style SD2.x 768-v)\n",
|
"# Waifu Diffusion 1.5 (anime style SD2.x 768-v)\n",
|
||||||
"#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta2/resolve/main/checkpoints/wd-1-5-beta2-fp16.safetensors -P ./models/checkpoints/\n",
|
"#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta3/resolve/main/wd-illusion-fp16.safetensors -P ./models/checkpoints/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# unCLIP models\n",
|
"# unCLIP models\n",
|
||||||
@ -100,6 +107,7 @@
|
|||||||
"# Loras\n",
|
"# Loras\n",
|
||||||
"#!wget -c https://civitai.com/api/download/models/10350 -O ./models/loras/theovercomer8sContrastFix_sd21768.safetensors #theovercomer8sContrastFix SD2.x 768-v\n",
|
"#!wget -c https://civitai.com/api/download/models/10350 -O ./models/loras/theovercomer8sContrastFix_sd21768.safetensors #theovercomer8sContrastFix SD2.x 768-v\n",
|
||||||
"#!wget -c https://civitai.com/api/download/models/10638 -O ./models/loras/theovercomer8sContrastFix_sd15.safetensors #theovercomer8sContrastFix SD1.x\n",
|
"#!wget -c https://civitai.com/api/download/models/10638 -O ./models/loras/theovercomer8sContrastFix_sd15.safetensors #theovercomer8sContrastFix SD1.x\n",
|
||||||
|
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors -P ./models/loras/ #SDXL offset noise lora\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# T2I-Adapter\n",
|
"# T2I-Adapter\n",
|
||||||
@ -151,13 +159,64 @@
|
|||||||
"\n"
|
"\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "kkkkkkkkkkkkkkk"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Run ComfyUI with cloudflared (Recommended Way)\n",
|
||||||
|
"\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "jjjjjjjjjjjjjj"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb\n",
|
||||||
|
"!dpkg -i cloudflared-linux-amd64.deb\n",
|
||||||
|
"\n",
|
||||||
|
"import subprocess\n",
|
||||||
|
"import threading\n",
|
||||||
|
"import time\n",
|
||||||
|
"import socket\n",
|
||||||
|
"import urllib.request\n",
|
||||||
|
"\n",
|
||||||
|
"def iframe_thread(port):\n",
|
||||||
|
" while True:\n",
|
||||||
|
" time.sleep(0.5)\n",
|
||||||
|
" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
|
||||||
|
" result = sock.connect_ex(('127.0.0.1', port))\n",
|
||||||
|
" if result == 0:\n",
|
||||||
|
" break\n",
|
||||||
|
" sock.close()\n",
|
||||||
|
" print(\"\\nComfyUI finished loading, trying to launch cloudflared (if it gets stuck here cloudflared is having issues)\\n\")\n",
|
||||||
|
"\n",
|
||||||
|
" p = subprocess.Popen([\"cloudflared\", \"tunnel\", \"--url\", \"http://127.0.0.1:{}\".format(port)], stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
|
||||||
|
" for line in p.stderr:\n",
|
||||||
|
" l = line.decode()\n",
|
||||||
|
" if \"trycloudflare.com \" in l:\n",
|
||||||
|
" print(\"This is the URL to access ComfyUI:\", l[l.find(\"http\"):], end='')\n",
|
||||||
|
" #print(l, end='')\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n",
|
||||||
|
"\n",
|
||||||
|
"!python main.py --dont-print-server"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "kkkkkkkkkkkkkk"
|
"id": "kkkkkkkkkkkkkk"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"### Run ComfyUI with localtunnel (Recommended Way)\n",
|
"### Run ComfyUI with localtunnel\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n"
|
"\n"
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
torch
|
torch
|
||||||
torchdiffeq
|
|
||||||
torchsde
|
torchsde
|
||||||
einops
|
einops
|
||||||
transformers>=4.25.1
|
transformers>=4.25.1
|
||||||
@ -10,3 +9,4 @@ pyyaml
|
|||||||
Pillow
|
Pillow
|
||||||
scipy
|
scipy
|
||||||
tqdm
|
tqdm
|
||||||
|
psutil
|
||||||
|
|||||||
@ -2,8 +2,12 @@ import json
|
|||||||
from urllib import request, parse
|
from urllib import request, parse
|
||||||
import random
|
import random
|
||||||
|
|
||||||
#this is the ComfyUI api prompt format. If you want it for a specific workflow you can copy it from the prompt section
|
#This is the ComfyUI api prompt format.
|
||||||
#of the image metadata of images generated with ComfyUI
|
|
||||||
|
#If you want it for a specific workflow you can "enable dev mode options"
|
||||||
|
#in the settings of the UI (gear beside the "Queue Size: ") this will enable
|
||||||
|
#a button on the UI to save workflows in api format.
|
||||||
|
|
||||||
#keep in mind ComfyUI is pre alpha software so this format will change a bit.
|
#keep in mind ComfyUI is pre alpha software so this format will change a bit.
|
||||||
|
|
||||||
#this is the one for the default workflow
|
#this is the one for the default workflow
|
||||||
|
|||||||
36
server.py
36
server.py
@ -8,7 +8,7 @@ import uuid
|
|||||||
import json
|
import json
|
||||||
import glob
|
import glob
|
||||||
import struct
|
import struct
|
||||||
from PIL import Image
|
from PIL import Image, ImageOps
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -29,6 +29,7 @@ import comfy.model_management
|
|||||||
|
|
||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
|
UNENCODED_PREVIEW_IMAGE = 2
|
||||||
|
|
||||||
async def send_socket_catch_exception(function, message):
|
async def send_socket_catch_exception(function, message):
|
||||||
try:
|
try:
|
||||||
@ -344,6 +345,11 @@ class PromptServer():
|
|||||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||||
system_stats = {
|
system_stats = {
|
||||||
|
"system": {
|
||||||
|
"os": os.name,
|
||||||
|
"python_version": sys.version,
|
||||||
|
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
|
||||||
|
},
|
||||||
"devices": [
|
"devices": [
|
||||||
{
|
{
|
||||||
"name": device_name,
|
"name": device_name,
|
||||||
@ -498,7 +504,9 @@ class PromptServer():
|
|||||||
return prompt_info
|
return prompt_info
|
||||||
|
|
||||||
async def send(self, event, data, sid=None):
|
async def send(self, event, data, sid=None):
|
||||||
if isinstance(data, (bytes, bytearray)):
|
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||||
|
await self.send_image(data, sid=sid)
|
||||||
|
elif isinstance(data, (bytes, bytearray)):
|
||||||
await self.send_bytes(event, data, sid)
|
await self.send_bytes(event, data, sid)
|
||||||
else:
|
else:
|
||||||
await self.send_json(event, data, sid)
|
await self.send_json(event, data, sid)
|
||||||
@ -512,6 +520,30 @@ class PromptServer():
|
|||||||
message.extend(data)
|
message.extend(data)
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
async def send_image(self, image_data, sid=None):
|
||||||
|
image_type = image_data[0]
|
||||||
|
image = image_data[1]
|
||||||
|
max_size = image_data[2]
|
||||||
|
if max_size is not None:
|
||||||
|
if hasattr(Image, 'Resampling'):
|
||||||
|
resampling = Image.Resampling.BILINEAR
|
||||||
|
else:
|
||||||
|
resampling = Image.ANTIALIAS
|
||||||
|
|
||||||
|
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||||
|
type_num = 1
|
||||||
|
if image_type == "JPEG":
|
||||||
|
type_num = 1
|
||||||
|
elif image_type == "PNG":
|
||||||
|
type_num = 2
|
||||||
|
|
||||||
|
bytesIO = BytesIO()
|
||||||
|
header = struct.pack(">I", type_num)
|
||||||
|
bytesIO.write(header)
|
||||||
|
image.save(bytesIO, format=image_type, quality=95, compress_level=4)
|
||||||
|
preview_bytes = bytesIO.getvalue()
|
||||||
|
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
||||||
|
|
||||||
async def send_bytes(self, event, data, sid=None):
|
async def send_bytes(self, event, data, sid=None):
|
||||||
message = self.encode_bytes(event, data)
|
message = self.encode_bytes(event, data)
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import {app} from "/scripts/app.js";
|
import {app} from "../../scripts/app.js";
|
||||||
|
|
||||||
// Adds filtering to combo context menus
|
// Adds filtering to combo context menus
|
||||||
|
|
||||||
@ -27,10 +27,13 @@ const ext = {
|
|||||||
const clickedComboValue = currentNode.widgets
|
const clickedComboValue = currentNode.widgets
|
||||||
.filter(w => w.type === "combo" && w.options.values.length === values.length)
|
.filter(w => w.type === "combo" && w.options.values.length === values.length)
|
||||||
.find(w => w.options.values.every((v, i) => v === values[i]))
|
.find(w => w.options.values.every((v, i) => v === values[i]))
|
||||||
.value;
|
?.value;
|
||||||
|
|
||||||
let selectedIndex = values.findIndex(v => v === clickedComboValue);
|
let selectedIndex = clickedComboValue ? values.findIndex(v => v === clickedComboValue) : 0;
|
||||||
let selectedItem = displayedItems?.[selectedIndex];
|
if (selectedIndex < 0) {
|
||||||
|
selectedIndex = 0;
|
||||||
|
}
|
||||||
|
let selectedItem = displayedItems[selectedIndex];
|
||||||
updateSelected();
|
updateSelected();
|
||||||
|
|
||||||
// Apply highlighting to the selected item
|
// Apply highlighting to the selected item
|
||||||
|
|||||||
25
web/extensions/core/linkRenderMode.js
Normal file
25
web/extensions/core/linkRenderMode.js
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import { app } from "../../scripts/app.js";
|
||||||
|
|
||||||
|
const id = "Comfy.LinkRenderMode";
|
||||||
|
const ext = {
|
||||||
|
name: id,
|
||||||
|
async setup(app) {
|
||||||
|
app.ui.settings.addSetting({
|
||||||
|
id,
|
||||||
|
name: "Link Render Mode",
|
||||||
|
defaultValue: 2,
|
||||||
|
type: "combo",
|
||||||
|
options: LiteGraph.LINK_RENDER_MODES.map((m, i) => ({
|
||||||
|
value: i,
|
||||||
|
text: m,
|
||||||
|
selected: i == app.canvas.links_render_mode,
|
||||||
|
})),
|
||||||
|
onChange(value) {
|
||||||
|
app.canvas.links_render_mode = +value;
|
||||||
|
app.graph.setDirtyCanvas(true);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
app.registerExtension(ext);
|
||||||
@ -2,7 +2,7 @@ import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js";
|
|||||||
import { app } from "../../scripts/app.js";
|
import { app } from "../../scripts/app.js";
|
||||||
|
|
||||||
const CONVERTED_TYPE = "converted-widget";
|
const CONVERTED_TYPE = "converted-widget";
|
||||||
const VALID_TYPES = ["STRING", "combo", "number"];
|
const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"];
|
||||||
|
|
||||||
function isConvertableWidget(widget, config) {
|
function isConvertableWidget(widget, config) {
|
||||||
return VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0]);
|
return VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0]);
|
||||||
|
|||||||
@ -9766,6 +9766,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
|
|
||||||
switch (w.type) {
|
switch (w.type) {
|
||||||
case "button":
|
case "button":
|
||||||
|
ctx.fillStyle = background_color;
|
||||||
if (w.clicked) {
|
if (w.clicked) {
|
||||||
ctx.fillStyle = "#AAA";
|
ctx.fillStyle = "#AAA";
|
||||||
w.clicked = false;
|
w.clicked = false;
|
||||||
@ -9835,7 +9836,11 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
ctx.textAlign = "center";
|
ctx.textAlign = "center";
|
||||||
ctx.fillStyle = text_color;
|
ctx.fillStyle = text_color;
|
||||||
ctx.fillText(
|
ctx.fillText(
|
||||||
w.label || w.name + " " + Number(w.value).toFixed(3),
|
w.label || w.name + " " + Number(w.value).toFixed(
|
||||||
|
w.options.precision != null
|
||||||
|
? w.options.precision
|
||||||
|
: 3
|
||||||
|
),
|
||||||
widget_width * 0.5,
|
widget_width * 0.5,
|
||||||
y + H * 0.7
|
y + H * 0.7
|
||||||
);
|
);
|
||||||
@ -13835,7 +13840,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
if (!disabled) {
|
if (!disabled) {
|
||||||
element.addEventListener("click", inner_onclick);
|
element.addEventListener("click", inner_onclick);
|
||||||
}
|
}
|
||||||
if (options.autoopen) {
|
if (!disabled && options.autoopen) {
|
||||||
LiteGraph.pointerListenerAdd(element,"enter",inner_over);
|
LiteGraph.pointerListenerAdd(element,"enter",inner_over);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -264,6 +264,15 @@ class ComfyApi extends EventTarget {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets system & device stats
|
||||||
|
* @returns System stats such as python version, OS, per device info
|
||||||
|
*/
|
||||||
|
async getSystemStats() {
|
||||||
|
const res = await this.fetchApi("/system_stats");
|
||||||
|
return await res.json();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sends a POST request to the API
|
* Sends a POST request to the API
|
||||||
* @param {*} type The endpoint to post to
|
* @param {*} type The endpoint to post to
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import { ComfyLogging } from "./logging.js";
|
||||||
import { ComfyWidgets } from "./widgets.js";
|
import { ComfyWidgets } from "./widgets.js";
|
||||||
import { ComfyUI, $el } from "./ui.js";
|
import { ComfyUI, $el } from "./ui.js";
|
||||||
import { api } from "./api.js";
|
import { api } from "./api.js";
|
||||||
@ -31,6 +32,7 @@ export class ComfyApp {
|
|||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
this.ui = new ComfyUI(this);
|
this.ui = new ComfyUI(this);
|
||||||
|
this.logging = new ComfyLogging(this);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* List of extensions that are registered with the app
|
* List of extensions that are registered with the app
|
||||||
@ -282,6 +284,11 @@ export class ComfyApp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
options.push({
|
||||||
|
content: "Bypass",
|
||||||
|
callback: (obj) => { if (this.mode === 4) this.mode = 0; else this.mode = 4; this.graph.change(); }
|
||||||
|
});
|
||||||
|
|
||||||
// prevent conflict of clipspace content
|
// prevent conflict of clipspace content
|
||||||
if(!ComfyApp.clipspace_return_node) {
|
if(!ComfyApp.clipspace_return_node) {
|
||||||
options.push({
|
options.push({
|
||||||
@ -768,6 +775,19 @@ export class ComfyApp {
|
|||||||
}
|
}
|
||||||
block_default = true;
|
block_default = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (e.keyCode == 66 && e.ctrlKey) {
|
||||||
|
if (this.selected_nodes) {
|
||||||
|
for (var i in this.selected_nodes) {
|
||||||
|
if (this.selected_nodes[i].mode === 4) { // never
|
||||||
|
this.selected_nodes[i].mode = 0; // always
|
||||||
|
} else {
|
||||||
|
this.selected_nodes[i].mode = 4; // never
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
block_default = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
this.graph.change();
|
this.graph.change();
|
||||||
@ -914,14 +934,21 @@ export class ComfyApp {
|
|||||||
const origDrawNode = LGraphCanvas.prototype.drawNode;
|
const origDrawNode = LGraphCanvas.prototype.drawNode;
|
||||||
LGraphCanvas.prototype.drawNode = function (node, ctx) {
|
LGraphCanvas.prototype.drawNode = function (node, ctx) {
|
||||||
var editor_alpha = this.editor_alpha;
|
var editor_alpha = this.editor_alpha;
|
||||||
|
var old_color = node.bgcolor;
|
||||||
|
|
||||||
if (node.mode === 2) { // never
|
if (node.mode === 2) { // never
|
||||||
this.editor_alpha = 0.4;
|
this.editor_alpha = 0.4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node.mode === 4) { // never
|
||||||
|
node.bgcolor = "#FF00FF";
|
||||||
|
this.editor_alpha = 0.2;
|
||||||
|
}
|
||||||
|
|
||||||
const res = origDrawNode.apply(this, arguments);
|
const res = origDrawNode.apply(this, arguments);
|
||||||
|
|
||||||
this.editor_alpha = editor_alpha;
|
this.editor_alpha = editor_alpha;
|
||||||
|
node.bgcolor = old_color;
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
};
|
};
|
||||||
@ -1003,6 +1030,7 @@ export class ComfyApp {
|
|||||||
*/
|
*/
|
||||||
async #loadExtensions() {
|
async #loadExtensions() {
|
||||||
const extensions = await api.getExtensions();
|
const extensions = await api.getExtensions();
|
||||||
|
this.logging.addEntry("Comfy.App", "debug", { Extensions: extensions });
|
||||||
for (const ext of extensions) {
|
for (const ext of extensions) {
|
||||||
try {
|
try {
|
||||||
await import(api.apiURL(ext));
|
await import(api.apiURL(ext));
|
||||||
@ -1286,6 +1314,9 @@ export class ComfyApp {
|
|||||||
(t) => `<li>${t}</li>`
|
(t) => `<li>${t}</li>`
|
||||||
).join("")}</ul>Nodes that have failed to load will show as red on the graph.`
|
).join("")}</ul>Nodes that have failed to load will show as red on the graph.`
|
||||||
);
|
);
|
||||||
|
this.logging.addEntry("Comfy.App", "warn", {
|
||||||
|
MissingNodes: missingNodeTypes,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1308,7 +1339,7 @@ export class ComfyApp {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node.mode === 2) {
|
if (node.mode === 2 || node.mode === 4) {
|
||||||
// Don't serialize muted nodes
|
// Don't serialize muted nodes
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -1331,12 +1362,36 @@ export class ComfyApp {
|
|||||||
let parent = node.getInputNode(i);
|
let parent = node.getInputNode(i);
|
||||||
if (parent) {
|
if (parent) {
|
||||||
let link = node.getInputLink(i);
|
let link = node.getInputLink(i);
|
||||||
while (parent && parent.isVirtualNode) {
|
while (parent.mode === 4 || parent.isVirtualNode) {
|
||||||
link = parent.getInputLink(link.origin_slot);
|
let found = false;
|
||||||
if (link) {
|
if (parent.isVirtualNode) {
|
||||||
parent = parent.getInputNode(link.origin_slot);
|
link = parent.getInputLink(link.origin_slot);
|
||||||
} else {
|
if (link) {
|
||||||
parent = null;
|
parent = parent.getInputNode(link.target_slot);
|
||||||
|
if (parent) {
|
||||||
|
found = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (link && parent.mode === 4) {
|
||||||
|
let all_inputs = [link.origin_slot];
|
||||||
|
if (parent.inputs) {
|
||||||
|
all_inputs = all_inputs.concat(Object.keys(parent.inputs))
|
||||||
|
for (let parent_input in all_inputs) {
|
||||||
|
parent_input = all_inputs[parent_input];
|
||||||
|
if (parent.inputs[parent_input].type === node.inputs[i].type) {
|
||||||
|
link = parent.getInputLink(parent_input);
|
||||||
|
if (link) {
|
||||||
|
parent = parent.getInputNode(parent_input);
|
||||||
|
}
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!found) {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
367
web/scripts/logging.js
Normal file
367
web/scripts/logging.js
Normal file
@ -0,0 +1,367 @@
|
|||||||
|
import { $el, ComfyDialog } from "./ui.js";
|
||||||
|
import { api } from "./api.js";
|
||||||
|
|
||||||
|
$el("style", {
|
||||||
|
textContent: `
|
||||||
|
.comfy-logging-logs {
|
||||||
|
display: grid;
|
||||||
|
color: var(--fg-color);
|
||||||
|
white-space: pre-wrap;
|
||||||
|
}
|
||||||
|
.comfy-logging-log {
|
||||||
|
display: contents;
|
||||||
|
}
|
||||||
|
.comfy-logging-title {
|
||||||
|
background: var(--tr-even-bg-color);
|
||||||
|
font-weight: bold;
|
||||||
|
margin-bottom: 5px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
.comfy-logging-log div {
|
||||||
|
background: var(--row-bg);
|
||||||
|
padding: 5px;
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
parent: document.body,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Stringify function supporting max depth and removal of circular references
|
||||||
|
// https://stackoverflow.com/a/57193345
|
||||||
|
function stringify(val, depth, replacer, space, onGetObjID) {
|
||||||
|
depth = isNaN(+depth) ? 1 : depth;
|
||||||
|
var recursMap = new WeakMap();
|
||||||
|
function _build(val, depth, o, a, r) {
|
||||||
|
// (JSON.stringify() has it's own rules, which we respect here by using it for property iteration)
|
||||||
|
return !val || typeof val != "object"
|
||||||
|
? val
|
||||||
|
: ((r = recursMap.has(val)),
|
||||||
|
recursMap.set(val, true),
|
||||||
|
(a = Array.isArray(val)),
|
||||||
|
r
|
||||||
|
? (o = (onGetObjID && onGetObjID(val)) || null)
|
||||||
|
: JSON.stringify(val, function (k, v) {
|
||||||
|
if (a || depth > 0) {
|
||||||
|
if (replacer) v = replacer(k, v);
|
||||||
|
if (!k) return (a = Array.isArray(v)), (val = v);
|
||||||
|
!o && (o = a ? [] : {});
|
||||||
|
o[k] = _build(v, a ? depth : depth - 1);
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
o === void 0 ? (a ? [] : {}) : o);
|
||||||
|
}
|
||||||
|
return JSON.stringify(_build(val, depth), null, space);
|
||||||
|
}
|
||||||
|
|
||||||
|
const jsonReplacer = (k, v, ui) => {
|
||||||
|
if (v instanceof Array && v.length === 1) {
|
||||||
|
v = v[0];
|
||||||
|
}
|
||||||
|
if (v instanceof Date) {
|
||||||
|
v = v.toISOString();
|
||||||
|
if (ui) {
|
||||||
|
v = v.split("T")[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (v instanceof Error) {
|
||||||
|
let err = "";
|
||||||
|
if (v.name) err += v.name + "\n";
|
||||||
|
if (v.message) err += v.message + "\n";
|
||||||
|
if (v.stack) err += v.stack + "\n";
|
||||||
|
if (!err) {
|
||||||
|
err = v.toString();
|
||||||
|
}
|
||||||
|
v = err;
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
};
|
||||||
|
|
||||||
|
const fileInput = $el("input", {
|
||||||
|
type: "file",
|
||||||
|
accept: ".json",
|
||||||
|
style: { display: "none" },
|
||||||
|
parent: document.body,
|
||||||
|
});
|
||||||
|
|
||||||
|
class ComfyLoggingDialog extends ComfyDialog {
|
||||||
|
constructor(logging) {
|
||||||
|
super();
|
||||||
|
this.logging = logging;
|
||||||
|
}
|
||||||
|
|
||||||
|
clear() {
|
||||||
|
this.logging.clear();
|
||||||
|
this.show();
|
||||||
|
}
|
||||||
|
|
||||||
|
export() {
|
||||||
|
const blob = new Blob([stringify([...this.logging.entries], 20, jsonReplacer, "\t")], {
|
||||||
|
type: "application/json",
|
||||||
|
});
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = $el("a", {
|
||||||
|
href: url,
|
||||||
|
download: `comfyui-logs-${Date.now()}.json`,
|
||||||
|
style: { display: "none" },
|
||||||
|
parent: document.body,
|
||||||
|
});
|
||||||
|
a.click();
|
||||||
|
setTimeout(function () {
|
||||||
|
a.remove();
|
||||||
|
window.URL.revokeObjectURL(url);
|
||||||
|
}, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
import() {
|
||||||
|
fileInput.onchange = () => {
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = () => {
|
||||||
|
fileInput.remove();
|
||||||
|
try {
|
||||||
|
const obj = JSON.parse(reader.result);
|
||||||
|
if (obj instanceof Array) {
|
||||||
|
this.show(obj);
|
||||||
|
} else {
|
||||||
|
throw new Error("Invalid file selected.");
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
alert("Unable to load logs: " + error.message);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
reader.readAsText(fileInput.files[0]);
|
||||||
|
};
|
||||||
|
fileInput.click();
|
||||||
|
}
|
||||||
|
|
||||||
|
createButtons() {
|
||||||
|
return [
|
||||||
|
$el("button", {
|
||||||
|
type: "button",
|
||||||
|
textContent: "Clear",
|
||||||
|
onclick: () => this.clear(),
|
||||||
|
}),
|
||||||
|
$el("button", {
|
||||||
|
type: "button",
|
||||||
|
textContent: "Export logs...",
|
||||||
|
onclick: () => this.export(),
|
||||||
|
}),
|
||||||
|
$el("button", {
|
||||||
|
type: "button",
|
||||||
|
textContent: "View exported logs...",
|
||||||
|
onclick: () => this.import(),
|
||||||
|
}),
|
||||||
|
...super.createButtons(),
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
getTypeColor(type) {
|
||||||
|
switch (type) {
|
||||||
|
case "error":
|
||||||
|
return "red";
|
||||||
|
case "warn":
|
||||||
|
return "orange";
|
||||||
|
case "debug":
|
||||||
|
return "dodgerblue";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
show(entries) {
|
||||||
|
if (!entries) entries = this.logging.entries;
|
||||||
|
this.element.style.width = "100%";
|
||||||
|
const cols = {
|
||||||
|
source: "Source",
|
||||||
|
type: "Type",
|
||||||
|
timestamp: "Timestamp",
|
||||||
|
message: "Message",
|
||||||
|
};
|
||||||
|
const keys = Object.keys(cols);
|
||||||
|
const headers = Object.values(cols).map((title) =>
|
||||||
|
$el("div.comfy-logging-title", {
|
||||||
|
textContent: title,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
const rows = entries.map((entry, i) => {
|
||||||
|
return $el(
|
||||||
|
"div.comfy-logging-log",
|
||||||
|
{
|
||||||
|
$: (el) => el.style.setProperty("--row-bg", `var(--tr-${i % 2 ? "even" : "odd"}-bg-color)`),
|
||||||
|
},
|
||||||
|
keys.map((key) => {
|
||||||
|
let v = entry[key];
|
||||||
|
let color;
|
||||||
|
if (key === "type") {
|
||||||
|
color = this.getTypeColor(v);
|
||||||
|
} else {
|
||||||
|
v = jsonReplacer(key, v, true);
|
||||||
|
|
||||||
|
if (typeof v === "object") {
|
||||||
|
v = stringify(v, 5, jsonReplacer, " ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return $el("div", {
|
||||||
|
style: {
|
||||||
|
color,
|
||||||
|
},
|
||||||
|
textContent: v,
|
||||||
|
});
|
||||||
|
})
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
const grid = $el(
|
||||||
|
"div.comfy-logging-logs",
|
||||||
|
{
|
||||||
|
style: {
|
||||||
|
gridTemplateColumns: `repeat(${headers.length}, 1fr)`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
[...headers, ...rows]
|
||||||
|
);
|
||||||
|
const els = [grid];
|
||||||
|
if (!this.logging.enabled) {
|
||||||
|
els.unshift(
|
||||||
|
$el("h3", {
|
||||||
|
style: { textAlign: "center" },
|
||||||
|
textContent: "Logging is disabled",
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
super.show($el("div", els));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export class ComfyLogging {
|
||||||
|
/**
|
||||||
|
* @type Array<{ source: string, type: string, timestamp: Date, message: any }>
|
||||||
|
*/
|
||||||
|
entries = [];
|
||||||
|
|
||||||
|
#enabled;
|
||||||
|
#console = {};
|
||||||
|
|
||||||
|
get enabled() {
|
||||||
|
return this.#enabled;
|
||||||
|
}
|
||||||
|
|
||||||
|
set enabled(value) {
|
||||||
|
if (value === this.#enabled) return;
|
||||||
|
if (value) {
|
||||||
|
this.patchConsole();
|
||||||
|
} else {
|
||||||
|
this.unpatchConsole();
|
||||||
|
}
|
||||||
|
this.#enabled = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
constructor(app) {
|
||||||
|
this.app = app;
|
||||||
|
|
||||||
|
this.dialog = new ComfyLoggingDialog(this);
|
||||||
|
this.addSetting();
|
||||||
|
this.catchUnhandled();
|
||||||
|
this.addInitData();
|
||||||
|
}
|
||||||
|
|
||||||
|
addSetting() {
|
||||||
|
const settingId = "Comfy.Logging.Enabled";
|
||||||
|
const htmlSettingId = settingId.replaceAll(".", "-");
|
||||||
|
const setting = this.app.ui.settings.addSetting({
|
||||||
|
id: settingId,
|
||||||
|
name: settingId,
|
||||||
|
defaultValue: true,
|
||||||
|
type: (name, setter, value) => {
|
||||||
|
return $el("tr", [
|
||||||
|
$el("td", [
|
||||||
|
$el("label", {
|
||||||
|
textContent: "Logging",
|
||||||
|
for: htmlSettingId,
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
$el("td", [
|
||||||
|
$el("input", {
|
||||||
|
id: htmlSettingId,
|
||||||
|
type: "checkbox",
|
||||||
|
checked: value,
|
||||||
|
onchange: (event) => {
|
||||||
|
setter((this.enabled = event.target.checked));
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
$el("button", {
|
||||||
|
textContent: "View Logs",
|
||||||
|
onclick: () => {
|
||||||
|
this.app.ui.settings.element.close();
|
||||||
|
this.dialog.show();
|
||||||
|
},
|
||||||
|
style: {
|
||||||
|
fontSize: "14px",
|
||||||
|
display: "block",
|
||||||
|
marginTop: "5px",
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
]);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
this.enabled = setting.value;
|
||||||
|
}
|
||||||
|
|
||||||
|
patchConsole() {
|
||||||
|
// Capture common console outputs
|
||||||
|
const self = this;
|
||||||
|
for (const type of ["log", "warn", "error", "debug"]) {
|
||||||
|
const orig = console[type];
|
||||||
|
this.#console[type] = orig;
|
||||||
|
console[type] = function () {
|
||||||
|
orig.apply(console, arguments);
|
||||||
|
self.addEntry("console", type, ...arguments);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unpatchConsole() {
|
||||||
|
// Restore original console functions
|
||||||
|
for (const type of Object.keys(this.#console)) {
|
||||||
|
console[type] = this.#console[type];
|
||||||
|
}
|
||||||
|
this.#console = {};
|
||||||
|
}
|
||||||
|
|
||||||
|
catchUnhandled() {
|
||||||
|
// Capture uncaught errors
|
||||||
|
window.addEventListener("error", (e) => {
|
||||||
|
this.addEntry("window", "error", e.error ?? "Unknown error");
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
|
||||||
|
window.addEventListener("unhandledrejection", (e) => {
|
||||||
|
this.addEntry("unhandledrejection", "error", e.reason ?? "Unknown error");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
clear() {
|
||||||
|
this.entries = [];
|
||||||
|
}
|
||||||
|
|
||||||
|
addEntry(source, type, ...args) {
|
||||||
|
if (this.enabled) {
|
||||||
|
this.entries.push({
|
||||||
|
source,
|
||||||
|
type,
|
||||||
|
timestamp: new Date(),
|
||||||
|
message: args,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log(source, ...args) {
|
||||||
|
this.addEntry(source, "log", ...args);
|
||||||
|
}
|
||||||
|
|
||||||
|
async addInitData() {
|
||||||
|
if (!this.enabled) return;
|
||||||
|
const source = "ComfyUI.Logging";
|
||||||
|
this.addEntry(source, "debug", { UserAgent: navigator.userAgent });
|
||||||
|
const systemStats = await api.getSystemStats();
|
||||||
|
this.addEntry(source, "debug", systemStats);
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -234,7 +234,7 @@ class ComfySettingsDialog extends ComfyDialog {
|
|||||||
localStorage[settingId] = JSON.stringify(value);
|
localStorage[settingId] = JSON.stringify(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "",}) {
|
addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined}) {
|
||||||
if (!id) {
|
if (!id) {
|
||||||
throw new Error("Settings must have an ID");
|
throw new Error("Settings must have an ID");
|
||||||
}
|
}
|
||||||
@ -347,6 +347,32 @@ class ComfySettingsDialog extends ComfyDialog {
|
|||||||
]),
|
]),
|
||||||
]);
|
]);
|
||||||
break;
|
break;
|
||||||
|
case "combo":
|
||||||
|
element = $el("tr", [
|
||||||
|
labelCell,
|
||||||
|
$el("td", [
|
||||||
|
$el(
|
||||||
|
"select",
|
||||||
|
{
|
||||||
|
oninput: (e) => {
|
||||||
|
setter(e.target.value);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
(typeof options === "function" ? options(value) : options || []).map((opt) => {
|
||||||
|
if (typeof opt === "string") {
|
||||||
|
opt = { text: opt };
|
||||||
|
}
|
||||||
|
const v = opt.value ?? opt.text;
|
||||||
|
return $el("option", {
|
||||||
|
value: v,
|
||||||
|
textContent: opt.text,
|
||||||
|
selected: value + "" === v + "",
|
||||||
|
});
|
||||||
|
})
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
]);
|
||||||
|
break;
|
||||||
case "text":
|
case "text":
|
||||||
default:
|
default:
|
||||||
if (type !== "text") {
|
if (type !== "text") {
|
||||||
@ -480,7 +506,7 @@ class ComfyList {
|
|||||||
|
|
||||||
hide() {
|
hide() {
|
||||||
this.element.style.display = "none";
|
this.element.style.display = "none";
|
||||||
this.button.textContent = "See " + this.#text;
|
this.button.textContent = "View " + this.#text;
|
||||||
}
|
}
|
||||||
|
|
||||||
toggle() {
|
toggle() {
|
||||||
@ -542,6 +568,13 @@ export class ComfyUI {
|
|||||||
defaultValue: "",
|
defaultValue: "",
|
||||||
});
|
});
|
||||||
|
|
||||||
|
this.settings.addSetting({
|
||||||
|
id: "Comfy.DisableSliders",
|
||||||
|
name: "Disable sliders.",
|
||||||
|
type: "boolean",
|
||||||
|
defaultValue: false,
|
||||||
|
});
|
||||||
|
|
||||||
const fileInput = $el("input", {
|
const fileInput = $el("input", {
|
||||||
id: "comfy-file-input",
|
id: "comfy-file-input",
|
||||||
type: "file",
|
type: "file",
|
||||||
|
|||||||
@ -79,8 +79,8 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
|
|||||||
return valueControl;
|
return valueControl;
|
||||||
};
|
};
|
||||||
|
|
||||||
function seedWidget(node, inputName, inputData) {
|
function seedWidget(node, inputName, inputData, app) {
|
||||||
const seed = ComfyWidgets.INT(node, inputName, inputData);
|
const seed = ComfyWidgets.INT(node, inputName, inputData, app);
|
||||||
const seedControl = addValueControlWidget(node, seed.widget, "randomize");
|
const seedControl = addValueControlWidget(node, seed.widget, "randomize");
|
||||||
|
|
||||||
seed.widget.linkedWidgets = [seedControl];
|
seed.widget.linkedWidgets = [seedControl];
|
||||||
@ -250,19 +250,29 @@ function addMultilineWidget(node, name, opts, app) {
|
|||||||
return { minWidth: 400, minHeight: 200, widget };
|
return { minWidth: 400, minHeight: 200, widget };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function isSlider(display, app) {
|
||||||
|
if (app.ui.settings.getSettingValue("Comfy.DisableSliders")) {
|
||||||
|
return "number"
|
||||||
|
}
|
||||||
|
|
||||||
|
return (display==="slider") ? "slider" : "number"
|
||||||
|
}
|
||||||
|
|
||||||
export const ComfyWidgets = {
|
export const ComfyWidgets = {
|
||||||
"INT:seed": seedWidget,
|
"INT:seed": seedWidget,
|
||||||
"INT:noise_seed": seedWidget,
|
"INT:noise_seed": seedWidget,
|
||||||
FLOAT(node, inputName, inputData) {
|
FLOAT(node, inputName, inputData, app) {
|
||||||
|
let widgetType = isSlider(inputData[1]["display"], app);
|
||||||
const { val, config } = getNumberDefaults(inputData, 0.5);
|
const { val, config } = getNumberDefaults(inputData, 0.5);
|
||||||
return { widget: node.addWidget("number", inputName, val, () => {}, config) };
|
return { widget: node.addWidget(widgetType, inputName, val, () => {}, config) };
|
||||||
},
|
},
|
||||||
INT(node, inputName, inputData) {
|
INT(node, inputName, inputData, app) {
|
||||||
|
let widgetType = isSlider(inputData[1]["display"], app);
|
||||||
const { val, config } = getNumberDefaults(inputData, 1);
|
const { val, config } = getNumberDefaults(inputData, 1);
|
||||||
Object.assign(config, { precision: 0 });
|
Object.assign(config, { precision: 0 });
|
||||||
return {
|
return {
|
||||||
widget: node.addWidget(
|
widget: node.addWidget(
|
||||||
"number",
|
widgetType,
|
||||||
inputName,
|
inputName,
|
||||||
val,
|
val,
|
||||||
function (v) {
|
function (v) {
|
||||||
@ -273,6 +283,18 @@ export const ComfyWidgets = {
|
|||||||
),
|
),
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
BOOLEAN(node, inputName, inputData) {
|
||||||
|
let defaultVal = inputData[1]["default"];
|
||||||
|
return {
|
||||||
|
widget: node.addWidget(
|
||||||
|
"toggle",
|
||||||
|
inputName,
|
||||||
|
defaultVal,
|
||||||
|
() => {},
|
||||||
|
{"on": inputData[1].label_on, "off": inputData[1].label_off}
|
||||||
|
)
|
||||||
|
};
|
||||||
|
},
|
||||||
STRING(node, inputName, inputData, app) {
|
STRING(node, inputName, inputData, app) {
|
||||||
const defaultVal = inputData[1].default || "";
|
const defaultVal = inputData[1].default || "";
|
||||||
const multiline = !!inputData[1].multiline;
|
const multiline = !!inputData[1].multiline;
|
||||||
@ -411,7 +433,7 @@ export const ComfyWidgets = {
|
|||||||
// Add handler to check if an image is being dragged over our node
|
// Add handler to check if an image is being dragged over our node
|
||||||
node.onDragOver = function (e) {
|
node.onDragOver = function (e) {
|
||||||
if (e.dataTransfer && e.dataTransfer.items) {
|
if (e.dataTransfer && e.dataTransfer.items) {
|
||||||
const image = [...e.dataTransfer.items].find((f) => f.kind === "file" && f.type.startsWith("image/"));
|
const image = [...e.dataTransfer.items].find((f) => f.kind === "file");
|
||||||
return !!image;
|
return !!image;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
4
web/types/comfy.d.ts
vendored
4
web/types/comfy.d.ts
vendored
@ -30,9 +30,7 @@ export interface ComfyExtension {
|
|||||||
getCustomWidgets(
|
getCustomWidgets(
|
||||||
app: ComfyApp
|
app: ComfyApp
|
||||||
): Promise<
|
): Promise<
|
||||||
Array<
|
Record<string, (node, inputName, inputData, app) => { widget?: IWidget; minWidth?: number; minHeight?: number }>
|
||||||
Record<string, (node, inputName, inputData, app) => { widget?: IWidget; minWidth?: number; minHeight?: number }>
|
|
||||||
>
|
|
||||||
>;
|
>;
|
||||||
/**
|
/**
|
||||||
* Allows the extension to add additional handling to the node before it is registered with LGraph
|
* Allows the extension to add additional handling to the node before it is registered with LGraph
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user