mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 05:22:34 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
e11af79bcc
@ -94,7 +94,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
|||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 5.6 that supports the 7000 series and might have some performance improvements:
|
This is the command to install the nightly with ROCm 5.6 that supports the 7000 series and might have some performance improvements:
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.6 -r requirements.txt```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.6```
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
@ -126,10 +126,10 @@ After this you should have everything installed and can proceed to running Comfy
|
|||||||
|
|
||||||
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
|
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
|
||||||
|
|
||||||
1. Install pytorch. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide.
|
1. Install pytorch nightly. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide (make sure to install the latest pytorch nightly).
|
||||||
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
|
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
|
||||||
1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies).
|
1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies).
|
||||||
1. Launch ComfyUI by running `python main.py`.
|
1. Launch ComfyUI by running `python main.py --force-fp16`. Note that --force-fp16 will only work if you installed the latest pytorch nightly.
|
||||||
|
|
||||||
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
|
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
|
||||||
|
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from ..ldm.modules.diffusionmodules.util import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from ..ldm.modules.attention import SpatialTransformer
|
from ..ldm.modules.attention import SpatialTransformer
|
||||||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
|
||||||
from ..ldm.util import exists
|
from ..ldm.util import exists
|
||||||
|
|
||||||
|
|
||||||
@ -57,6 +57,7 @@ class ControlNet(nn.Module):
|
|||||||
transformer_depth_middle=None,
|
transformer_depth_middle=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||||
if use_spatial_transformer:
|
if use_spatial_transformer:
|
||||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||||
|
|
||||||
@ -200,13 +201,7 @@ class ControlNet(nn.Module):
|
|||||||
|
|
||||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||||
layers.append(
|
layers.append(
|
||||||
AttentionBlock(
|
SpatialTransformer(
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=dim_head,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
|
||||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint
|
use_checkpoint=use_checkpoint
|
||||||
@ -259,13 +254,7 @@ class ControlNet(nn.Module):
|
|||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
),
|
),
|
||||||
AttentionBlock(
|
SpatialTransformer( # always uses a self-attn
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=dim_head,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
|
||||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint
|
use_checkpoint=use_checkpoint
|
||||||
|
|||||||
@ -39,6 +39,7 @@ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORI
|
|||||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
||||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||||
|
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||||
cm_group = parser.add_mutually_exclusive_group()
|
cm_group = parser.add_mutually_exclusive_group()
|
||||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||||
@ -84,7 +85,12 @@ parser.add_argument("--dont-print-server", action="store_true", help="Don't prin
|
|||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||||
|
|
||||||
|
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.windows_standalone_build:
|
if args.windows_standalone_build:
|
||||||
args.auto_launch = True
|
args.auto_launch = True
|
||||||
|
|
||||||
|
if args.disable_auto_launch:
|
||||||
|
args.auto_launch = False
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import math
|
|||||||
from scipy import integrate
|
from scipy import integrate
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchdiffeq import odeint
|
|
||||||
import torchsde
|
import torchsde
|
||||||
from tqdm.auto import trange, tqdm
|
from tqdm.auto import trange, tqdm
|
||||||
|
|
||||||
@ -287,30 +286,6 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
|
||||||
v = torch.randint_like(x, 2) * 2 - 1
|
|
||||||
fevals = 0
|
|
||||||
def ode_fn(sigma, x):
|
|
||||||
nonlocal fevals
|
|
||||||
with torch.enable_grad():
|
|
||||||
x = x[0].detach().requires_grad_()
|
|
||||||
denoised = model(x, sigma * s_in, **extra_args)
|
|
||||||
d = to_d(x, sigma, denoised)
|
|
||||||
fevals += 1
|
|
||||||
grad = torch.autograd.grad((d * v).sum(), x)[0]
|
|
||||||
d_ll = (v * grad).flatten(1).sum(1)
|
|
||||||
return d.detach(), d_ll
|
|
||||||
x_min = x, x.new_zeros([x.shape[0]])
|
|
||||||
t = x.new_tensor([sigma_min, sigma_max])
|
|
||||||
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
|
|
||||||
latent, delta_ll = sol[0][-1], sol[1][-1]
|
|
||||||
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
|
|
||||||
return ll_prior + delta_ll, {'fevals': fevals}
|
|
||||||
|
|
||||||
|
|
||||||
class PIDStepSizeController:
|
class PIDStepSizeController:
|
||||||
"""A PID controller for ODE adaptive step size control."""
|
"""A PID controller for ODE adaptive step size control."""
|
||||||
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
||||||
|
|||||||
@ -52,9 +52,9 @@ def init_(tensor):
|
|||||||
|
|
||||||
# feedforward
|
# feedforward
|
||||||
class GEGLU(nn.Module):
|
class GEGLU(nn.Module):
|
||||||
def __init__(self, dim_in, dim_out, dtype=None):
|
def __init__(self, dim_in, dim_out, dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype)
|
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||||
@ -62,19 +62,19 @@ class GEGLU(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None):
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = int(dim * mult)
|
||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
project_in = nn.Sequential(
|
project_in = nn.Sequential(
|
||||||
comfy.ops.Linear(dim, inner_dim, dtype=dtype),
|
comfy.ops.Linear(dim, inner_dim, dtype=dtype, device=device),
|
||||||
nn.GELU()
|
nn.GELU()
|
||||||
) if not glu else GEGLU(dim, inner_dim, dtype=dtype)
|
) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
project_in,
|
project_in,
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
comfy.ops.Linear(inner_dim, dim_out, dtype=dtype)
|
comfy.ops.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -90,8 +90,8 @@ def zero_module(module):
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels, dtype=None):
|
def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
|
||||||
class SpatialSelfAttention(nn.Module):
|
class SpatialSelfAttention(nn.Module):
|
||||||
@ -148,7 +148,7 @@ class SpatialSelfAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CrossAttentionBirchSan(nn.Module):
|
class CrossAttentionBirchSan(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -156,12 +156,12 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
|
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -245,7 +245,7 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CrossAttentionDoggettx(nn.Module):
|
class CrossAttentionDoggettx(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -253,12 +253,12 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
|
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -343,7 +343,7 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
return self.to_out(r2)
|
return self.to_out(r2)
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -351,12 +351,12 @@ class CrossAttention(nn.Module):
|
|||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
|
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -399,7 +399,7 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
class MemoryEfficientCrossAttention(nn.Module):
|
class MemoryEfficientCrossAttention(nn.Module):
|
||||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||||
f"{heads} heads.")
|
f"{heads} heads.")
|
||||||
@ -409,11 +409,11 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
|
|
||||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
@ -450,7 +450,7 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
class CrossAttentionPytorch(nn.Module):
|
class CrossAttentionPytorch(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -458,11 +458,11 @@ class CrossAttentionPytorch(nn.Module):
|
|||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
|
|
||||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
@ -508,17 +508,17 @@ else:
|
|||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||||
disable_self_attn=False, dtype=None):
|
disable_self_attn=False, dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.disable_self_attn = disable_self_attn
|
self.disable_self_attn = disable_self_attn
|
||||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||||
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype) # is a self-attention if not self.disable_self_attn
|
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device) # is a self-attention if not self.disable_self_attn
|
||||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype)
|
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device)
|
||||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype) # is self-attn if context is none
|
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device) # is self-attn if context is none
|
||||||
self.norm1 = nn.LayerNorm(dim, dtype=dtype)
|
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
|
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
|
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.d_head = d_head
|
self.d_head = d_head
|
||||||
@ -648,34 +648,34 @@ class SpatialTransformer(nn.Module):
|
|||||||
def __init__(self, in_channels, n_heads, d_head,
|
def __init__(self, in_channels, n_heads, d_head,
|
||||||
depth=1, dropout=0., context_dim=None,
|
depth=1, dropout=0., context_dim=None,
|
||||||
disable_self_attn=False, use_linear=False,
|
disable_self_attn=False, use_linear=False,
|
||||||
use_checkpoint=True, dtype=None):
|
use_checkpoint=True, dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if exists(context_dim) and not isinstance(context_dim, list):
|
if exists(context_dim) and not isinstance(context_dim, list):
|
||||||
context_dim = [context_dim] * depth
|
context_dim = [context_dim] * depth
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
inner_dim = n_heads * d_head
|
inner_dim = n_heads * d_head
|
||||||
self.norm = Normalize(in_channels, dtype=dtype)
|
self.norm = Normalize(in_channels, dtype=dtype, device=device)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
self.proj_in = nn.Conv2d(in_channels,
|
self.proj_in = nn.Conv2d(in_channels,
|
||||||
inner_dim,
|
inner_dim,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0, dtype=dtype)
|
padding=0, dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype)
|
self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.transformer_blocks = nn.ModuleList(
|
self.transformer_blocks = nn.ModuleList(
|
||||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
||||||
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype)
|
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device)
|
||||||
for d in range(depth)]
|
for d in range(depth)]
|
||||||
)
|
)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
self.proj_out = nn.Conv2d(inner_dim,in_channels,
|
self.proj_out = nn.Conv2d(inner_dim,in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0, dtype=dtype)
|
padding=0, dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype)
|
self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device)
|
||||||
self.use_linear = use_linear
|
self.use_linear = use_linear
|
||||||
|
|
||||||
def forward(self, x, context=None, transformer_options={}):
|
def forward(self, x, context=None, transformer_options={}):
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from typing import Optional, Any
|
|||||||
|
|
||||||
from ..attention import MemoryEfficientCrossAttention
|
from ..attention import MemoryEfficientCrossAttention
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
if model_management.xformers_enabled_vae():
|
if model_management.xformers_enabled_vae():
|
||||||
import xformers
|
import xformers
|
||||||
@ -48,7 +49,7 @@ class Upsample(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
self.conv = torch.nn.Conv2d(in_channels,
|
self.conv = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -67,7 +68,7 @@ class Downsample(nn.Module):
|
|||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
# no asymmetric padding in torch conv, must do it ourselves
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
self.conv = torch.nn.Conv2d(in_channels,
|
self.conv = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=2,
|
stride=2,
|
||||||
@ -95,30 +96,30 @@ class ResnetBlock(nn.Module):
|
|||||||
|
|
||||||
self.swish = torch.nn.SiLU(inplace=True)
|
self.swish = torch.nn.SiLU(inplace=True)
|
||||||
self.norm1 = Normalize(in_channels)
|
self.norm1 = Normalize(in_channels)
|
||||||
self.conv1 = torch.nn.Conv2d(in_channels,
|
self.conv1 = comfy.ops.Conv2d(in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
if temb_channels > 0:
|
if temb_channels > 0:
|
||||||
self.temb_proj = torch.nn.Linear(temb_channels,
|
self.temb_proj = comfy.ops.Linear(temb_channels,
|
||||||
out_channels)
|
out_channels)
|
||||||
self.norm2 = Normalize(out_channels)
|
self.norm2 = Normalize(out_channels)
|
||||||
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||||
self.conv2 = torch.nn.Conv2d(out_channels,
|
self.conv2 = comfy.ops.Conv2d(out_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
if self.use_conv_shortcut:
|
if self.use_conv_shortcut:
|
||||||
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
self.conv_shortcut = comfy.ops.Conv2d(in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
else:
|
else:
|
||||||
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
self.nin_shortcut = comfy.ops.Conv2d(in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -188,22 +189,22 @@ class AttnBlock(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels)
|
||||||
self.q = torch.nn.Conv2d(in_channels,
|
self.q = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.k = torch.nn.Conv2d(in_channels,
|
self.k = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.v = torch.nn.Conv2d(in_channels,
|
self.v = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
self.proj_out = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -243,22 +244,22 @@ class MemoryEfficientAttnBlock(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels)
|
||||||
self.q = torch.nn.Conv2d(in_channels,
|
self.q = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.k = torch.nn.Conv2d(in_channels,
|
self.k = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.v = torch.nn.Conv2d(in_channels,
|
self.v = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
self.proj_out = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -302,22 +303,22 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels)
|
||||||
self.q = torch.nn.Conv2d(in_channels,
|
self.q = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.k = torch.nn.Conv2d(in_channels,
|
self.k = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.v = torch.nn.Conv2d(in_channels,
|
self.v = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
self.proj_out = comfy.ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -399,14 +400,14 @@ class Model(nn.Module):
|
|||||||
# timestep embedding
|
# timestep embedding
|
||||||
self.temb = nn.Module()
|
self.temb = nn.Module()
|
||||||
self.temb.dense = nn.ModuleList([
|
self.temb.dense = nn.ModuleList([
|
||||||
torch.nn.Linear(self.ch,
|
comfy.ops.Linear(self.ch,
|
||||||
self.temb_ch),
|
self.temb_ch),
|
||||||
torch.nn.Linear(self.temb_ch,
|
comfy.ops.Linear(self.temb_ch,
|
||||||
self.temb_ch),
|
self.temb_ch),
|
||||||
])
|
])
|
||||||
|
|
||||||
# downsampling
|
# downsampling
|
||||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
self.conv_in = comfy.ops.Conv2d(in_channels,
|
||||||
self.ch,
|
self.ch,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -475,7 +476,7 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
# end
|
# end
|
||||||
self.norm_out = Normalize(block_in)
|
self.norm_out = Normalize(block_in)
|
||||||
self.conv_out = torch.nn.Conv2d(block_in,
|
self.conv_out = comfy.ops.Conv2d(block_in,
|
||||||
out_ch,
|
out_ch,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -548,7 +549,7 @@ class Encoder(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
# downsampling
|
# downsampling
|
||||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
self.conv_in = comfy.ops.Conv2d(in_channels,
|
||||||
self.ch,
|
self.ch,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -593,7 +594,7 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
# end
|
# end
|
||||||
self.norm_out = Normalize(block_in)
|
self.norm_out = Normalize(block_in)
|
||||||
self.conv_out = torch.nn.Conv2d(block_in,
|
self.conv_out = comfy.ops.Conv2d(block_in,
|
||||||
2*z_channels if double_z else z_channels,
|
2*z_channels if double_z else z_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -653,7 +654,7 @@ class Decoder(nn.Module):
|
|||||||
self.z_shape, np.prod(self.z_shape)))
|
self.z_shape, np.prod(self.z_shape)))
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
self.conv_in = torch.nn.Conv2d(z_channels,
|
self.conv_in = comfy.ops.Conv2d(z_channels,
|
||||||
block_in,
|
block_in,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -695,7 +696,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
# end
|
# end
|
||||||
self.norm_out = Normalize(block_in)
|
self.norm_out = Normalize(block_in)
|
||||||
self.conv_out = torch.nn.Conv2d(block_in,
|
self.conv_out = comfy.ops.Conv2d(block_in,
|
||||||
out_ch,
|
out_ch,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
|
|||||||
@ -19,45 +19,6 @@ from ..attention import SpatialTransformer
|
|||||||
from comfy.ldm.util import exists
|
from comfy.ldm.util import exists
|
||||||
|
|
||||||
|
|
||||||
# dummy replace
|
|
||||||
def convert_module_to_f16(x):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def convert_module_to_f32(x):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
## go
|
|
||||||
class AttentionPool2d(nn.Module):
|
|
||||||
"""
|
|
||||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
spacial_dim: int,
|
|
||||||
embed_dim: int,
|
|
||||||
num_heads_channels: int,
|
|
||||||
output_dim: int = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
|
|
||||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
|
||||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
|
||||||
self.num_heads = embed_dim // num_heads_channels
|
|
||||||
self.attention = QKVAttention(self.num_heads)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
b, c, *_spatial = x.shape
|
|
||||||
x = x.reshape(b, c, -1) # NC(HW)
|
|
||||||
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
|
||||||
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
|
||||||
x = self.qkv_proj(x)
|
|
||||||
x = self.attention(x)
|
|
||||||
x = self.c_proj(x)
|
|
||||||
return x[:, :, 0]
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepBlock(nn.Module):
|
class TimestepBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
Any module where forward() takes timestep embeddings as a second argument.
|
Any module where forward() takes timestep embeddings as a second argument.
|
||||||
@ -111,14 +72,14 @@ class Upsample(nn.Module):
|
|||||||
upsampling occurs in the inner-two dimensions.
|
upsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
self.use_conv = use_conv
|
self.use_conv = use_conv
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
if use_conv:
|
if use_conv:
|
||||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype)
|
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x, output_shape=None):
|
def forward(self, x, output_shape=None):
|
||||||
assert x.shape[1] == self.channels
|
assert x.shape[1] == self.channels
|
||||||
@ -138,19 +99,6 @@ class Upsample(nn.Module):
|
|||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class TransposedUpsample(nn.Module):
|
|
||||||
'Learned 2x upsampling without padding'
|
|
||||||
def __init__(self, channels, out_channels=None, ks=5):
|
|
||||||
super().__init__()
|
|
||||||
self.channels = channels
|
|
||||||
self.out_channels = out_channels or channels
|
|
||||||
|
|
||||||
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
|
|
||||||
|
|
||||||
def forward(self,x):
|
|
||||||
return self.up(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Downsample(nn.Module):
|
class Downsample(nn.Module):
|
||||||
"""
|
"""
|
||||||
A downsampling layer with an optional convolution.
|
A downsampling layer with an optional convolution.
|
||||||
@ -160,7 +108,7 @@ class Downsample(nn.Module):
|
|||||||
downsampling occurs in the inner-two dimensions.
|
downsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
@ -169,7 +117,7 @@ class Downsample(nn.Module):
|
|||||||
stride = 2 if dims != 3 else (1, 2, 2)
|
stride = 2 if dims != 3 else (1, 2, 2)
|
||||||
if use_conv:
|
if use_conv:
|
||||||
self.op = conv_nd(
|
self.op = conv_nd(
|
||||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype
|
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert self.channels == self.out_channels
|
assert self.channels == self.out_channels
|
||||||
@ -208,7 +156,8 @@ class ResBlock(TimestepBlock):
|
|||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
up=False,
|
up=False,
|
||||||
down=False,
|
down=False,
|
||||||
dtype=None
|
dtype=None,
|
||||||
|
device=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
@ -220,19 +169,19 @@ class ResBlock(TimestepBlock):
|
|||||||
self.use_scale_shift_norm = use_scale_shift_norm
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
|
||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
nn.GroupNorm(32, channels, dtype=dtype),
|
nn.GroupNorm(32, channels, dtype=dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.updown = up or down
|
self.updown = up or down
|
||||||
|
|
||||||
if up:
|
if up:
|
||||||
self.h_upd = Upsample(channels, False, dims, dtype=dtype)
|
self.h_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
|
||||||
self.x_upd = Upsample(channels, False, dims, dtype=dtype)
|
self.x_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
|
||||||
elif down:
|
elif down:
|
||||||
self.h_upd = Downsample(channels, False, dims, dtype=dtype)
|
self.h_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
|
||||||
self.x_upd = Downsample(channels, False, dims, dtype=dtype)
|
self.x_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
self.h_upd = self.x_upd = nn.Identity()
|
self.h_upd = self.x_upd = nn.Identity()
|
||||||
|
|
||||||
@ -240,15 +189,15 @@ class ResBlock(TimestepBlock):
|
|||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(
|
linear(
|
||||||
emb_channels,
|
emb_channels,
|
||||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype
|
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.out_layers = nn.Sequential(
|
self.out_layers = nn.Sequential(
|
||||||
nn.GroupNorm(32, self.out_channels, dtype=dtype),
|
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Dropout(p=dropout),
|
nn.Dropout(p=dropout),
|
||||||
zero_module(
|
zero_module(
|
||||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype)
|
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -256,10 +205,10 @@ class ResBlock(TimestepBlock):
|
|||||||
self.skip_connection = nn.Identity()
|
self.skip_connection = nn.Identity()
|
||||||
elif use_conv:
|
elif use_conv:
|
||||||
self.skip_connection = conv_nd(
|
self.skip_connection = conv_nd(
|
||||||
dims, channels, self.out_channels, 3, padding=1, dtype=dtype
|
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype)
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x, emb):
|
def forward(self, x, emb):
|
||||||
"""
|
"""
|
||||||
@ -295,142 +244,6 @@ class ResBlock(TimestepBlock):
|
|||||||
h = self.out_layers(h)
|
h = self.out_layers(h)
|
||||||
return self.skip_connection(x) + h
|
return self.skip_connection(x) + h
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlock(nn.Module):
|
|
||||||
"""
|
|
||||||
An attention block that allows spatial positions to attend to each other.
|
|
||||||
Originally ported from here, but adapted to the N-d case.
|
|
||||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
channels,
|
|
||||||
num_heads=1,
|
|
||||||
num_head_channels=-1,
|
|
||||||
use_checkpoint=False,
|
|
||||||
use_new_attention_order=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.channels = channels
|
|
||||||
if num_head_channels == -1:
|
|
||||||
self.num_heads = num_heads
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
channels % num_head_channels == 0
|
|
||||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
|
||||||
self.num_heads = channels // num_head_channels
|
|
||||||
self.use_checkpoint = use_checkpoint
|
|
||||||
self.norm = normalization(channels)
|
|
||||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
|
||||||
if use_new_attention_order:
|
|
||||||
# split qkv before split heads
|
|
||||||
self.attention = QKVAttention(self.num_heads)
|
|
||||||
else:
|
|
||||||
# split heads before split qkv
|
|
||||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
|
||||||
|
|
||||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
|
||||||
#return pt_checkpoint(self._forward, x) # pytorch
|
|
||||||
|
|
||||||
def _forward(self, x):
|
|
||||||
b, c, *spatial = x.shape
|
|
||||||
x = x.reshape(b, c, -1)
|
|
||||||
qkv = self.qkv(self.norm(x))
|
|
||||||
h = self.attention(qkv)
|
|
||||||
h = self.proj_out(h)
|
|
||||||
return (x + h).reshape(b, c, *spatial)
|
|
||||||
|
|
||||||
|
|
||||||
def count_flops_attn(model, _x, y):
|
|
||||||
"""
|
|
||||||
A counter for the `thop` package to count the operations in an
|
|
||||||
attention operation.
|
|
||||||
Meant to be used like:
|
|
||||||
macs, params = thop.profile(
|
|
||||||
model,
|
|
||||||
inputs=(inputs, timestamps),
|
|
||||||
custom_ops={QKVAttention: QKVAttention.count_flops},
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
b, c, *spatial = y[0].shape
|
|
||||||
num_spatial = int(np.prod(spatial))
|
|
||||||
# We perform two matmuls with the same number of ops.
|
|
||||||
# The first computes the weight matrix, the second computes
|
|
||||||
# the combination of the value vectors.
|
|
||||||
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
|
||||||
model.total_ops += th.DoubleTensor([matmul_ops])
|
|
||||||
|
|
||||||
|
|
||||||
class QKVAttentionLegacy(nn.Module):
|
|
||||||
"""
|
|
||||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, n_heads):
|
|
||||||
super().__init__()
|
|
||||||
self.n_heads = n_heads
|
|
||||||
|
|
||||||
def forward(self, qkv):
|
|
||||||
"""
|
|
||||||
Apply QKV attention.
|
|
||||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
|
||||||
:return: an [N x (H * C) x T] tensor after attention.
|
|
||||||
"""
|
|
||||||
bs, width, length = qkv.shape
|
|
||||||
assert width % (3 * self.n_heads) == 0
|
|
||||||
ch = width // (3 * self.n_heads)
|
|
||||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
|
||||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
||||||
weight = th.einsum(
|
|
||||||
"bct,bcs->bts", q * scale, k * scale
|
|
||||||
) # More stable with f16 than dividing afterwards
|
|
||||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
||||||
a = th.einsum("bts,bcs->bct", weight, v)
|
|
||||||
return a.reshape(bs, -1, length)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def count_flops(model, _x, y):
|
|
||||||
return count_flops_attn(model, _x, y)
|
|
||||||
|
|
||||||
|
|
||||||
class QKVAttention(nn.Module):
|
|
||||||
"""
|
|
||||||
A module which performs QKV attention and splits in a different order.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, n_heads):
|
|
||||||
super().__init__()
|
|
||||||
self.n_heads = n_heads
|
|
||||||
|
|
||||||
def forward(self, qkv):
|
|
||||||
"""
|
|
||||||
Apply QKV attention.
|
|
||||||
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
|
||||||
:return: an [N x (H * C) x T] tensor after attention.
|
|
||||||
"""
|
|
||||||
bs, width, length = qkv.shape
|
|
||||||
assert width % (3 * self.n_heads) == 0
|
|
||||||
ch = width // (3 * self.n_heads)
|
|
||||||
q, k, v = qkv.chunk(3, dim=1)
|
|
||||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
||||||
weight = th.einsum(
|
|
||||||
"bct,bcs->bts",
|
|
||||||
(q * scale).view(bs * self.n_heads, ch, length),
|
|
||||||
(k * scale).view(bs * self.n_heads, ch, length),
|
|
||||||
) # More stable with f16 than dividing afterwards
|
|
||||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
||||||
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
|
||||||
return a.reshape(bs, -1, length)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def count_flops(model, _x, y):
|
|
||||||
return count_flops_attn(model, _x, y)
|
|
||||||
|
|
||||||
|
|
||||||
class Timestep(nn.Module):
|
class Timestep(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -503,8 +316,10 @@ class UNetModel(nn.Module):
|
|||||||
use_linear_in_transformer=False,
|
use_linear_in_transformer=False,
|
||||||
adm_in_channels=None,
|
adm_in_channels=None,
|
||||||
transformer_depth_middle=None,
|
transformer_depth_middle=None,
|
||||||
|
device=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||||
if use_spatial_transformer:
|
if use_spatial_transformer:
|
||||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||||
|
|
||||||
@ -564,9 +379,9 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
time_embed_dim = model_channels * 4
|
time_embed_dim = model_channels * 4
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
linear(model_channels, time_embed_dim, dtype=self.dtype),
|
linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
@ -579,9 +394,9 @@ class UNetModel(nn.Module):
|
|||||||
assert adm_in_channels is not None
|
assert adm_in_channels is not None
|
||||||
self.label_emb = nn.Sequential(
|
self.label_emb = nn.Sequential(
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
linear(adm_in_channels, time_embed_dim, dtype=self.dtype),
|
linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -590,7 +405,7 @@ class UNetModel(nn.Module):
|
|||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
TimestepEmbedSequential(
|
TimestepEmbedSequential(
|
||||||
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype)
|
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -609,7 +424,8 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
dtype=self.dtype
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = mult * model_channels
|
ch = mult * model_channels
|
||||||
@ -628,17 +444,10 @@ class UNetModel(nn.Module):
|
|||||||
disabled_sa = False
|
disabled_sa = False
|
||||||
|
|
||||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||||
layers.append(
|
layers.append(SpatialTransformer(
|
||||||
AttentionBlock(
|
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=dim_head,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
|
||||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
@ -657,11 +466,12 @@ class UNetModel(nn.Module):
|
|||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
down=True,
|
down=True,
|
||||||
dtype=self.dtype
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Downsample(
|
else Downsample(
|
||||||
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype
|
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -686,18 +496,13 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
dtype=self.dtype
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
),
|
),
|
||||||
AttentionBlock(
|
SpatialTransformer( # always uses a self-attn
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=dim_head,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
|
||||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||||
),
|
),
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
@ -706,7 +511,8 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
dtype=self.dtype
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
@ -724,7 +530,8 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
dtype=self.dtype
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = model_channels * mult
|
ch = model_channels * mult
|
||||||
@ -744,16 +551,10 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||||
layers.append(
|
layers.append(
|
||||||
AttentionBlock(
|
SpatialTransformer(
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads_upsample,
|
|
||||||
num_head_channels=dim_head,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
|
||||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if level and i == self.num_res_blocks[level]:
|
if level and i == self.num_res_blocks[level]:
|
||||||
@ -768,43 +569,28 @@ class UNetModel(nn.Module):
|
|||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
up=True,
|
up=True,
|
||||||
dtype=self.dtype
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype)
|
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device)
|
||||||
)
|
)
|
||||||
ds //= 2
|
ds //= 2
|
||||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
nn.GroupNorm(32, ch, dtype=self.dtype),
|
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
|
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
|
||||||
)
|
)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
self.id_predictor = nn.Sequential(
|
self.id_predictor = nn.Sequential(
|
||||||
nn.GroupNorm(32, ch, dtype=self.dtype),
|
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||||
conv_nd(dims, model_channels, n_embed, 1),
|
conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
|
||||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||||
)
|
)
|
||||||
|
|
||||||
def convert_to_fp16(self):
|
|
||||||
"""
|
|
||||||
Convert the torso of the model to float16.
|
|
||||||
"""
|
|
||||||
self.input_blocks.apply(convert_module_to_f16)
|
|
||||||
self.middle_block.apply(convert_module_to_f16)
|
|
||||||
self.output_blocks.apply(convert_module_to_f16)
|
|
||||||
|
|
||||||
def convert_to_fp32(self):
|
|
||||||
"""
|
|
||||||
Convert the torso of the model to float32.
|
|
||||||
"""
|
|
||||||
self.input_blocks.apply(convert_module_to_f32)
|
|
||||||
self.middle_block.apply(convert_module_to_f32)
|
|
||||||
self.output_blocks.apply(convert_module_to_f32)
|
|
||||||
|
|
||||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||||
"""
|
"""
|
||||||
Apply the model to an input batch.
|
Apply the model to an input batch.
|
||||||
|
|||||||
@ -12,14 +12,14 @@ class ModelType(Enum):
|
|||||||
V_PREDICTION = 2
|
V_PREDICTION = 2
|
||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
class BaseModel(torch.nn.Module):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
unet_config = model_config.unet_config
|
unet_config = model_config.unet_config
|
||||||
self.latent_format = model_config.latent_format
|
self.latent_format = model_config.latent_format
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||||
self.diffusion_model = UNetModel(**unet_config)
|
self.diffusion_model = UNetModel(**unet_config, device=device)
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||||
if self.adm_channels is None:
|
if self.adm_channels is None:
|
||||||
@ -107,8 +107,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class SD21UNCLIP(BaseModel):
|
class SD21UNCLIP(BaseModel):
|
||||||
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION):
|
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||||
super().__init__(model_config, model_type)
|
super().__init__(model_config, model_type, device=device)
|
||||||
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
@ -143,13 +143,13 @@ class SD21UNCLIP(BaseModel):
|
|||||||
return adm_out
|
return adm_out
|
||||||
|
|
||||||
class SDInpaint(BaseModel):
|
class SDInpaint(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__(model_config, model_type)
|
super().__init__(model_config, model_type, device=device)
|
||||||
self.concat_keys = ("mask", "masked_image")
|
self.concat_keys = ("mask", "masked_image")
|
||||||
|
|
||||||
class SDXLRefiner(BaseModel):
|
class SDXLRefiner(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__(model_config, model_type)
|
super().__init__(model_config, model_type, device=device)
|
||||||
self.embedder = Timestep(256)
|
self.embedder = Timestep(256)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
@ -164,7 +164,6 @@ class SDXLRefiner(BaseModel):
|
|||||||
else:
|
else:
|
||||||
aesthetic_score = kwargs.get("aesthetic_score", 6)
|
aesthetic_score = kwargs.get("aesthetic_score", 6)
|
||||||
|
|
||||||
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
|
|
||||||
out = []
|
out = []
|
||||||
out.append(self.embedder(torch.Tensor([height])))
|
out.append(self.embedder(torch.Tensor([height])))
|
||||||
out.append(self.embedder(torch.Tensor([width])))
|
out.append(self.embedder(torch.Tensor([width])))
|
||||||
@ -175,8 +174,8 @@ class SDXLRefiner(BaseModel):
|
|||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
|
|
||||||
class SDXL(BaseModel):
|
class SDXL(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__(model_config, model_type)
|
super().__init__(model_config, model_type, device=device)
|
||||||
self.embedder = Timestep(256)
|
self.embedder = Timestep(256)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
@ -188,7 +187,6 @@ class SDXL(BaseModel):
|
|||||||
target_width = kwargs.get("target_width", width)
|
target_width = kwargs.get("target_width", width)
|
||||||
target_height = kwargs.get("target_height", height)
|
target_height = kwargs.get("target_height", height)
|
||||||
|
|
||||||
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
|
|
||||||
out = []
|
out = []
|
||||||
out.append(self.embedder(torch.Tensor([height])))
|
out.append(self.embedder(torch.Tensor([height])))
|
||||||
out.append(self.embedder(torch.Tensor([width])))
|
out.append(self.embedder(torch.Tensor([width])))
|
||||||
|
|||||||
@ -364,6 +364,7 @@ def text_encoder_device():
|
|||||||
if args.gpu_only:
|
if args.gpu_only:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
|
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
|
||||||
|
#NOTE: on a Ryzen 5 7600X with 4080 it's faster to shift to GPU
|
||||||
if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough.
|
if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough.
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
else:
|
else:
|
||||||
@ -534,7 +535,7 @@ def should_use_fp16(device=None, model_params=0):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
#FP16 is just broken on these cards
|
#FP16 is just broken on these cards
|
||||||
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450"]
|
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX"]
|
||||||
for x in nvidia_16_series:
|
for x in nvidia_16_series:
|
||||||
if x in props.name:
|
if x in props.name:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -19,11 +19,11 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
strength = 1.0
|
strength = 1.0
|
||||||
if 'timestep_start' in cond[1]:
|
if 'timestep_start' in cond[1]:
|
||||||
timestep_start = cond[1]['timestep_start']
|
timestep_start = cond[1]['timestep_start']
|
||||||
if timestep_in > timestep_start:
|
if timestep_in[0] > timestep_start:
|
||||||
return None
|
return None
|
||||||
if 'timestep_end' in cond[1]:
|
if 'timestep_end' in cond[1]:
|
||||||
timestep_end = cond[1]['timestep_end']
|
timestep_end = cond[1]['timestep_end']
|
||||||
if timestep_in < timestep_end:
|
if timestep_in[0] < timestep_end:
|
||||||
return None
|
return None
|
||||||
if 'area' in cond[1]:
|
if 'area' in cond[1]:
|
||||||
area = cond[1]['area']
|
area = cond[1]['area']
|
||||||
|
|||||||
29
comfy/sd.py
29
comfy/sd.py
@ -70,13 +70,22 @@ def load_lora(lora, to_load):
|
|||||||
alpha = lora[alpha_name].item()
|
alpha = lora[alpha_name].item()
|
||||||
loaded_keys.add(alpha_name)
|
loaded_keys.add(alpha_name)
|
||||||
|
|
||||||
A_name = "{}.lora_up.weight".format(x)
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
B_name = "{}.lora_down.weight".format(x)
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
mid_name = "{}.lora_mid.weight".format(x)
|
A_name = None
|
||||||
|
|
||||||
if A_name in lora.keys():
|
if regular_lora in lora.keys():
|
||||||
|
A_name = regular_lora
|
||||||
|
B_name = "{}.lora_down.weight".format(x)
|
||||||
|
mid_name = "{}.lora_mid.weight".format(x)
|
||||||
|
elif diffusers_lora in lora.keys():
|
||||||
|
A_name = diffusers_lora
|
||||||
|
B_name = "{}_lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
|
||||||
|
if A_name is not None:
|
||||||
mid = None
|
mid = None
|
||||||
if mid_name in lora.keys():
|
if mid_name is not None and mid_name in lora.keys():
|
||||||
mid = lora[mid_name]
|
mid = lora[mid_name]
|
||||||
loaded_keys.add(mid_name)
|
loaded_keys.add(mid_name)
|
||||||
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
||||||
@ -202,6 +211,11 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||||
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
|
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
|
||||||
|
|
||||||
|
diffusers_lora_key = "unet.{}".format(k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
||||||
|
if diffusers_lora_key.endswith(".to_out.0"):
|
||||||
|
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||||
|
key_map[diffusers_lora_key] = "diffusion_model.{}".format(diffusers_keys[k])
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
def set_attr(obj, attr, value):
|
def set_attr(obj, attr, value):
|
||||||
@ -864,7 +878,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
use_fp16 = model_management.should_use_fp16()
|
use_fp16 = model_management.should_use_fp16()
|
||||||
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
|
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = 3
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||||
control_model = cldm.ControlNet(**controlnet_config)
|
control_model = cldm.ControlNet(**controlnet_config)
|
||||||
|
|
||||||
if pth:
|
if pth:
|
||||||
@ -1169,8 +1183,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
model = model_config.get_model(sd, "model.diffusion_model.")
|
model = model_config.get_model(sd, "model.diffusion_model.", device=offload_device)
|
||||||
model = model.to(offload_device)
|
|
||||||
model.load_model_weights(sd, "model.diffusion_model.")
|
model.load_model_weights(sd, "model.diffusion_model.")
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
|
|||||||
@ -91,13 +91,15 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
def set_up_textual_embeddings(self, tokens, current_embeds):
|
def set_up_textual_embeddings(self, tokens, current_embeds):
|
||||||
out_tokens = []
|
out_tokens = []
|
||||||
next_new_token = token_dict_size = current_embeds.weight.shape[0]
|
next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1
|
||||||
embedding_weights = []
|
embedding_weights = []
|
||||||
|
|
||||||
for x in tokens:
|
for x in tokens:
|
||||||
tokens_temp = []
|
tokens_temp = []
|
||||||
for y in x:
|
for y in x:
|
||||||
if isinstance(y, int):
|
if isinstance(y, int):
|
||||||
|
if y == token_dict_size: #EOS token
|
||||||
|
y = -1
|
||||||
tokens_temp += [y]
|
tokens_temp += [y]
|
||||||
else:
|
else:
|
||||||
if y.shape[0] == current_embeds.weight.shape[1]:
|
if y.shape[0] == current_embeds.weight.shape[1]:
|
||||||
@ -110,15 +112,21 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
tokens_temp += [self.empty_tokens[0][-1]]
|
tokens_temp += [self.empty_tokens[0][-1]]
|
||||||
out_tokens += [tokens_temp]
|
out_tokens += [tokens_temp]
|
||||||
|
|
||||||
|
n = token_dict_size
|
||||||
if len(embedding_weights) > 0:
|
if len(embedding_weights) > 0:
|
||||||
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
||||||
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
|
new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1]
|
||||||
n = token_dict_size
|
|
||||||
for x in embedding_weights:
|
for x in embedding_weights:
|
||||||
new_embedding.weight[n] = x
|
new_embedding.weight[n] = x
|
||||||
n += 1
|
n += 1
|
||||||
|
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding
|
||||||
self.transformer.set_input_embeddings(new_embedding)
|
self.transformer.set_input_embeddings(new_embedding)
|
||||||
return out_tokens
|
|
||||||
|
processed_tokens = []
|
||||||
|
for x in out_tokens:
|
||||||
|
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
|
||||||
|
|
||||||
|
return processed_tokens
|
||||||
|
|
||||||
def forward(self, tokens):
|
def forward(self, tokens):
|
||||||
backup_embeds = self.transformer.get_input_embeddings()
|
backup_embeds = self.transformer.get_input_embeddings()
|
||||||
|
|||||||
@ -109,8 +109,8 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.SDXL
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix=""):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
return model_base.SDXLRefiner(self)
|
return model_base.SDXLRefiner(self, device=device)
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
keys_to_replace = {}
|
keys_to_replace = {}
|
||||||
@ -152,8 +152,8 @@ class SDXL(supported_models_base.BASE):
|
|||||||
else:
|
else:
|
||||||
return model_base.ModelType.EPS
|
return model_base.ModelType.EPS
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix=""):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix))
|
return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
keys_to_replace = {}
|
keys_to_replace = {}
|
||||||
|
|||||||
@ -53,13 +53,13 @@ class BASE:
|
|||||||
for x in self.unet_extra_config:
|
for x in self.unet_extra_config:
|
||||||
self.unet_config[x] = self.unet_extra_config[x]
|
self.unet_config[x] = self.unet_extra_config[x]
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix=""):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
if self.inpaint_model():
|
if self.inpaint_model():
|
||||||
return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix))
|
return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||||
elif self.noise_aug_config is not None:
|
elif self.noise_aug_config is not None:
|
||||||
return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix))
|
return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
|
||||||
else:
|
else:
|
||||||
return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix))
|
return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|||||||
@ -6,6 +6,8 @@ import folder_paths
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
class ModelMergeSimple:
|
class ModelMergeSimple:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -101,8 +103,7 @@ class CheckpointSave:
|
|||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
prompt_info = json.dumps(prompt)
|
prompt_info = json.dumps(prompt)
|
||||||
|
|
||||||
metadata = {"prompt": prompt_info}
|
metadata = {}
|
||||||
|
|
||||||
|
|
||||||
enable_modelspec = True
|
enable_modelspec = True
|
||||||
if isinstance(model.model, comfy.model_base.SDXL):
|
if isinstance(model.model, comfy.model_base.SDXL):
|
||||||
@ -127,9 +128,11 @@ class CheckpointSave:
|
|||||||
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
|
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
|
||||||
metadata["modelspec.predict_key"] = "v"
|
metadata["modelspec.predict_key"] = "v"
|
||||||
|
|
||||||
if extra_pnginfo is not None:
|
if not args.disable_metadata:
|
||||||
for x in extra_pnginfo:
|
metadata["prompt"] = prompt_info
|
||||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|||||||
@ -40,7 +40,8 @@ def cuda_malloc_supported():
|
|||||||
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
|
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
|
||||||
"GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620",
|
"GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620",
|
||||||
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
|
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
|
||||||
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000"}
|
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000",
|
||||||
|
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
names = get_gpu_names()
|
names = get_gpu_names()
|
||||||
|
|||||||
17
execution.py
17
execution.py
@ -42,11 +42,14 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
|||||||
|
|
||||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||||
# check if node wants the lists
|
# check if node wants the lists
|
||||||
intput_is_list = False
|
input_is_list = False
|
||||||
if hasattr(obj, "INPUT_IS_LIST"):
|
if hasattr(obj, "INPUT_IS_LIST"):
|
||||||
intput_is_list = obj.INPUT_IS_LIST
|
input_is_list = obj.INPUT_IS_LIST
|
||||||
|
|
||||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
if len(input_data_all) == 0:
|
||||||
|
max_len_input = 0
|
||||||
|
else:
|
||||||
|
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||||
|
|
||||||
# get a slice of inputs, repeat last input when list isn't long enough
|
# get a slice of inputs, repeat last input when list isn't long enough
|
||||||
def slice_dict(d, i):
|
def slice_dict(d, i):
|
||||||
@ -56,11 +59,15 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
|||||||
return d_new
|
return d_new
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
if intput_is_list:
|
if input_is_list:
|
||||||
if allow_interrupt:
|
if allow_interrupt:
|
||||||
nodes.before_node_execution()
|
nodes.before_node_execution()
|
||||||
results.append(getattr(obj, func)(**input_data_all))
|
results.append(getattr(obj, func)(**input_data_all))
|
||||||
else:
|
elif max_len_input == 0:
|
||||||
|
if allow_interrupt:
|
||||||
|
nodes.before_node_execution()
|
||||||
|
results.append(getattr(obj, func)())
|
||||||
|
else:
|
||||||
for i in range(max_len_input):
|
for i in range(max_len_input):
|
||||||
if allow_interrupt:
|
if allow_interrupt:
|
||||||
nodes.before_node_execution()
|
nodes.before_node_execution()
|
||||||
|
|||||||
2
main.py
2
main.py
@ -160,6 +160,8 @@ if __name__ == "__main__":
|
|||||||
if args.auto_launch:
|
if args.auto_launch:
|
||||||
def startup_server(address, port):
|
def startup_server(address, port):
|
||||||
import webbrowser
|
import webbrowser
|
||||||
|
if os.name == 'nt' and address == '0.0.0.0':
|
||||||
|
address = '127.0.0.1'
|
||||||
webbrowser.open(f"http://{address}:{port}")
|
webbrowser.open(f"http://{address}:{port}")
|
||||||
call_on_start = startup_server
|
call_on_start = startup_server
|
||||||
|
|
||||||
|
|||||||
79
nodes.py
79
nodes.py
@ -26,6 +26,8 @@ import comfy.utils
|
|||||||
import comfy.clip_vision
|
import comfy.clip_vision
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
@ -352,12 +354,22 @@ class SaveLatent:
|
|||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
prompt_info = json.dumps(prompt)
|
prompt_info = json.dumps(prompt)
|
||||||
|
|
||||||
metadata = {"prompt": prompt_info}
|
metadata = None
|
||||||
if extra_pnginfo is not None:
|
if not args.disable_metadata:
|
||||||
for x in extra_pnginfo:
|
metadata = {"prompt": prompt_info}
|
||||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
file = f"{filename}_{counter:05}_.latent"
|
file = f"{filename}_{counter:05}_.latent"
|
||||||
|
|
||||||
|
results = list()
|
||||||
|
results.append({
|
||||||
|
"filename": file,
|
||||||
|
"subfolder": subfolder,
|
||||||
|
"type": "output"
|
||||||
|
})
|
||||||
|
|
||||||
file = os.path.join(full_output_folder, file)
|
file = os.path.join(full_output_folder, file)
|
||||||
|
|
||||||
output = {}
|
output = {}
|
||||||
@ -365,7 +377,7 @@ class SaveLatent:
|
|||||||
output["latent_format_version_0"] = torch.tensor([])
|
output["latent_format_version_0"] = torch.tensor([])
|
||||||
|
|
||||||
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
||||||
return {}
|
return { "ui": { "latents": results } }
|
||||||
|
|
||||||
|
|
||||||
class LoadLatent:
|
class LoadLatent:
|
||||||
@ -1043,6 +1055,47 @@ class LatentComposite:
|
|||||||
samples_out["samples"] = s
|
samples_out["samples"] = s
|
||||||
return (samples_out,)
|
return (samples_out,)
|
||||||
|
|
||||||
|
class LatentBlend:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"samples1": ("LATENT",),
|
||||||
|
"samples2": ("LATENT",),
|
||||||
|
"blend_factor": ("FLOAT", {
|
||||||
|
"default": 0.5,
|
||||||
|
"min": 0,
|
||||||
|
"max": 1,
|
||||||
|
"step": 0.01
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "blend"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
|
||||||
|
|
||||||
|
samples_out = samples1.copy()
|
||||||
|
samples1 = samples1["samples"]
|
||||||
|
samples2 = samples2["samples"]
|
||||||
|
|
||||||
|
if samples1.shape != samples2.shape:
|
||||||
|
samples2.permute(0, 3, 1, 2)
|
||||||
|
samples2 = comfy.utils.common_upscale(samples2, samples1.shape[3], samples1.shape[2], 'bicubic', crop='center')
|
||||||
|
samples2.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
samples_blended = self.blend_mode(samples1, samples2, blend_mode)
|
||||||
|
samples_blended = samples1 * blend_factor + samples_blended * (1 - blend_factor)
|
||||||
|
samples_out["samples"] = samples_blended
|
||||||
|
return (samples_out,)
|
||||||
|
|
||||||
|
def blend_mode(self, img1, img2, mode):
|
||||||
|
if mode == "normal":
|
||||||
|
return img2
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported blend mode: {mode}")
|
||||||
|
|
||||||
class LatentCrop:
|
class LatentCrop:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -1214,12 +1267,14 @@ class SaveImage:
|
|||||||
for image in images:
|
for image in images:
|
||||||
i = 255. * image.cpu().numpy()
|
i = 255. * image.cpu().numpy()
|
||||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||||
metadata = PngInfo()
|
metadata = None
|
||||||
if prompt is not None:
|
if not args.disable_metadata:
|
||||||
metadata.add_text("prompt", json.dumps(prompt))
|
metadata = PngInfo()
|
||||||
if extra_pnginfo is not None:
|
if prompt is not None:
|
||||||
for x in extra_pnginfo:
|
metadata.add_text("prompt", json.dumps(prompt))
|
||||||
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
||||||
|
|
||||||
file = f"{filename}_{counter:05}_.png"
|
file = f"{filename}_{counter:05}_.png"
|
||||||
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
|
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
|
||||||
@ -1487,6 +1542,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"KSamplerAdvanced": KSamplerAdvanced,
|
"KSamplerAdvanced": KSamplerAdvanced,
|
||||||
"SetLatentNoiseMask": SetLatentNoiseMask,
|
"SetLatentNoiseMask": SetLatentNoiseMask,
|
||||||
"LatentComposite": LatentComposite,
|
"LatentComposite": LatentComposite,
|
||||||
|
"LatentBlend": LatentBlend,
|
||||||
"LatentRotate": LatentRotate,
|
"LatentRotate": LatentRotate,
|
||||||
"LatentFlip": LatentFlip,
|
"LatentFlip": LatentFlip,
|
||||||
"LatentCrop": LatentCrop,
|
"LatentCrop": LatentCrop,
|
||||||
@ -1558,6 +1614,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"LatentUpscale": "Upscale Latent",
|
"LatentUpscale": "Upscale Latent",
|
||||||
"LatentUpscaleBy": "Upscale Latent By",
|
"LatentUpscaleBy": "Upscale Latent By",
|
||||||
"LatentComposite": "Latent Composite",
|
"LatentComposite": "Latent Composite",
|
||||||
|
"LatentBlend": "Latent Blend",
|
||||||
"LatentFromBatch" : "Latent From Batch",
|
"LatentFromBatch" : "Latent From Batch",
|
||||||
"RepeatLatentBatch": "Repeat Latent Batch",
|
"RepeatLatentBatch": "Repeat Latent Batch",
|
||||||
# Image
|
# Image
|
||||||
|
|||||||
@ -159,13 +159,64 @@
|
|||||||
"\n"
|
"\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "kkkkkkkkkkkkkkk"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Run ComfyUI with cloudflared (Recommended Way)\n",
|
||||||
|
"\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "jjjjjjjjjjjjjj"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb\n",
|
||||||
|
"!dpkg -i cloudflared-linux-amd64.deb\n",
|
||||||
|
"\n",
|
||||||
|
"import subprocess\n",
|
||||||
|
"import threading\n",
|
||||||
|
"import time\n",
|
||||||
|
"import socket\n",
|
||||||
|
"import urllib.request\n",
|
||||||
|
"\n",
|
||||||
|
"def iframe_thread(port):\n",
|
||||||
|
" while True:\n",
|
||||||
|
" time.sleep(0.5)\n",
|
||||||
|
" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
|
||||||
|
" result = sock.connect_ex(('127.0.0.1', port))\n",
|
||||||
|
" if result == 0:\n",
|
||||||
|
" break\n",
|
||||||
|
" sock.close()\n",
|
||||||
|
" print(\"\\nComfyUI finished loading, trying to launch cloudflared (if it gets stuck here cloudflared is having issues)\\n\")\n",
|
||||||
|
"\n",
|
||||||
|
" p = subprocess.Popen([\"cloudflared\", \"tunnel\", \"--url\", \"http://127.0.0.1:{}\".format(port)], stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
|
||||||
|
" for line in p.stderr:\n",
|
||||||
|
" l = line.decode()\n",
|
||||||
|
" if \"trycloudflare.com \" in l:\n",
|
||||||
|
" print(\"This is the URL to access ComfyUI:\", l[l.find(\"http\"):], end='')\n",
|
||||||
|
" #print(l, end='')\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n",
|
||||||
|
"\n",
|
||||||
|
"!python main.py --dont-print-server"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "kkkkkkkkkkkkkk"
|
"id": "kkkkkkkkkkkkkk"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"### Run ComfyUI with localtunnel (Recommended Way)\n",
|
"### Run ComfyUI with localtunnel\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n"
|
"\n"
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
torch
|
torch
|
||||||
torchdiffeq
|
|
||||||
torchsde
|
torchsde
|
||||||
einops
|
einops
|
||||||
transformers>=4.25.1
|
transformers>=4.25.1
|
||||||
|
|||||||
@ -345,6 +345,11 @@ class PromptServer():
|
|||||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||||
system_stats = {
|
system_stats = {
|
||||||
|
"system": {
|
||||||
|
"os": os.name,
|
||||||
|
"python_version": sys.version,
|
||||||
|
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
|
||||||
|
},
|
||||||
"devices": [
|
"devices": [
|
||||||
{
|
{
|
||||||
"name": device_name,
|
"name": device_name,
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import {app} from "/scripts/app.js";
|
import {app} from "../../scripts/app.js";
|
||||||
|
|
||||||
// Adds filtering to combo context menus
|
// Adds filtering to combo context menus
|
||||||
|
|
||||||
@ -27,10 +27,13 @@ const ext = {
|
|||||||
const clickedComboValue = currentNode.widgets
|
const clickedComboValue = currentNode.widgets
|
||||||
.filter(w => w.type === "combo" && w.options.values.length === values.length)
|
.filter(w => w.type === "combo" && w.options.values.length === values.length)
|
||||||
.find(w => w.options.values.every((v, i) => v === values[i]))
|
.find(w => w.options.values.every((v, i) => v === values[i]))
|
||||||
.value;
|
?.value;
|
||||||
|
|
||||||
let selectedIndex = values.findIndex(v => v === clickedComboValue);
|
let selectedIndex = clickedComboValue ? values.findIndex(v => v === clickedComboValue) : 0;
|
||||||
let selectedItem = displayedItems?.[selectedIndex];
|
if (selectedIndex < 0) {
|
||||||
|
selectedIndex = 0;
|
||||||
|
}
|
||||||
|
let selectedItem = displayedItems[selectedIndex];
|
||||||
updateSelected();
|
updateSelected();
|
||||||
|
|
||||||
// Apply highlighting to the selected item
|
// Apply highlighting to the selected item
|
||||||
|
|||||||
25
web/extensions/core/linkRenderMode.js
Normal file
25
web/extensions/core/linkRenderMode.js
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import { app } from "/scripts/app.js";
|
||||||
|
|
||||||
|
const id = "Comfy.LinkRenderMode";
|
||||||
|
const ext = {
|
||||||
|
name: id,
|
||||||
|
async setup(app) {
|
||||||
|
app.ui.settings.addSetting({
|
||||||
|
id,
|
||||||
|
name: "Link Render Mode",
|
||||||
|
defaultValue: 2,
|
||||||
|
type: "combo",
|
||||||
|
options: LiteGraph.LINK_RENDER_MODES.map((m, i) => ({
|
||||||
|
value: i,
|
||||||
|
text: m,
|
||||||
|
selected: i == app.canvas.links_render_mode,
|
||||||
|
})),
|
||||||
|
onChange(value) {
|
||||||
|
app.canvas.links_render_mode = +value;
|
||||||
|
app.graph.setDirtyCanvas(true);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
app.registerExtension(ext);
|
||||||
@ -2,7 +2,7 @@ import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js";
|
|||||||
import { app } from "../../scripts/app.js";
|
import { app } from "../../scripts/app.js";
|
||||||
|
|
||||||
const CONVERTED_TYPE = "converted-widget";
|
const CONVERTED_TYPE = "converted-widget";
|
||||||
const VALID_TYPES = ["STRING", "combo", "number"];
|
const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"];
|
||||||
|
|
||||||
function isConvertableWidget(widget, config) {
|
function isConvertableWidget(widget, config) {
|
||||||
return VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0]);
|
return VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0]);
|
||||||
|
|||||||
@ -9835,7 +9835,11 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
ctx.textAlign = "center";
|
ctx.textAlign = "center";
|
||||||
ctx.fillStyle = text_color;
|
ctx.fillStyle = text_color;
|
||||||
ctx.fillText(
|
ctx.fillText(
|
||||||
w.label || w.name + " " + Number(w.value).toFixed(3),
|
w.label || w.name + " " + Number(w.value).toFixed(
|
||||||
|
w.options.precision != null
|
||||||
|
? w.options.precision
|
||||||
|
: 3
|
||||||
|
),
|
||||||
widget_width * 0.5,
|
widget_width * 0.5,
|
||||||
y + H * 0.7
|
y + H * 0.7
|
||||||
);
|
);
|
||||||
@ -13835,7 +13839,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
if (!disabled) {
|
if (!disabled) {
|
||||||
element.addEventListener("click", inner_onclick);
|
element.addEventListener("click", inner_onclick);
|
||||||
}
|
}
|
||||||
if (options.autoopen) {
|
if (!disabled && options.autoopen) {
|
||||||
LiteGraph.pointerListenerAdd(element,"enter",inner_over);
|
LiteGraph.pointerListenerAdd(element,"enter",inner_over);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -264,6 +264,15 @@ class ComfyApi extends EventTarget {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets system & device stats
|
||||||
|
* @returns System stats such as python version, OS, per device info
|
||||||
|
*/
|
||||||
|
async getSystemStats() {
|
||||||
|
const res = await this.fetchApi("/system_stats");
|
||||||
|
return await res.json();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sends a POST request to the API
|
* Sends a POST request to the API
|
||||||
* @param {*} type The endpoint to post to
|
* @param {*} type The endpoint to post to
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import { ComfyLogging } from "./logging.js";
|
||||||
import { ComfyWidgets } from "./widgets.js";
|
import { ComfyWidgets } from "./widgets.js";
|
||||||
import { ComfyUI, $el } from "./ui.js";
|
import { ComfyUI, $el } from "./ui.js";
|
||||||
import { api } from "./api.js";
|
import { api } from "./api.js";
|
||||||
@ -31,6 +32,7 @@ export class ComfyApp {
|
|||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
this.ui = new ComfyUI(this);
|
this.ui = new ComfyUI(this);
|
||||||
|
this.logging = new ComfyLogging(this);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* List of extensions that are registered with the app
|
* List of extensions that are registered with the app
|
||||||
@ -768,6 +770,19 @@ export class ComfyApp {
|
|||||||
}
|
}
|
||||||
block_default = true;
|
block_default = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (e.keyCode == 66 && e.ctrlKey) {
|
||||||
|
if (this.selected_nodes) {
|
||||||
|
for (var i in this.selected_nodes) {
|
||||||
|
if (this.selected_nodes[i].mode === 4) { // never
|
||||||
|
this.selected_nodes[i].mode = 0; // always
|
||||||
|
} else {
|
||||||
|
this.selected_nodes[i].mode = 4; // never
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
block_default = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
this.graph.change();
|
this.graph.change();
|
||||||
@ -914,14 +929,21 @@ export class ComfyApp {
|
|||||||
const origDrawNode = LGraphCanvas.prototype.drawNode;
|
const origDrawNode = LGraphCanvas.prototype.drawNode;
|
||||||
LGraphCanvas.prototype.drawNode = function (node, ctx) {
|
LGraphCanvas.prototype.drawNode = function (node, ctx) {
|
||||||
var editor_alpha = this.editor_alpha;
|
var editor_alpha = this.editor_alpha;
|
||||||
|
var old_color = node.bgcolor;
|
||||||
|
|
||||||
if (node.mode === 2) { // never
|
if (node.mode === 2) { // never
|
||||||
this.editor_alpha = 0.4;
|
this.editor_alpha = 0.4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node.mode === 4) { // never
|
||||||
|
node.bgcolor = "#FF00FF";
|
||||||
|
this.editor_alpha = 0.2;
|
||||||
|
}
|
||||||
|
|
||||||
const res = origDrawNode.apply(this, arguments);
|
const res = origDrawNode.apply(this, arguments);
|
||||||
|
|
||||||
this.editor_alpha = editor_alpha;
|
this.editor_alpha = editor_alpha;
|
||||||
|
node.bgcolor = old_color;
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
};
|
};
|
||||||
@ -1003,6 +1025,7 @@ export class ComfyApp {
|
|||||||
*/
|
*/
|
||||||
async #loadExtensions() {
|
async #loadExtensions() {
|
||||||
const extensions = await api.getExtensions();
|
const extensions = await api.getExtensions();
|
||||||
|
this.logging.addEntry("Comfy.App", "debug", { Extensions: extensions });
|
||||||
for (const ext of extensions) {
|
for (const ext of extensions) {
|
||||||
try {
|
try {
|
||||||
await import(api.apiURL(ext));
|
await import(api.apiURL(ext));
|
||||||
@ -1286,6 +1309,9 @@ export class ComfyApp {
|
|||||||
(t) => `<li>${t}</li>`
|
(t) => `<li>${t}</li>`
|
||||||
).join("")}</ul>Nodes that have failed to load will show as red on the graph.`
|
).join("")}</ul>Nodes that have failed to load will show as red on the graph.`
|
||||||
);
|
);
|
||||||
|
this.logging.addEntry("Comfy.App", "warn", {
|
||||||
|
MissingNodes: missingNodeTypes,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1308,7 +1334,7 @@ export class ComfyApp {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node.mode === 2) {
|
if (node.mode === 2 || node.mode === 4) {
|
||||||
// Don't serialize muted nodes
|
// Don't serialize muted nodes
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -1331,12 +1357,36 @@ export class ComfyApp {
|
|||||||
let parent = node.getInputNode(i);
|
let parent = node.getInputNode(i);
|
||||||
if (parent) {
|
if (parent) {
|
||||||
let link = node.getInputLink(i);
|
let link = node.getInputLink(i);
|
||||||
while (parent && parent.isVirtualNode) {
|
while (parent.mode === 4 || parent.isVirtualNode) {
|
||||||
link = parent.getInputLink(link.origin_slot);
|
let found = false;
|
||||||
if (link) {
|
if (parent.isVirtualNode) {
|
||||||
parent = parent.getInputNode(link.origin_slot);
|
link = parent.getInputLink(link.origin_slot);
|
||||||
} else {
|
if (link) {
|
||||||
parent = null;
|
parent = parent.getInputNode(link.target_slot);
|
||||||
|
if (parent) {
|
||||||
|
found = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (link && parent.mode === 4) {
|
||||||
|
let all_inputs = [link.origin_slot];
|
||||||
|
if (parent.inputs) {
|
||||||
|
all_inputs = all_inputs.concat(Object.keys(parent.inputs))
|
||||||
|
for (let parent_input in all_inputs) {
|
||||||
|
parent_input = all_inputs[parent_input];
|
||||||
|
if (parent.inputs[parent_input].type === node.inputs[i].type) {
|
||||||
|
link = parent.getInputLink(parent_input);
|
||||||
|
if (link) {
|
||||||
|
parent = parent.getInputNode(parent_input);
|
||||||
|
}
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!found) {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
367
web/scripts/logging.js
Normal file
367
web/scripts/logging.js
Normal file
@ -0,0 +1,367 @@
|
|||||||
|
import { $el, ComfyDialog } from "./ui.js";
|
||||||
|
import { api } from "./api.js";
|
||||||
|
|
||||||
|
$el("style", {
|
||||||
|
textContent: `
|
||||||
|
.comfy-logging-logs {
|
||||||
|
display: grid;
|
||||||
|
color: var(--fg-color);
|
||||||
|
white-space: pre-wrap;
|
||||||
|
}
|
||||||
|
.comfy-logging-log {
|
||||||
|
display: contents;
|
||||||
|
}
|
||||||
|
.comfy-logging-title {
|
||||||
|
background: var(--tr-even-bg-color);
|
||||||
|
font-weight: bold;
|
||||||
|
margin-bottom: 5px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
.comfy-logging-log div {
|
||||||
|
background: var(--row-bg);
|
||||||
|
padding: 5px;
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
parent: document.body,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Stringify function supporting max depth and removal of circular references
|
||||||
|
// https://stackoverflow.com/a/57193345
|
||||||
|
function stringify(val, depth, replacer, space, onGetObjID) {
|
||||||
|
depth = isNaN(+depth) ? 1 : depth;
|
||||||
|
var recursMap = new WeakMap();
|
||||||
|
function _build(val, depth, o, a, r) {
|
||||||
|
// (JSON.stringify() has it's own rules, which we respect here by using it for property iteration)
|
||||||
|
return !val || typeof val != "object"
|
||||||
|
? val
|
||||||
|
: ((r = recursMap.has(val)),
|
||||||
|
recursMap.set(val, true),
|
||||||
|
(a = Array.isArray(val)),
|
||||||
|
r
|
||||||
|
? (o = (onGetObjID && onGetObjID(val)) || null)
|
||||||
|
: JSON.stringify(val, function (k, v) {
|
||||||
|
if (a || depth > 0) {
|
||||||
|
if (replacer) v = replacer(k, v);
|
||||||
|
if (!k) return (a = Array.isArray(v)), (val = v);
|
||||||
|
!o && (o = a ? [] : {});
|
||||||
|
o[k] = _build(v, a ? depth : depth - 1);
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
o === void 0 ? (a ? [] : {}) : o);
|
||||||
|
}
|
||||||
|
return JSON.stringify(_build(val, depth), null, space);
|
||||||
|
}
|
||||||
|
|
||||||
|
const jsonReplacer = (k, v, ui) => {
|
||||||
|
if (v instanceof Array && v.length === 1) {
|
||||||
|
v = v[0];
|
||||||
|
}
|
||||||
|
if (v instanceof Date) {
|
||||||
|
v = v.toISOString();
|
||||||
|
if (ui) {
|
||||||
|
v = v.split("T")[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (v instanceof Error) {
|
||||||
|
let err = "";
|
||||||
|
if (v.name) err += v.name + "\n";
|
||||||
|
if (v.message) err += v.message + "\n";
|
||||||
|
if (v.stack) err += v.stack + "\n";
|
||||||
|
if (!err) {
|
||||||
|
err = v.toString();
|
||||||
|
}
|
||||||
|
v = err;
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
};
|
||||||
|
|
||||||
|
const fileInput = $el("input", {
|
||||||
|
type: "file",
|
||||||
|
accept: ".json",
|
||||||
|
style: { display: "none" },
|
||||||
|
parent: document.body,
|
||||||
|
});
|
||||||
|
|
||||||
|
class ComfyLoggingDialog extends ComfyDialog {
|
||||||
|
constructor(logging) {
|
||||||
|
super();
|
||||||
|
this.logging = logging;
|
||||||
|
}
|
||||||
|
|
||||||
|
clear() {
|
||||||
|
this.logging.clear();
|
||||||
|
this.show();
|
||||||
|
}
|
||||||
|
|
||||||
|
export() {
|
||||||
|
const blob = new Blob([stringify([...this.logging.entries], 20, jsonReplacer, "\t")], {
|
||||||
|
type: "application/json",
|
||||||
|
});
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = $el("a", {
|
||||||
|
href: url,
|
||||||
|
download: `comfyui-logs-${Date.now()}.json`,
|
||||||
|
style: { display: "none" },
|
||||||
|
parent: document.body,
|
||||||
|
});
|
||||||
|
a.click();
|
||||||
|
setTimeout(function () {
|
||||||
|
a.remove();
|
||||||
|
window.URL.revokeObjectURL(url);
|
||||||
|
}, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
import() {
|
||||||
|
fileInput.onchange = () => {
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = () => {
|
||||||
|
fileInput.remove();
|
||||||
|
try {
|
||||||
|
const obj = JSON.parse(reader.result);
|
||||||
|
if (obj instanceof Array) {
|
||||||
|
this.show(obj);
|
||||||
|
} else {
|
||||||
|
throw new Error("Invalid file selected.");
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
alert("Unable to load logs: " + error.message);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
reader.readAsText(fileInput.files[0]);
|
||||||
|
};
|
||||||
|
fileInput.click();
|
||||||
|
}
|
||||||
|
|
||||||
|
createButtons() {
|
||||||
|
return [
|
||||||
|
$el("button", {
|
||||||
|
type: "button",
|
||||||
|
textContent: "Clear",
|
||||||
|
onclick: () => this.clear(),
|
||||||
|
}),
|
||||||
|
$el("button", {
|
||||||
|
type: "button",
|
||||||
|
textContent: "Export logs...",
|
||||||
|
onclick: () => this.export(),
|
||||||
|
}),
|
||||||
|
$el("button", {
|
||||||
|
type: "button",
|
||||||
|
textContent: "View exported logs...",
|
||||||
|
onclick: () => this.import(),
|
||||||
|
}),
|
||||||
|
...super.createButtons(),
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
getTypeColor(type) {
|
||||||
|
switch (type) {
|
||||||
|
case "error":
|
||||||
|
return "red";
|
||||||
|
case "warn":
|
||||||
|
return "orange";
|
||||||
|
case "debug":
|
||||||
|
return "dodgerblue";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
show(entries) {
|
||||||
|
if (!entries) entries = this.logging.entries;
|
||||||
|
this.element.style.width = "100%";
|
||||||
|
const cols = {
|
||||||
|
source: "Source",
|
||||||
|
type: "Type",
|
||||||
|
timestamp: "Timestamp",
|
||||||
|
message: "Message",
|
||||||
|
};
|
||||||
|
const keys = Object.keys(cols);
|
||||||
|
const headers = Object.values(cols).map((title) =>
|
||||||
|
$el("div.comfy-logging-title", {
|
||||||
|
textContent: title,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
const rows = entries.map((entry, i) => {
|
||||||
|
return $el(
|
||||||
|
"div.comfy-logging-log",
|
||||||
|
{
|
||||||
|
$: (el) => el.style.setProperty("--row-bg", `var(--tr-${i % 2 ? "even" : "odd"}-bg-color)`),
|
||||||
|
},
|
||||||
|
keys.map((key) => {
|
||||||
|
let v = entry[key];
|
||||||
|
let color;
|
||||||
|
if (key === "type") {
|
||||||
|
color = this.getTypeColor(v);
|
||||||
|
} else {
|
||||||
|
v = jsonReplacer(key, v, true);
|
||||||
|
|
||||||
|
if (typeof v === "object") {
|
||||||
|
v = stringify(v, 5, jsonReplacer, " ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return $el("div", {
|
||||||
|
style: {
|
||||||
|
color,
|
||||||
|
},
|
||||||
|
textContent: v,
|
||||||
|
});
|
||||||
|
})
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
const grid = $el(
|
||||||
|
"div.comfy-logging-logs",
|
||||||
|
{
|
||||||
|
style: {
|
||||||
|
gridTemplateColumns: `repeat(${headers.length}, 1fr)`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
[...headers, ...rows]
|
||||||
|
);
|
||||||
|
const els = [grid];
|
||||||
|
if (!this.logging.enabled) {
|
||||||
|
els.unshift(
|
||||||
|
$el("h3", {
|
||||||
|
style: { textAlign: "center" },
|
||||||
|
textContent: "Logging is disabled",
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
super.show($el("div", els));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export class ComfyLogging {
|
||||||
|
/**
|
||||||
|
* @type Array<{ source: string, type: string, timestamp: Date, message: any }>
|
||||||
|
*/
|
||||||
|
entries = [];
|
||||||
|
|
||||||
|
#enabled;
|
||||||
|
#console = {};
|
||||||
|
|
||||||
|
get enabled() {
|
||||||
|
return this.#enabled;
|
||||||
|
}
|
||||||
|
|
||||||
|
set enabled(value) {
|
||||||
|
if (value === this.#enabled) return;
|
||||||
|
if (value) {
|
||||||
|
this.patchConsole();
|
||||||
|
} else {
|
||||||
|
this.unpatchConsole();
|
||||||
|
}
|
||||||
|
this.#enabled = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
constructor(app) {
|
||||||
|
this.app = app;
|
||||||
|
|
||||||
|
this.dialog = new ComfyLoggingDialog(this);
|
||||||
|
this.addSetting();
|
||||||
|
this.catchUnhandled();
|
||||||
|
this.addInitData();
|
||||||
|
}
|
||||||
|
|
||||||
|
addSetting() {
|
||||||
|
const settingId = "Comfy.Logging.Enabled";
|
||||||
|
const htmlSettingId = settingId.replaceAll(".", "-");
|
||||||
|
const setting = this.app.ui.settings.addSetting({
|
||||||
|
id: settingId,
|
||||||
|
name: settingId,
|
||||||
|
defaultValue: true,
|
||||||
|
type: (name, setter, value) => {
|
||||||
|
return $el("tr", [
|
||||||
|
$el("td", [
|
||||||
|
$el("label", {
|
||||||
|
textContent: "Logging",
|
||||||
|
for: htmlSettingId,
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
$el("td", [
|
||||||
|
$el("input", {
|
||||||
|
id: htmlSettingId,
|
||||||
|
type: "checkbox",
|
||||||
|
checked: value,
|
||||||
|
onchange: (event) => {
|
||||||
|
setter((this.enabled = event.target.checked));
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
$el("button", {
|
||||||
|
textContent: "View Logs",
|
||||||
|
onclick: () => {
|
||||||
|
this.app.ui.settings.element.close();
|
||||||
|
this.dialog.show();
|
||||||
|
},
|
||||||
|
style: {
|
||||||
|
fontSize: "14px",
|
||||||
|
display: "block",
|
||||||
|
marginTop: "5px",
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
]);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
this.enabled = setting.value;
|
||||||
|
}
|
||||||
|
|
||||||
|
patchConsole() {
|
||||||
|
// Capture common console outputs
|
||||||
|
const self = this;
|
||||||
|
for (const type of ["log", "warn", "error", "debug"]) {
|
||||||
|
const orig = console[type];
|
||||||
|
this.#console[type] = orig;
|
||||||
|
console[type] = function () {
|
||||||
|
orig.apply(console, arguments);
|
||||||
|
self.addEntry("console", type, ...arguments);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unpatchConsole() {
|
||||||
|
// Restore original console functions
|
||||||
|
for (const type of Object.keys(this.#console)) {
|
||||||
|
console[type] = this.#console[type];
|
||||||
|
}
|
||||||
|
this.#console = {};
|
||||||
|
}
|
||||||
|
|
||||||
|
catchUnhandled() {
|
||||||
|
// Capture uncaught errors
|
||||||
|
window.addEventListener("error", (e) => {
|
||||||
|
this.addEntry("window", "error", e.error ?? "Unknown error");
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
|
||||||
|
window.addEventListener("unhandledrejection", (e) => {
|
||||||
|
this.addEntry("unhandledrejection", "error", e.reason ?? "Unknown error");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
clear() {
|
||||||
|
this.entries = [];
|
||||||
|
}
|
||||||
|
|
||||||
|
addEntry(source, type, ...args) {
|
||||||
|
if (this.enabled) {
|
||||||
|
this.entries.push({
|
||||||
|
source,
|
||||||
|
type,
|
||||||
|
timestamp: new Date(),
|
||||||
|
message: args,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log(source, ...args) {
|
||||||
|
this.addEntry(source, "log", ...args);
|
||||||
|
}
|
||||||
|
|
||||||
|
async addInitData() {
|
||||||
|
if (!this.enabled) return;
|
||||||
|
const source = "ComfyUI.Logging";
|
||||||
|
this.addEntry(source, "debug", { UserAgent: navigator.userAgent });
|
||||||
|
const systemStats = await api.getSystemStats();
|
||||||
|
this.addEntry(source, "debug", systemStats);
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -234,7 +234,7 @@ class ComfySettingsDialog extends ComfyDialog {
|
|||||||
localStorage[settingId] = JSON.stringify(value);
|
localStorage[settingId] = JSON.stringify(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "",}) {
|
addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined}) {
|
||||||
if (!id) {
|
if (!id) {
|
||||||
throw new Error("Settings must have an ID");
|
throw new Error("Settings must have an ID");
|
||||||
}
|
}
|
||||||
@ -347,6 +347,32 @@ class ComfySettingsDialog extends ComfyDialog {
|
|||||||
]),
|
]),
|
||||||
]);
|
]);
|
||||||
break;
|
break;
|
||||||
|
case "combo":
|
||||||
|
element = $el("tr", [
|
||||||
|
labelCell,
|
||||||
|
$el("td", [
|
||||||
|
$el(
|
||||||
|
"select",
|
||||||
|
{
|
||||||
|
oninput: (e) => {
|
||||||
|
setter(e.target.value);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
(typeof options === "function" ? options(value) : options || []).map((opt) => {
|
||||||
|
if (typeof opt === "string") {
|
||||||
|
opt = { text: opt };
|
||||||
|
}
|
||||||
|
const v = opt.value ?? opt.text;
|
||||||
|
return $el("option", {
|
||||||
|
value: v,
|
||||||
|
textContent: opt.text,
|
||||||
|
selected: value + "" === v + "",
|
||||||
|
});
|
||||||
|
})
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
]);
|
||||||
|
break;
|
||||||
case "text":
|
case "text":
|
||||||
default:
|
default:
|
||||||
if (type !== "text") {
|
if (type !== "text") {
|
||||||
@ -480,7 +506,7 @@ class ComfyList {
|
|||||||
|
|
||||||
hide() {
|
hide() {
|
||||||
this.element.style.display = "none";
|
this.element.style.display = "none";
|
||||||
this.button.textContent = "See " + this.#text;
|
this.button.textContent = "View " + this.#text;
|
||||||
}
|
}
|
||||||
|
|
||||||
toggle() {
|
toggle() {
|
||||||
@ -542,6 +568,13 @@ export class ComfyUI {
|
|||||||
defaultValue: "",
|
defaultValue: "",
|
||||||
});
|
});
|
||||||
|
|
||||||
|
this.settings.addSetting({
|
||||||
|
id: "Comfy.DisableSliders",
|
||||||
|
name: "Disable sliders.",
|
||||||
|
type: "boolean",
|
||||||
|
defaultValue: false,
|
||||||
|
});
|
||||||
|
|
||||||
const fileInput = $el("input", {
|
const fileInput = $el("input", {
|
||||||
id: "comfy-file-input",
|
id: "comfy-file-input",
|
||||||
type: "file",
|
type: "file",
|
||||||
|
|||||||
@ -79,8 +79,8 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
|
|||||||
return valueControl;
|
return valueControl;
|
||||||
};
|
};
|
||||||
|
|
||||||
function seedWidget(node, inputName, inputData) {
|
function seedWidget(node, inputName, inputData, app) {
|
||||||
const seed = ComfyWidgets.INT(node, inputName, inputData);
|
const seed = ComfyWidgets.INT(node, inputName, inputData, app);
|
||||||
const seedControl = addValueControlWidget(node, seed.widget, "randomize");
|
const seedControl = addValueControlWidget(node, seed.widget, "randomize");
|
||||||
|
|
||||||
seed.widget.linkedWidgets = [seedControl];
|
seed.widget.linkedWidgets = [seedControl];
|
||||||
@ -250,19 +250,29 @@ function addMultilineWidget(node, name, opts, app) {
|
|||||||
return { minWidth: 400, minHeight: 200, widget };
|
return { minWidth: 400, minHeight: 200, widget };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function isSlider(display, app) {
|
||||||
|
if (app.ui.settings.getSettingValue("Comfy.DisableSliders")) {
|
||||||
|
return "number"
|
||||||
|
}
|
||||||
|
|
||||||
|
return (display==="slider") ? "slider" : "number"
|
||||||
|
}
|
||||||
|
|
||||||
export const ComfyWidgets = {
|
export const ComfyWidgets = {
|
||||||
"INT:seed": seedWidget,
|
"INT:seed": seedWidget,
|
||||||
"INT:noise_seed": seedWidget,
|
"INT:noise_seed": seedWidget,
|
||||||
FLOAT(node, inputName, inputData) {
|
FLOAT(node, inputName, inputData, app) {
|
||||||
|
let widgetType = isSlider(inputData[1]["display"], app);
|
||||||
const { val, config } = getNumberDefaults(inputData, 0.5);
|
const { val, config } = getNumberDefaults(inputData, 0.5);
|
||||||
return { widget: node.addWidget("number", inputName, val, () => {}, config) };
|
return { widget: node.addWidget(widgetType, inputName, val, () => {}, config) };
|
||||||
},
|
},
|
||||||
INT(node, inputName, inputData) {
|
INT(node, inputName, inputData, app) {
|
||||||
|
let widgetType = isSlider(inputData[1]["display"], app);
|
||||||
const { val, config } = getNumberDefaults(inputData, 1);
|
const { val, config } = getNumberDefaults(inputData, 1);
|
||||||
Object.assign(config, { precision: 0 });
|
Object.assign(config, { precision: 0 });
|
||||||
return {
|
return {
|
||||||
widget: node.addWidget(
|
widget: node.addWidget(
|
||||||
"number",
|
widgetType,
|
||||||
inputName,
|
inputName,
|
||||||
val,
|
val,
|
||||||
function (v) {
|
function (v) {
|
||||||
@ -273,6 +283,18 @@ export const ComfyWidgets = {
|
|||||||
),
|
),
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
BOOLEAN(node, inputName, inputData) {
|
||||||
|
let defaultVal = inputData[1]["default"];
|
||||||
|
return {
|
||||||
|
widget: node.addWidget(
|
||||||
|
"toggle",
|
||||||
|
inputName,
|
||||||
|
defaultVal,
|
||||||
|
() => {},
|
||||||
|
{"on": inputData[1].label_on, "off": inputData[1].label_off}
|
||||||
|
)
|
||||||
|
};
|
||||||
|
},
|
||||||
STRING(node, inputName, inputData, app) {
|
STRING(node, inputName, inputData, app) {
|
||||||
const defaultVal = inputData[1].default || "";
|
const defaultVal = inputData[1].default || "";
|
||||||
const multiline = !!inputData[1].multiline;
|
const multiline = !!inputData[1].multiline;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user