mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-09 21:12:36 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
e11af79bcc
@ -94,7 +94,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
|
||||
|
||||
This is the command to install the nightly with ROCm 5.6 that supports the 7000 series and might have some performance improvements:
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.6 -r requirements.txt```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.6```
|
||||
|
||||
### NVIDIA
|
||||
|
||||
@ -126,10 +126,10 @@ After this you should have everything installed and can proceed to running Comfy
|
||||
|
||||
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
|
||||
|
||||
1. Install pytorch. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide.
|
||||
1. Install pytorch nightly. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide (make sure to install the latest pytorch nightly).
|
||||
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
|
||||
1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies).
|
||||
1. Launch ComfyUI by running `python main.py`.
|
||||
1. Launch ComfyUI by running `python main.py --force-fp16`. Note that --force-fp16 will only work if you installed the latest pytorch nightly.
|
||||
|
||||
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ from ..ldm.modules.diffusionmodules.util import (
|
||||
)
|
||||
|
||||
from ..ldm.modules.attention import SpatialTransformer
|
||||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
||||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
|
||||
from ..ldm.util import exists
|
||||
|
||||
|
||||
@ -57,6 +57,7 @@ class ControlNet(nn.Module):
|
||||
transformer_depth_middle=None,
|
||||
):
|
||||
super().__init__()
|
||||
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
|
||||
@ -200,13 +201,7 @@ class ControlNet(nn.Module):
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
@ -259,13 +254,7 @@ class ControlNet(nn.Module):
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||
SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
|
||||
@ -39,6 +39,7 @@ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORI
|
||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||
cm_group = parser.add_mutually_exclusive_group()
|
||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
@ -84,7 +85,12 @@ parser.add_argument("--dont-print-server", action="store_true", help="Don't prin
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.windows_standalone_build:
|
||||
args.auto_launch = True
|
||||
|
||||
if args.disable_auto_launch:
|
||||
args.auto_launch = False
|
||||
|
||||
@ -3,7 +3,6 @@ import math
|
||||
from scipy import integrate
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchdiffeq import odeint
|
||||
import torchsde
|
||||
from tqdm.auto import trange, tqdm
|
||||
|
||||
@ -287,30 +286,6 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
v = torch.randint_like(x, 2) * 2 - 1
|
||||
fevals = 0
|
||||
def ode_fn(sigma, x):
|
||||
nonlocal fevals
|
||||
with torch.enable_grad():
|
||||
x = x[0].detach().requires_grad_()
|
||||
denoised = model(x, sigma * s_in, **extra_args)
|
||||
d = to_d(x, sigma, denoised)
|
||||
fevals += 1
|
||||
grad = torch.autograd.grad((d * v).sum(), x)[0]
|
||||
d_ll = (v * grad).flatten(1).sum(1)
|
||||
return d.detach(), d_ll
|
||||
x_min = x, x.new_zeros([x.shape[0]])
|
||||
t = x.new_tensor([sigma_min, sigma_max])
|
||||
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
|
||||
latent, delta_ll = sol[0][-1], sol[1][-1]
|
||||
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
|
||||
return ll_prior + delta_ll, {'fevals': fevals}
|
||||
|
||||
|
||||
class PIDStepSizeController:
|
||||
"""A PID controller for ODE adaptive step size control."""
|
||||
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
||||
|
||||
@ -52,9 +52,9 @@ def init_(tensor):
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype)
|
||||
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
@ -62,19 +62,19 @@ class GEGLU(nn.Module):
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
comfy.ops.Linear(dim, inner_dim, dtype=dtype),
|
||||
comfy.ops.Linear(dim, inner_dim, dtype=dtype, device=device),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim, dtype=dtype)
|
||||
) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
comfy.ops.Linear(inner_dim, dim_out, dtype=dtype)
|
||||
comfy.ops.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -90,8 +90,8 @@ def zero_module(module):
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels, dtype=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype)
|
||||
def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
@ -148,7 +148,7 @@ class SpatialSelfAttention(nn.Module):
|
||||
|
||||
|
||||
class CrossAttentionBirchSan(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
@ -156,12 +156,12 @@ class CrossAttentionBirchSan(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
|
||||
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
@ -245,7 +245,7 @@ class CrossAttentionBirchSan(nn.Module):
|
||||
|
||||
|
||||
class CrossAttentionDoggettx(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
@ -253,12 +253,12 @@ class CrossAttentionDoggettx(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
|
||||
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
@ -343,7 +343,7 @@ class CrossAttentionDoggettx(nn.Module):
|
||||
return self.to_out(r2)
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
@ -351,12 +351,12 @@ class CrossAttention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
|
||||
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
@ -399,7 +399,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
class MemoryEfficientCrossAttention(nn.Module):
|
||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None):
|
||||
super().__init__()
|
||||
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||
f"{heads} heads.")
|
||||
@ -409,11 +409,11 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout))
|
||||
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
@ -450,7 +450,7 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
return self.to_out(out)
|
||||
|
||||
class CrossAttentionPytorch(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
@ -458,11 +458,11 @@ class CrossAttentionPytorch(nn.Module):
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout))
|
||||
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
@ -508,17 +508,17 @@ else:
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||
disable_self_attn=False, dtype=None):
|
||||
disable_self_attn=False, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype)
|
||||
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device)
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim, dtype=dtype)
|
||||
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
|
||||
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.checkpoint = checkpoint
|
||||
self.n_heads = n_heads
|
||||
self.d_head = d_head
|
||||
@ -648,34 +648,34 @@ class SpatialTransformer(nn.Module):
|
||||
def __init__(self, in_channels, n_heads, d_head,
|
||||
depth=1, dropout=0., context_dim=None,
|
||||
disable_self_attn=False, use_linear=False,
|
||||
use_checkpoint=True, dtype=None):
|
||||
use_checkpoint=True, dtype=None, device=None):
|
||||
super().__init__()
|
||||
if exists(context_dim) and not isinstance(context_dim, list):
|
||||
context_dim = [context_dim] * depth
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels, dtype=dtype)
|
||||
self.norm = Normalize(in_channels, dtype=dtype, device=device)
|
||||
if not use_linear:
|
||||
self.proj_in = nn.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0, dtype=dtype)
|
||||
padding=0, dtype=dtype, device=device)
|
||||
else:
|
||||
self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype)
|
||||
self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
||||
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype)
|
||||
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device)
|
||||
for d in range(depth)]
|
||||
)
|
||||
if not use_linear:
|
||||
self.proj_out = nn.Conv2d(inner_dim,in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0, dtype=dtype)
|
||||
padding=0, dtype=dtype, device=device)
|
||||
else:
|
||||
self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype)
|
||||
self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device)
|
||||
self.use_linear = use_linear
|
||||
|
||||
def forward(self, x, context=None, transformer_options={}):
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Optional, Any
|
||||
|
||||
from ..attention import MemoryEfficientCrossAttention
|
||||
from comfy import model_management
|
||||
import comfy.ops
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
import xformers
|
||||
@ -48,7 +49,7 @@ class Upsample(nn.Module):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
self.conv = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
@ -67,7 +68,7 @@ class Downsample(nn.Module):
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
self.conv = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
@ -95,30 +96,30 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels,
|
||||
self.conv1 = comfy.ops.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels,
|
||||
self.temb_proj = comfy.ops.Linear(temb_channels,
|
||||
out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels,
|
||||
self.conv2 = comfy.ops.Conv2d(out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
||||
self.conv_shortcut = comfy.ops.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
||||
self.nin_shortcut = comfy.ops.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
@ -188,22 +189,22 @@ class AttnBlock(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
self.q = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
self.k = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
self.v = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
self.proj_out = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
@ -243,22 +244,22 @@ class MemoryEfficientAttnBlock(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
self.q = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
self.k = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
self.v = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
self.proj_out = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
@ -302,22 +303,22 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
self.q = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
self.k = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
self.v = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
self.proj_out = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
@ -399,14 +400,14 @@ class Model(nn.Module):
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList([
|
||||
torch.nn.Linear(self.ch,
|
||||
comfy.ops.Linear(self.ch,
|
||||
self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch,
|
||||
comfy.ops.Linear(self.temb_ch,
|
||||
self.temb_ch),
|
||||
])
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
self.conv_in = comfy.ops.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
@ -475,7 +476,7 @@ class Model(nn.Module):
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
self.conv_out = comfy.ops.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
@ -548,7 +549,7 @@ class Encoder(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
self.conv_in = comfy.ops.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
@ -593,7 +594,7 @@ class Encoder(nn.Module):
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
self.conv_out = comfy.ops.Conv2d(block_in,
|
||||
2*z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
@ -653,7 +654,7 @@ class Decoder(nn.Module):
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels,
|
||||
self.conv_in = comfy.ops.Conv2d(z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
@ -695,7 +696,7 @@ class Decoder(nn.Module):
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
self.conv_out = comfy.ops.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
|
||||
@ -19,45 +19,6 @@ from ..attention import SpatialTransformer
|
||||
from comfy.ldm.util import exists
|
||||
|
||||
|
||||
# dummy replace
|
||||
def convert_module_to_f16(x):
|
||||
pass
|
||||
|
||||
def convert_module_to_f32(x):
|
||||
pass
|
||||
|
||||
|
||||
## go
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spacial_dim: int,
|
||||
embed_dim: int,
|
||||
num_heads_channels: int,
|
||||
output_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
|
||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||
self.num_heads = embed_dim // num_heads_channels
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, *_spatial = x.shape
|
||||
x = x.reshape(b, c, -1) # NC(HW)
|
||||
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
||||
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
||||
x = self.qkv_proj(x)
|
||||
x = self.attention(x)
|
||||
x = self.c_proj(x)
|
||||
return x[:, :, 0]
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
@ -111,14 +72,14 @@ class Upsample(nn.Module):
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype)
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, output_shape=None):
|
||||
assert x.shape[1] == self.channels
|
||||
@ -138,19 +99,6 @@ class Upsample(nn.Module):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class TransposedUpsample(nn.Module):
|
||||
'Learned 2x upsampling without padding'
|
||||
def __init__(self, channels, out_channels=None, ks=5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
|
||||
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
|
||||
|
||||
def forward(self,x):
|
||||
return self.up(x)
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
@ -160,7 +108,7 @@ class Downsample(nn.Module):
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@ -169,7 +117,7 @@ class Downsample(nn.Module):
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(
|
||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype
|
||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device
|
||||
)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
@ -208,7 +156,8 @@ class ResBlock(TimestepBlock):
|
||||
use_checkpoint=False,
|
||||
up=False,
|
||||
down=False,
|
||||
dtype=None
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
@ -220,19 +169,19 @@ class ResBlock(TimestepBlock):
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
nn.GroupNorm(32, channels, dtype=dtype),
|
||||
nn.GroupNorm(32, channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims, dtype=dtype)
|
||||
self.x_upd = Upsample(channels, False, dims, dtype=dtype)
|
||||
self.h_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
|
||||
self.x_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims, dtype=dtype)
|
||||
self.x_upd = Downsample(channels, False, dims, dtype=dtype)
|
||||
self.h_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
|
||||
self.x_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
@ -240,15 +189,15 @@ class ResBlock(TimestepBlock):
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
nn.GroupNorm(32, self.out_channels, dtype=dtype),
|
||||
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype)
|
||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
|
||||
),
|
||||
)
|
||||
|
||||
@ -256,10 +205,10 @@ class ResBlock(TimestepBlock):
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 3, padding=1, dtype=dtype
|
||||
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype)
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
@ -295,142 +244,6 @@ class ResBlock(TimestepBlock):
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
Originally ported from here, but adapted to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
use_checkpoint=False,
|
||||
use_new_attention_order=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels)
|
||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
if use_new_attention_order:
|
||||
# split qkv before split heads
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
else:
|
||||
# split heads before split qkv
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
||||
#return pt_checkpoint(self._forward, x) # pytorch
|
||||
|
||||
def _forward(self, x):
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
h = self.attention(qkv)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
|
||||
|
||||
def count_flops_attn(model, _x, y):
|
||||
"""
|
||||
A counter for the `thop` package to count the operations in an
|
||||
attention operation.
|
||||
Meant to be used like:
|
||||
macs, params = thop.profile(
|
||||
model,
|
||||
inputs=(inputs, timestamps),
|
||||
custom_ops={QKVAttention: QKVAttention.count_flops},
|
||||
)
|
||||
"""
|
||||
b, c, *spatial = y[0].shape
|
||||
num_spatial = int(np.prod(spatial))
|
||||
# We perform two matmuls with the same number of ops.
|
||||
# The first computes the weight matrix, the second computes
|
||||
# the combination of the value vectors.
|
||||
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
||||
model.total_ops += th.DoubleTensor([matmul_ops])
|
||||
|
||||
|
||||
class QKVAttentionLegacy(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts", q * scale, k * scale
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention and splits in a different order.
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts",
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class Timestep(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@ -503,8 +316,10 @@ class UNetModel(nn.Module):
|
||||
use_linear_in_transformer=False,
|
||||
adm_in_channels=None,
|
||||
transformer_depth_middle=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
|
||||
@ -564,9 +379,9 @@ class UNetModel(nn.Module):
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim, dtype=self.dtype),
|
||||
linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
||||
linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
@ -579,9 +394,9 @@ class UNetModel(nn.Module):
|
||||
assert adm_in_channels is not None
|
||||
self.label_emb = nn.Sequential(
|
||||
nn.Sequential(
|
||||
linear(adm_in_channels, time_embed_dim, dtype=self.dtype),
|
||||
linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
||||
linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -590,7 +405,7 @@ class UNetModel(nn.Module):
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype)
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
||||
)
|
||||
]
|
||||
)
|
||||
@ -609,7 +424,8 @@ class UNetModel(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
@ -628,17 +444,10 @@ class UNetModel(nn.Module):
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
layers.append(SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
@ -657,11 +466,12 @@ class UNetModel(nn.Module):
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
dtype=self.dtype
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -686,18 +496,13 @@ class UNetModel(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||
SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
@ -706,7 +511,8 @@ class UNetModel(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
@ -724,7 +530,8 @@ class UNetModel(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
@ -744,16 +551,10 @@ class UNetModel(nn.Module):
|
||||
|
||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||
)
|
||||
)
|
||||
if level and i == self.num_res_blocks[level]:
|
||||
@ -768,43 +569,28 @@ class UNetModel(nn.Module):
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
dtype=self.dtype
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype)
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
nn.GroupNorm(32, ch, dtype=self.dtype),
|
||||
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
nn.GroupNorm(32, ch, dtype=self.dtype),
|
||||
conv_nd(dims, model_channels, n_embed, 1),
|
||||
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||
conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
Convert the torso of the model to float16.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f16)
|
||||
self.middle_block.apply(convert_module_to_f16)
|
||||
self.output_blocks.apply(convert_module_to_f16)
|
||||
|
||||
def convert_to_fp32(self):
|
||||
"""
|
||||
Convert the torso of the model to float32.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f32)
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
self.output_blocks.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
@ -12,14 +12,14 @@ class ModelType(Enum):
|
||||
V_PREDICTION = 2
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__()
|
||||
|
||||
unet_config = model_config.unet_config
|
||||
self.latent_format = model_config.latent_format
|
||||
self.model_config = model_config
|
||||
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||
self.diffusion_model = UNetModel(**unet_config)
|
||||
self.diffusion_model = UNetModel(**unet_config, device=device)
|
||||
self.model_type = model_type
|
||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||
if self.adm_channels is None:
|
||||
@ -107,8 +107,8 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
|
||||
class SD21UNCLIP(BaseModel):
|
||||
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION):
|
||||
super().__init__(model_config, model_type)
|
||||
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
@ -143,13 +143,13 @@ class SD21UNCLIP(BaseModel):
|
||||
return adm_out
|
||||
|
||||
class SDInpaint(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS):
|
||||
super().__init__(model_config, model_type)
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.concat_keys = ("mask", "masked_image")
|
||||
|
||||
class SDXLRefiner(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS):
|
||||
super().__init__(model_config, model_type)
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.embedder = Timestep(256)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
@ -164,7 +164,6 @@ class SDXLRefiner(BaseModel):
|
||||
else:
|
||||
aesthetic_score = kwargs.get("aesthetic_score", 6)
|
||||
|
||||
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
|
||||
out = []
|
||||
out.append(self.embedder(torch.Tensor([height])))
|
||||
out.append(self.embedder(torch.Tensor([width])))
|
||||
@ -175,8 +174,8 @@ class SDXLRefiner(BaseModel):
|
||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||
|
||||
class SDXL(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS):
|
||||
super().__init__(model_config, model_type)
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.embedder = Timestep(256)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
@ -188,7 +187,6 @@ class SDXL(BaseModel):
|
||||
target_width = kwargs.get("target_width", width)
|
||||
target_height = kwargs.get("target_height", height)
|
||||
|
||||
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
|
||||
out = []
|
||||
out.append(self.embedder(torch.Tensor([height])))
|
||||
out.append(self.embedder(torch.Tensor([width])))
|
||||
|
||||
@ -364,6 +364,7 @@ def text_encoder_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
|
||||
#NOTE: on a Ryzen 5 7600X with 4080 it's faster to shift to GPU
|
||||
if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough.
|
||||
return get_torch_device()
|
||||
else:
|
||||
@ -534,7 +535,7 @@ def should_use_fp16(device=None, model_params=0):
|
||||
return False
|
||||
|
||||
#FP16 is just broken on these cards
|
||||
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450"]
|
||||
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX"]
|
||||
for x in nvidia_16_series:
|
||||
if x in props.name:
|
||||
return False
|
||||
|
||||
@ -19,11 +19,11 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
strength = 1.0
|
||||
if 'timestep_start' in cond[1]:
|
||||
timestep_start = cond[1]['timestep_start']
|
||||
if timestep_in > timestep_start:
|
||||
if timestep_in[0] > timestep_start:
|
||||
return None
|
||||
if 'timestep_end' in cond[1]:
|
||||
timestep_end = cond[1]['timestep_end']
|
||||
if timestep_in < timestep_end:
|
||||
if timestep_in[0] < timestep_end:
|
||||
return None
|
||||
if 'area' in cond[1]:
|
||||
area = cond[1]['area']
|
||||
|
||||
29
comfy/sd.py
29
comfy/sd.py
@ -70,13 +70,22 @@ def load_lora(lora, to_load):
|
||||
alpha = lora[alpha_name].item()
|
||||
loaded_keys.add(alpha_name)
|
||||
|
||||
A_name = "{}.lora_up.weight".format(x)
|
||||
B_name = "{}.lora_down.weight".format(x)
|
||||
mid_name = "{}.lora_mid.weight".format(x)
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
if A_name in lora.keys():
|
||||
if regular_lora in lora.keys():
|
||||
A_name = regular_lora
|
||||
B_name = "{}.lora_down.weight".format(x)
|
||||
mid_name = "{}.lora_mid.weight".format(x)
|
||||
elif diffusers_lora in lora.keys():
|
||||
A_name = diffusers_lora
|
||||
B_name = "{}_lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
|
||||
if A_name is not None:
|
||||
mid = None
|
||||
if mid_name in lora.keys():
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
||||
@ -202,6 +211,11 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
|
||||
|
||||
diffusers_lora_key = "unet.{}".format(k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
||||
if diffusers_lora_key.endswith(".to_out.0"):
|
||||
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||
key_map[diffusers_lora_key] = "diffusion_model.{}".format(diffusers_keys[k])
|
||||
return key_map
|
||||
|
||||
def set_attr(obj, attr, value):
|
||||
@ -864,7 +878,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
use_fp16 = model_management.should_use_fp16()
|
||||
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = 3
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = cldm.ControlNet(**controlnet_config)
|
||||
|
||||
if pth:
|
||||
@ -1169,8 +1183,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model_config.get_model(sd, "model.diffusion_model.")
|
||||
model = model.to(offload_device)
|
||||
model = model_config.get_model(sd, "model.diffusion_model.", device=offload_device)
|
||||
model.load_model_weights(sd, "model.diffusion_model.")
|
||||
|
||||
if output_vae:
|
||||
|
||||
@ -91,13 +91,15 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
def set_up_textual_embeddings(self, tokens, current_embeds):
|
||||
out_tokens = []
|
||||
next_new_token = token_dict_size = current_embeds.weight.shape[0]
|
||||
next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1
|
||||
embedding_weights = []
|
||||
|
||||
for x in tokens:
|
||||
tokens_temp = []
|
||||
for y in x:
|
||||
if isinstance(y, int):
|
||||
if y == token_dict_size: #EOS token
|
||||
y = -1
|
||||
tokens_temp += [y]
|
||||
else:
|
||||
if y.shape[0] == current_embeds.weight.shape[1]:
|
||||
@ -110,15 +112,21 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
tokens_temp += [self.empty_tokens[0][-1]]
|
||||
out_tokens += [tokens_temp]
|
||||
|
||||
n = token_dict_size
|
||||
if len(embedding_weights) > 0:
|
||||
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
||||
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
|
||||
n = token_dict_size
|
||||
new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
||||
new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1]
|
||||
for x in embedding_weights:
|
||||
new_embedding.weight[n] = x
|
||||
n += 1
|
||||
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding
|
||||
self.transformer.set_input_embeddings(new_embedding)
|
||||
return out_tokens
|
||||
|
||||
processed_tokens = []
|
||||
for x in out_tokens:
|
||||
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
|
||||
|
||||
return processed_tokens
|
||||
|
||||
def forward(self, tokens):
|
||||
backup_embeds = self.transformer.get_input_embeddings()
|
||||
|
||||
@ -109,8 +109,8 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
def get_model(self, state_dict, prefix=""):
|
||||
return model_base.SDXLRefiner(self)
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SDXLRefiner(self, device=device)
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
keys_to_replace = {}
|
||||
@ -152,8 +152,8 @@ class SDXL(supported_models_base.BASE):
|
||||
else:
|
||||
return model_base.ModelType.EPS
|
||||
|
||||
def get_model(self, state_dict, prefix=""):
|
||||
return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix))
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
keys_to_replace = {}
|
||||
|
||||
@ -53,13 +53,13 @@ class BASE:
|
||||
for x in self.unet_extra_config:
|
||||
self.unet_config[x] = self.unet_extra_config[x]
|
||||
|
||||
def get_model(self, state_dict, prefix=""):
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
if self.inpaint_model():
|
||||
return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix))
|
||||
return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||
elif self.noise_aug_config is not None:
|
||||
return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix))
|
||||
return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
|
||||
else:
|
||||
return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix))
|
||||
return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
@ -6,6 +6,8 @@ import folder_paths
|
||||
import json
|
||||
import os
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
class ModelMergeSimple:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -101,8 +103,7 @@ class CheckpointSave:
|
||||
if prompt is not None:
|
||||
prompt_info = json.dumps(prompt)
|
||||
|
||||
metadata = {"prompt": prompt_info}
|
||||
|
||||
metadata = {}
|
||||
|
||||
enable_modelspec = True
|
||||
if isinstance(model.model, comfy.model_base.SDXL):
|
||||
@ -127,9 +128,11 @@ class CheckpointSave:
|
||||
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
|
||||
metadata["modelspec.predict_key"] = "v"
|
||||
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
if not args.disable_metadata:
|
||||
metadata["prompt"] = prompt_info
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
@ -40,7 +40,8 @@ def cuda_malloc_supported():
|
||||
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
|
||||
"GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620",
|
||||
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
|
||||
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000"}
|
||||
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000",
|
||||
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M"}
|
||||
|
||||
try:
|
||||
names = get_gpu_names()
|
||||
|
||||
17
execution.py
17
execution.py
@ -42,11 +42,14 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
|
||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
# check if node wants the lists
|
||||
intput_is_list = False
|
||||
input_is_list = False
|
||||
if hasattr(obj, "INPUT_IS_LIST"):
|
||||
intput_is_list = obj.INPUT_IS_LIST
|
||||
input_is_list = obj.INPUT_IS_LIST
|
||||
|
||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||
if len(input_data_all) == 0:
|
||||
max_len_input = 0
|
||||
else:
|
||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||
|
||||
# get a slice of inputs, repeat last input when list isn't long enough
|
||||
def slice_dict(d, i):
|
||||
@ -56,11 +59,15 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
return d_new
|
||||
|
||||
results = []
|
||||
if intput_is_list:
|
||||
if input_is_list:
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**input_data_all))
|
||||
else:
|
||||
elif max_len_input == 0:
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)())
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
|
||||
2
main.py
2
main.py
@ -160,6 +160,8 @@ if __name__ == "__main__":
|
||||
if args.auto_launch:
|
||||
def startup_server(address, port):
|
||||
import webbrowser
|
||||
if os.name == 'nt' and address == '0.0.0.0':
|
||||
address = '127.0.0.1'
|
||||
webbrowser.open(f"http://{address}:{port}")
|
||||
call_on_start = startup_server
|
||||
|
||||
|
||||
79
nodes.py
79
nodes.py
@ -26,6 +26,8 @@ import comfy.utils
|
||||
import comfy.clip_vision
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args
|
||||
|
||||
import importlib
|
||||
|
||||
import folder_paths
|
||||
@ -352,12 +354,22 @@ class SaveLatent:
|
||||
if prompt is not None:
|
||||
prompt_info = json.dumps(prompt)
|
||||
|
||||
metadata = {"prompt": prompt_info}
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
metadata = None
|
||||
if not args.disable_metadata:
|
||||
metadata = {"prompt": prompt_info}
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
file = f"{filename}_{counter:05}_.latent"
|
||||
|
||||
results = list()
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": "output"
|
||||
})
|
||||
|
||||
file = os.path.join(full_output_folder, file)
|
||||
|
||||
output = {}
|
||||
@ -365,7 +377,7 @@ class SaveLatent:
|
||||
output["latent_format_version_0"] = torch.tensor([])
|
||||
|
||||
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
||||
return {}
|
||||
return { "ui": { "latents": results } }
|
||||
|
||||
|
||||
class LoadLatent:
|
||||
@ -1043,6 +1055,47 @@ class LatentComposite:
|
||||
samples_out["samples"] = s
|
||||
return (samples_out,)
|
||||
|
||||
class LatentBlend:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"samples1": ("LATENT",),
|
||||
"samples2": ("LATENT",),
|
||||
"blend_factor": ("FLOAT", {
|
||||
"default": 0.5,
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"step": 0.01
|
||||
}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "blend"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
|
||||
|
||||
samples_out = samples1.copy()
|
||||
samples1 = samples1["samples"]
|
||||
samples2 = samples2["samples"]
|
||||
|
||||
if samples1.shape != samples2.shape:
|
||||
samples2.permute(0, 3, 1, 2)
|
||||
samples2 = comfy.utils.common_upscale(samples2, samples1.shape[3], samples1.shape[2], 'bicubic', crop='center')
|
||||
samples2.permute(0, 2, 3, 1)
|
||||
|
||||
samples_blended = self.blend_mode(samples1, samples2, blend_mode)
|
||||
samples_blended = samples1 * blend_factor + samples_blended * (1 - blend_factor)
|
||||
samples_out["samples"] = samples_blended
|
||||
return (samples_out,)
|
||||
|
||||
def blend_mode(self, img1, img2, mode):
|
||||
if mode == "normal":
|
||||
return img2
|
||||
else:
|
||||
raise ValueError(f"Unsupported blend mode: {mode}")
|
||||
|
||||
class LatentCrop:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1214,12 +1267,14 @@ class SaveImage:
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
metadata = PngInfo()
|
||||
if prompt is not None:
|
||||
metadata.add_text("prompt", json.dumps(prompt))
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
||||
metadata = None
|
||||
if not args.disable_metadata:
|
||||
metadata = PngInfo()
|
||||
if prompt is not None:
|
||||
metadata.add_text("prompt", json.dumps(prompt))
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
||||
|
||||
file = f"{filename}_{counter:05}_.png"
|
||||
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
|
||||
@ -1487,6 +1542,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"KSamplerAdvanced": KSamplerAdvanced,
|
||||
"SetLatentNoiseMask": SetLatentNoiseMask,
|
||||
"LatentComposite": LatentComposite,
|
||||
"LatentBlend": LatentBlend,
|
||||
"LatentRotate": LatentRotate,
|
||||
"LatentFlip": LatentFlip,
|
||||
"LatentCrop": LatentCrop,
|
||||
@ -1558,6 +1614,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LatentUpscale": "Upscale Latent",
|
||||
"LatentUpscaleBy": "Upscale Latent By",
|
||||
"LatentComposite": "Latent Composite",
|
||||
"LatentBlend": "Latent Blend",
|
||||
"LatentFromBatch" : "Latent From Batch",
|
||||
"RepeatLatentBatch": "Repeat Latent Batch",
|
||||
# Image
|
||||
|
||||
@ -159,13 +159,64 @@
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "kkkkkkkkkkkkkkk"
|
||||
},
|
||||
"source": [
|
||||
"### Run ComfyUI with cloudflared (Recommended Way)\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "jjjjjjjjjjjjjj"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb\n",
|
||||
"!dpkg -i cloudflared-linux-amd64.deb\n",
|
||||
"\n",
|
||||
"import subprocess\n",
|
||||
"import threading\n",
|
||||
"import time\n",
|
||||
"import socket\n",
|
||||
"import urllib.request\n",
|
||||
"\n",
|
||||
"def iframe_thread(port):\n",
|
||||
" while True:\n",
|
||||
" time.sleep(0.5)\n",
|
||||
" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
|
||||
" result = sock.connect_ex(('127.0.0.1', port))\n",
|
||||
" if result == 0:\n",
|
||||
" break\n",
|
||||
" sock.close()\n",
|
||||
" print(\"\\nComfyUI finished loading, trying to launch cloudflared (if it gets stuck here cloudflared is having issues)\\n\")\n",
|
||||
"\n",
|
||||
" p = subprocess.Popen([\"cloudflared\", \"tunnel\", \"--url\", \"http://127.0.0.1:{}\".format(port)], stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
|
||||
" for line in p.stderr:\n",
|
||||
" l = line.decode()\n",
|
||||
" if \"trycloudflare.com \" in l:\n",
|
||||
" print(\"This is the URL to access ComfyUI:\", l[l.find(\"http\"):], end='')\n",
|
||||
" #print(l, end='')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n",
|
||||
"\n",
|
||||
"!python main.py --dont-print-server"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "kkkkkkkkkkkkkk"
|
||||
},
|
||||
"source": [
|
||||
"### Run ComfyUI with localtunnel (Recommended Way)\n",
|
||||
"### Run ComfyUI with localtunnel\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
torch
|
||||
torchdiffeq
|
||||
torchsde
|
||||
einops
|
||||
transformers>=4.25.1
|
||||
|
||||
@ -345,6 +345,11 @@ class PromptServer():
|
||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||
system_stats = {
|
||||
"system": {
|
||||
"os": os.name,
|
||||
"python_version": sys.version,
|
||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
|
||||
},
|
||||
"devices": [
|
||||
{
|
||||
"name": device_name,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import {app} from "/scripts/app.js";
|
||||
import {app} from "../../scripts/app.js";
|
||||
|
||||
// Adds filtering to combo context menus
|
||||
|
||||
@ -27,10 +27,13 @@ const ext = {
|
||||
const clickedComboValue = currentNode.widgets
|
||||
.filter(w => w.type === "combo" && w.options.values.length === values.length)
|
||||
.find(w => w.options.values.every((v, i) => v === values[i]))
|
||||
.value;
|
||||
?.value;
|
||||
|
||||
let selectedIndex = values.findIndex(v => v === clickedComboValue);
|
||||
let selectedItem = displayedItems?.[selectedIndex];
|
||||
let selectedIndex = clickedComboValue ? values.findIndex(v => v === clickedComboValue) : 0;
|
||||
if (selectedIndex < 0) {
|
||||
selectedIndex = 0;
|
||||
}
|
||||
let selectedItem = displayedItems[selectedIndex];
|
||||
updateSelected();
|
||||
|
||||
// Apply highlighting to the selected item
|
||||
|
||||
25
web/extensions/core/linkRenderMode.js
Normal file
25
web/extensions/core/linkRenderMode.js
Normal file
@ -0,0 +1,25 @@
|
||||
import { app } from "/scripts/app.js";
|
||||
|
||||
const id = "Comfy.LinkRenderMode";
|
||||
const ext = {
|
||||
name: id,
|
||||
async setup(app) {
|
||||
app.ui.settings.addSetting({
|
||||
id,
|
||||
name: "Link Render Mode",
|
||||
defaultValue: 2,
|
||||
type: "combo",
|
||||
options: LiteGraph.LINK_RENDER_MODES.map((m, i) => ({
|
||||
value: i,
|
||||
text: m,
|
||||
selected: i == app.canvas.links_render_mode,
|
||||
})),
|
||||
onChange(value) {
|
||||
app.canvas.links_render_mode = +value;
|
||||
app.graph.setDirtyCanvas(true);
|
||||
},
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
app.registerExtension(ext);
|
||||
@ -2,7 +2,7 @@ import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js";
|
||||
import { app } from "../../scripts/app.js";
|
||||
|
||||
const CONVERTED_TYPE = "converted-widget";
|
||||
const VALID_TYPES = ["STRING", "combo", "number"];
|
||||
const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"];
|
||||
|
||||
function isConvertableWidget(widget, config) {
|
||||
return VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0]);
|
||||
|
||||
@ -9835,7 +9835,11 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
ctx.textAlign = "center";
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.fillText(
|
||||
w.label || w.name + " " + Number(w.value).toFixed(3),
|
||||
w.label || w.name + " " + Number(w.value).toFixed(
|
||||
w.options.precision != null
|
||||
? w.options.precision
|
||||
: 3
|
||||
),
|
||||
widget_width * 0.5,
|
||||
y + H * 0.7
|
||||
);
|
||||
@ -13835,7 +13839,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
if (!disabled) {
|
||||
element.addEventListener("click", inner_onclick);
|
||||
}
|
||||
if (options.autoopen) {
|
||||
if (!disabled && options.autoopen) {
|
||||
LiteGraph.pointerListenerAdd(element,"enter",inner_over);
|
||||
}
|
||||
|
||||
|
||||
@ -264,6 +264,15 @@ class ComfyApi extends EventTarget {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets system & device stats
|
||||
* @returns System stats such as python version, OS, per device info
|
||||
*/
|
||||
async getSystemStats() {
|
||||
const res = await this.fetchApi("/system_stats");
|
||||
return await res.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a POST request to the API
|
||||
* @param {*} type The endpoint to post to
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import { ComfyLogging } from "./logging.js";
|
||||
import { ComfyWidgets } from "./widgets.js";
|
||||
import { ComfyUI, $el } from "./ui.js";
|
||||
import { api } from "./api.js";
|
||||
@ -31,6 +32,7 @@ export class ComfyApp {
|
||||
|
||||
constructor() {
|
||||
this.ui = new ComfyUI(this);
|
||||
this.logging = new ComfyLogging(this);
|
||||
|
||||
/**
|
||||
* List of extensions that are registered with the app
|
||||
@ -768,6 +770,19 @@ export class ComfyApp {
|
||||
}
|
||||
block_default = true;
|
||||
}
|
||||
|
||||
if (e.keyCode == 66 && e.ctrlKey) {
|
||||
if (this.selected_nodes) {
|
||||
for (var i in this.selected_nodes) {
|
||||
if (this.selected_nodes[i].mode === 4) { // never
|
||||
this.selected_nodes[i].mode = 0; // always
|
||||
} else {
|
||||
this.selected_nodes[i].mode = 4; // never
|
||||
}
|
||||
}
|
||||
}
|
||||
block_default = true;
|
||||
}
|
||||
}
|
||||
|
||||
this.graph.change();
|
||||
@ -914,14 +929,21 @@ export class ComfyApp {
|
||||
const origDrawNode = LGraphCanvas.prototype.drawNode;
|
||||
LGraphCanvas.prototype.drawNode = function (node, ctx) {
|
||||
var editor_alpha = this.editor_alpha;
|
||||
var old_color = node.bgcolor;
|
||||
|
||||
if (node.mode === 2) { // never
|
||||
this.editor_alpha = 0.4;
|
||||
}
|
||||
|
||||
if (node.mode === 4) { // never
|
||||
node.bgcolor = "#FF00FF";
|
||||
this.editor_alpha = 0.2;
|
||||
}
|
||||
|
||||
const res = origDrawNode.apply(this, arguments);
|
||||
|
||||
this.editor_alpha = editor_alpha;
|
||||
node.bgcolor = old_color;
|
||||
|
||||
return res;
|
||||
};
|
||||
@ -1003,6 +1025,7 @@ export class ComfyApp {
|
||||
*/
|
||||
async #loadExtensions() {
|
||||
const extensions = await api.getExtensions();
|
||||
this.logging.addEntry("Comfy.App", "debug", { Extensions: extensions });
|
||||
for (const ext of extensions) {
|
||||
try {
|
||||
await import(api.apiURL(ext));
|
||||
@ -1286,6 +1309,9 @@ export class ComfyApp {
|
||||
(t) => `<li>${t}</li>`
|
||||
).join("")}</ul>Nodes that have failed to load will show as red on the graph.`
|
||||
);
|
||||
this.logging.addEntry("Comfy.App", "warn", {
|
||||
MissingNodes: missingNodeTypes,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1308,7 +1334,7 @@ export class ComfyApp {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node.mode === 2) {
|
||||
if (node.mode === 2 || node.mode === 4) {
|
||||
// Don't serialize muted nodes
|
||||
continue;
|
||||
}
|
||||
@ -1331,12 +1357,36 @@ export class ComfyApp {
|
||||
let parent = node.getInputNode(i);
|
||||
if (parent) {
|
||||
let link = node.getInputLink(i);
|
||||
while (parent && parent.isVirtualNode) {
|
||||
link = parent.getInputLink(link.origin_slot);
|
||||
if (link) {
|
||||
parent = parent.getInputNode(link.origin_slot);
|
||||
} else {
|
||||
parent = null;
|
||||
while (parent.mode === 4 || parent.isVirtualNode) {
|
||||
let found = false;
|
||||
if (parent.isVirtualNode) {
|
||||
link = parent.getInputLink(link.origin_slot);
|
||||
if (link) {
|
||||
parent = parent.getInputNode(link.target_slot);
|
||||
if (parent) {
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
} else if (link && parent.mode === 4) {
|
||||
let all_inputs = [link.origin_slot];
|
||||
if (parent.inputs) {
|
||||
all_inputs = all_inputs.concat(Object.keys(parent.inputs))
|
||||
for (let parent_input in all_inputs) {
|
||||
parent_input = all_inputs[parent_input];
|
||||
if (parent.inputs[parent_input].type === node.inputs[i].type) {
|
||||
link = parent.getInputLink(parent_input);
|
||||
if (link) {
|
||||
parent = parent.getInputNode(parent_input);
|
||||
}
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
367
web/scripts/logging.js
Normal file
367
web/scripts/logging.js
Normal file
@ -0,0 +1,367 @@
|
||||
import { $el, ComfyDialog } from "./ui.js";
|
||||
import { api } from "./api.js";
|
||||
|
||||
$el("style", {
|
||||
textContent: `
|
||||
.comfy-logging-logs {
|
||||
display: grid;
|
||||
color: var(--fg-color);
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
.comfy-logging-log {
|
||||
display: contents;
|
||||
}
|
||||
.comfy-logging-title {
|
||||
background: var(--tr-even-bg-color);
|
||||
font-weight: bold;
|
||||
margin-bottom: 5px;
|
||||
text-align: center;
|
||||
}
|
||||
.comfy-logging-log div {
|
||||
background: var(--row-bg);
|
||||
padding: 5px;
|
||||
}
|
||||
`,
|
||||
parent: document.body,
|
||||
});
|
||||
|
||||
// Stringify function supporting max depth and removal of circular references
|
||||
// https://stackoverflow.com/a/57193345
|
||||
function stringify(val, depth, replacer, space, onGetObjID) {
|
||||
depth = isNaN(+depth) ? 1 : depth;
|
||||
var recursMap = new WeakMap();
|
||||
function _build(val, depth, o, a, r) {
|
||||
// (JSON.stringify() has it's own rules, which we respect here by using it for property iteration)
|
||||
return !val || typeof val != "object"
|
||||
? val
|
||||
: ((r = recursMap.has(val)),
|
||||
recursMap.set(val, true),
|
||||
(a = Array.isArray(val)),
|
||||
r
|
||||
? (o = (onGetObjID && onGetObjID(val)) || null)
|
||||
: JSON.stringify(val, function (k, v) {
|
||||
if (a || depth > 0) {
|
||||
if (replacer) v = replacer(k, v);
|
||||
if (!k) return (a = Array.isArray(v)), (val = v);
|
||||
!o && (o = a ? [] : {});
|
||||
o[k] = _build(v, a ? depth : depth - 1);
|
||||
}
|
||||
}),
|
||||
o === void 0 ? (a ? [] : {}) : o);
|
||||
}
|
||||
return JSON.stringify(_build(val, depth), null, space);
|
||||
}
|
||||
|
||||
const jsonReplacer = (k, v, ui) => {
|
||||
if (v instanceof Array && v.length === 1) {
|
||||
v = v[0];
|
||||
}
|
||||
if (v instanceof Date) {
|
||||
v = v.toISOString();
|
||||
if (ui) {
|
||||
v = v.split("T")[1];
|
||||
}
|
||||
}
|
||||
if (v instanceof Error) {
|
||||
let err = "";
|
||||
if (v.name) err += v.name + "\n";
|
||||
if (v.message) err += v.message + "\n";
|
||||
if (v.stack) err += v.stack + "\n";
|
||||
if (!err) {
|
||||
err = v.toString();
|
||||
}
|
||||
v = err;
|
||||
}
|
||||
return v;
|
||||
};
|
||||
|
||||
const fileInput = $el("input", {
|
||||
type: "file",
|
||||
accept: ".json",
|
||||
style: { display: "none" },
|
||||
parent: document.body,
|
||||
});
|
||||
|
||||
class ComfyLoggingDialog extends ComfyDialog {
|
||||
constructor(logging) {
|
||||
super();
|
||||
this.logging = logging;
|
||||
}
|
||||
|
||||
clear() {
|
||||
this.logging.clear();
|
||||
this.show();
|
||||
}
|
||||
|
||||
export() {
|
||||
const blob = new Blob([stringify([...this.logging.entries], 20, jsonReplacer, "\t")], {
|
||||
type: "application/json",
|
||||
});
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = $el("a", {
|
||||
href: url,
|
||||
download: `comfyui-logs-${Date.now()}.json`,
|
||||
style: { display: "none" },
|
||||
parent: document.body,
|
||||
});
|
||||
a.click();
|
||||
setTimeout(function () {
|
||||
a.remove();
|
||||
window.URL.revokeObjectURL(url);
|
||||
}, 0);
|
||||
}
|
||||
|
||||
import() {
|
||||
fileInput.onchange = () => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = () => {
|
||||
fileInput.remove();
|
||||
try {
|
||||
const obj = JSON.parse(reader.result);
|
||||
if (obj instanceof Array) {
|
||||
this.show(obj);
|
||||
} else {
|
||||
throw new Error("Invalid file selected.");
|
||||
}
|
||||
} catch (error) {
|
||||
alert("Unable to load logs: " + error.message);
|
||||
}
|
||||
};
|
||||
reader.readAsText(fileInput.files[0]);
|
||||
};
|
||||
fileInput.click();
|
||||
}
|
||||
|
||||
createButtons() {
|
||||
return [
|
||||
$el("button", {
|
||||
type: "button",
|
||||
textContent: "Clear",
|
||||
onclick: () => this.clear(),
|
||||
}),
|
||||
$el("button", {
|
||||
type: "button",
|
||||
textContent: "Export logs...",
|
||||
onclick: () => this.export(),
|
||||
}),
|
||||
$el("button", {
|
||||
type: "button",
|
||||
textContent: "View exported logs...",
|
||||
onclick: () => this.import(),
|
||||
}),
|
||||
...super.createButtons(),
|
||||
];
|
||||
}
|
||||
|
||||
getTypeColor(type) {
|
||||
switch (type) {
|
||||
case "error":
|
||||
return "red";
|
||||
case "warn":
|
||||
return "orange";
|
||||
case "debug":
|
||||
return "dodgerblue";
|
||||
}
|
||||
}
|
||||
|
||||
show(entries) {
|
||||
if (!entries) entries = this.logging.entries;
|
||||
this.element.style.width = "100%";
|
||||
const cols = {
|
||||
source: "Source",
|
||||
type: "Type",
|
||||
timestamp: "Timestamp",
|
||||
message: "Message",
|
||||
};
|
||||
const keys = Object.keys(cols);
|
||||
const headers = Object.values(cols).map((title) =>
|
||||
$el("div.comfy-logging-title", {
|
||||
textContent: title,
|
||||
})
|
||||
);
|
||||
const rows = entries.map((entry, i) => {
|
||||
return $el(
|
||||
"div.comfy-logging-log",
|
||||
{
|
||||
$: (el) => el.style.setProperty("--row-bg", `var(--tr-${i % 2 ? "even" : "odd"}-bg-color)`),
|
||||
},
|
||||
keys.map((key) => {
|
||||
let v = entry[key];
|
||||
let color;
|
||||
if (key === "type") {
|
||||
color = this.getTypeColor(v);
|
||||
} else {
|
||||
v = jsonReplacer(key, v, true);
|
||||
|
||||
if (typeof v === "object") {
|
||||
v = stringify(v, 5, jsonReplacer, " ");
|
||||
}
|
||||
}
|
||||
|
||||
return $el("div", {
|
||||
style: {
|
||||
color,
|
||||
},
|
||||
textContent: v,
|
||||
});
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
const grid = $el(
|
||||
"div.comfy-logging-logs",
|
||||
{
|
||||
style: {
|
||||
gridTemplateColumns: `repeat(${headers.length}, 1fr)`,
|
||||
},
|
||||
},
|
||||
[...headers, ...rows]
|
||||
);
|
||||
const els = [grid];
|
||||
if (!this.logging.enabled) {
|
||||
els.unshift(
|
||||
$el("h3", {
|
||||
style: { textAlign: "center" },
|
||||
textContent: "Logging is disabled",
|
||||
})
|
||||
);
|
||||
}
|
||||
super.show($el("div", els));
|
||||
}
|
||||
}
|
||||
|
||||
export class ComfyLogging {
|
||||
/**
|
||||
* @type Array<{ source: string, type: string, timestamp: Date, message: any }>
|
||||
*/
|
||||
entries = [];
|
||||
|
||||
#enabled;
|
||||
#console = {};
|
||||
|
||||
get enabled() {
|
||||
return this.#enabled;
|
||||
}
|
||||
|
||||
set enabled(value) {
|
||||
if (value === this.#enabled) return;
|
||||
if (value) {
|
||||
this.patchConsole();
|
||||
} else {
|
||||
this.unpatchConsole();
|
||||
}
|
||||
this.#enabled = value;
|
||||
}
|
||||
|
||||
constructor(app) {
|
||||
this.app = app;
|
||||
|
||||
this.dialog = new ComfyLoggingDialog(this);
|
||||
this.addSetting();
|
||||
this.catchUnhandled();
|
||||
this.addInitData();
|
||||
}
|
||||
|
||||
addSetting() {
|
||||
const settingId = "Comfy.Logging.Enabled";
|
||||
const htmlSettingId = settingId.replaceAll(".", "-");
|
||||
const setting = this.app.ui.settings.addSetting({
|
||||
id: settingId,
|
||||
name: settingId,
|
||||
defaultValue: true,
|
||||
type: (name, setter, value) => {
|
||||
return $el("tr", [
|
||||
$el("td", [
|
||||
$el("label", {
|
||||
textContent: "Logging",
|
||||
for: htmlSettingId,
|
||||
}),
|
||||
]),
|
||||
$el("td", [
|
||||
$el("input", {
|
||||
id: htmlSettingId,
|
||||
type: "checkbox",
|
||||
checked: value,
|
||||
onchange: (event) => {
|
||||
setter((this.enabled = event.target.checked));
|
||||
},
|
||||
}),
|
||||
$el("button", {
|
||||
textContent: "View Logs",
|
||||
onclick: () => {
|
||||
this.app.ui.settings.element.close();
|
||||
this.dialog.show();
|
||||
},
|
||||
style: {
|
||||
fontSize: "14px",
|
||||
display: "block",
|
||||
marginTop: "5px",
|
||||
},
|
||||
}),
|
||||
]),
|
||||
]);
|
||||
},
|
||||
});
|
||||
this.enabled = setting.value;
|
||||
}
|
||||
|
||||
patchConsole() {
|
||||
// Capture common console outputs
|
||||
const self = this;
|
||||
for (const type of ["log", "warn", "error", "debug"]) {
|
||||
const orig = console[type];
|
||||
this.#console[type] = orig;
|
||||
console[type] = function () {
|
||||
orig.apply(console, arguments);
|
||||
self.addEntry("console", type, ...arguments);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
unpatchConsole() {
|
||||
// Restore original console functions
|
||||
for (const type of Object.keys(this.#console)) {
|
||||
console[type] = this.#console[type];
|
||||
}
|
||||
this.#console = {};
|
||||
}
|
||||
|
||||
catchUnhandled() {
|
||||
// Capture uncaught errors
|
||||
window.addEventListener("error", (e) => {
|
||||
this.addEntry("window", "error", e.error ?? "Unknown error");
|
||||
return false;
|
||||
});
|
||||
|
||||
window.addEventListener("unhandledrejection", (e) => {
|
||||
this.addEntry("unhandledrejection", "error", e.reason ?? "Unknown error");
|
||||
});
|
||||
}
|
||||
|
||||
clear() {
|
||||
this.entries = [];
|
||||
}
|
||||
|
||||
addEntry(source, type, ...args) {
|
||||
if (this.enabled) {
|
||||
this.entries.push({
|
||||
source,
|
||||
type,
|
||||
timestamp: new Date(),
|
||||
message: args,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
log(source, ...args) {
|
||||
this.addEntry(source, "log", ...args);
|
||||
}
|
||||
|
||||
async addInitData() {
|
||||
if (!this.enabled) return;
|
||||
const source = "ComfyUI.Logging";
|
||||
this.addEntry(source, "debug", { UserAgent: navigator.userAgent });
|
||||
const systemStats = await api.getSystemStats();
|
||||
this.addEntry(source, "debug", systemStats);
|
||||
}
|
||||
}
|
||||
@ -234,7 +234,7 @@ class ComfySettingsDialog extends ComfyDialog {
|
||||
localStorage[settingId] = JSON.stringify(value);
|
||||
}
|
||||
|
||||
addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "",}) {
|
||||
addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined}) {
|
||||
if (!id) {
|
||||
throw new Error("Settings must have an ID");
|
||||
}
|
||||
@ -347,6 +347,32 @@ class ComfySettingsDialog extends ComfyDialog {
|
||||
]),
|
||||
]);
|
||||
break;
|
||||
case "combo":
|
||||
element = $el("tr", [
|
||||
labelCell,
|
||||
$el("td", [
|
||||
$el(
|
||||
"select",
|
||||
{
|
||||
oninput: (e) => {
|
||||
setter(e.target.value);
|
||||
},
|
||||
},
|
||||
(typeof options === "function" ? options(value) : options || []).map((opt) => {
|
||||
if (typeof opt === "string") {
|
||||
opt = { text: opt };
|
||||
}
|
||||
const v = opt.value ?? opt.text;
|
||||
return $el("option", {
|
||||
value: v,
|
||||
textContent: opt.text,
|
||||
selected: value + "" === v + "",
|
||||
});
|
||||
})
|
||||
),
|
||||
]),
|
||||
]);
|
||||
break;
|
||||
case "text":
|
||||
default:
|
||||
if (type !== "text") {
|
||||
@ -480,7 +506,7 @@ class ComfyList {
|
||||
|
||||
hide() {
|
||||
this.element.style.display = "none";
|
||||
this.button.textContent = "See " + this.#text;
|
||||
this.button.textContent = "View " + this.#text;
|
||||
}
|
||||
|
||||
toggle() {
|
||||
@ -542,6 +568,13 @@ export class ComfyUI {
|
||||
defaultValue: "",
|
||||
});
|
||||
|
||||
this.settings.addSetting({
|
||||
id: "Comfy.DisableSliders",
|
||||
name: "Disable sliders.",
|
||||
type: "boolean",
|
||||
defaultValue: false,
|
||||
});
|
||||
|
||||
const fileInput = $el("input", {
|
||||
id: "comfy-file-input",
|
||||
type: "file",
|
||||
|
||||
@ -79,8 +79,8 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
|
||||
return valueControl;
|
||||
};
|
||||
|
||||
function seedWidget(node, inputName, inputData) {
|
||||
const seed = ComfyWidgets.INT(node, inputName, inputData);
|
||||
function seedWidget(node, inputName, inputData, app) {
|
||||
const seed = ComfyWidgets.INT(node, inputName, inputData, app);
|
||||
const seedControl = addValueControlWidget(node, seed.widget, "randomize");
|
||||
|
||||
seed.widget.linkedWidgets = [seedControl];
|
||||
@ -250,19 +250,29 @@ function addMultilineWidget(node, name, opts, app) {
|
||||
return { minWidth: 400, minHeight: 200, widget };
|
||||
}
|
||||
|
||||
function isSlider(display, app) {
|
||||
if (app.ui.settings.getSettingValue("Comfy.DisableSliders")) {
|
||||
return "number"
|
||||
}
|
||||
|
||||
return (display==="slider") ? "slider" : "number"
|
||||
}
|
||||
|
||||
export const ComfyWidgets = {
|
||||
"INT:seed": seedWidget,
|
||||
"INT:noise_seed": seedWidget,
|
||||
FLOAT(node, inputName, inputData) {
|
||||
FLOAT(node, inputName, inputData, app) {
|
||||
let widgetType = isSlider(inputData[1]["display"], app);
|
||||
const { val, config } = getNumberDefaults(inputData, 0.5);
|
||||
return { widget: node.addWidget("number", inputName, val, () => {}, config) };
|
||||
return { widget: node.addWidget(widgetType, inputName, val, () => {}, config) };
|
||||
},
|
||||
INT(node, inputName, inputData) {
|
||||
INT(node, inputName, inputData, app) {
|
||||
let widgetType = isSlider(inputData[1]["display"], app);
|
||||
const { val, config } = getNumberDefaults(inputData, 1);
|
||||
Object.assign(config, { precision: 0 });
|
||||
return {
|
||||
widget: node.addWidget(
|
||||
"number",
|
||||
widgetType,
|
||||
inputName,
|
||||
val,
|
||||
function (v) {
|
||||
@ -273,6 +283,18 @@ export const ComfyWidgets = {
|
||||
),
|
||||
};
|
||||
},
|
||||
BOOLEAN(node, inputName, inputData) {
|
||||
let defaultVal = inputData[1]["default"];
|
||||
return {
|
||||
widget: node.addWidget(
|
||||
"toggle",
|
||||
inputName,
|
||||
defaultVal,
|
||||
() => {},
|
||||
{"on": inputData[1].label_on, "off": inputData[1].label_off}
|
||||
)
|
||||
};
|
||||
},
|
||||
STRING(node, inputName, inputData, app) {
|
||||
const defaultVal = inputData[1].default || "";
|
||||
const multiline = !!inputData[1].multiline;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user