Merge branch 'comfyanonymous:master' into master

This commit is contained in:
Gaurav614 2023-08-07 12:08:06 +05:30 committed by GitHub
commit e11af79bcc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 876 additions and 461 deletions

View File

@ -94,7 +94,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2``` ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
This is the command to install the nightly with ROCm 5.6 that supports the 7000 series and might have some performance improvements: This is the command to install the nightly with ROCm 5.6 that supports the 7000 series and might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.6 -r requirements.txt``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.6```
### NVIDIA ### NVIDIA
@ -126,10 +126,10 @@ After this you should have everything installed and can proceed to running Comfy
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version. You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
1. Install pytorch. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide. 1. Install pytorch nightly. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide (make sure to install the latest pytorch nightly).
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux. 1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies). 1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies).
1. Launch ComfyUI by running `python main.py`. 1. Launch ComfyUI by running `python main.py --force-fp16`. Note that --force-fp16 will only work if you installed the latest pytorch nightly.
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux). > **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).

View File

@ -13,7 +13,7 @@ from ..ldm.modules.diffusionmodules.util import (
) )
from ..ldm.modules.attention import SpatialTransformer from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
from ..ldm.util import exists from ..ldm.util import exists
@ -57,6 +57,7 @@ class ControlNet(nn.Module):
transformer_depth_middle=None, transformer_depth_middle=None,
): ):
super().__init__() super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
if use_spatial_transformer: if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
@ -200,13 +201,7 @@ class ControlNet(nn.Module):
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append( layers.append(
AttentionBlock( SpatialTransformer(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint use_checkpoint=use_checkpoint
@ -259,13 +254,7 @@ class ControlNet(nn.Module):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
), ),
AttentionBlock( SpatialTransformer( # always uses a self-attn
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint use_checkpoint=use_checkpoint

View File

@ -39,6 +39,7 @@ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORI
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
cm_group = parser.add_mutually_exclusive_group() cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
@ -84,7 +85,12 @@ parser.add_argument("--dont-print-server", action="store_true", help="Don't prin
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
args = parser.parse_args() args = parser.parse_args()
if args.windows_standalone_build: if args.windows_standalone_build:
args.auto_launch = True args.auto_launch = True
if args.disable_auto_launch:
args.auto_launch = False

View File

@ -3,7 +3,6 @@ import math
from scipy import integrate from scipy import integrate
import torch import torch
from torch import nn from torch import nn
from torchdiffeq import odeint
import torchsde import torchsde
from tqdm.auto import trange, tqdm from tqdm.auto import trange, tqdm
@ -287,30 +286,6 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
return x return x
@torch.no_grad()
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
v = torch.randint_like(x, 2) * 2 - 1
fevals = 0
def ode_fn(sigma, x):
nonlocal fevals
with torch.enable_grad():
x = x[0].detach().requires_grad_()
denoised = model(x, sigma * s_in, **extra_args)
d = to_d(x, sigma, denoised)
fevals += 1
grad = torch.autograd.grad((d * v).sum(), x)[0]
d_ll = (v * grad).flatten(1).sum(1)
return d.detach(), d_ll
x_min = x, x.new_zeros([x.shape[0]])
t = x.new_tensor([sigma_min, sigma_max])
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
latent, delta_ll = sol[0][-1], sol[1][-1]
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
return ll_prior + delta_ll, {'fevals': fevals}
class PIDStepSizeController: class PIDStepSizeController:
"""A PID controller for ODE adaptive step size control.""" """A PID controller for ODE adaptive step size control."""
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):

View File

@ -52,9 +52,9 @@ def init_(tensor):
# feedforward # feedforward
class GEGLU(nn.Module): class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None): def __init__(self, dim_in, dim_out, dtype=None, device=None):
super().__init__() super().__init__()
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype) self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
def forward(self, x): def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1) x, gate = self.proj(x).chunk(2, dim=-1)
@ -62,19 +62,19 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential( project_in = nn.Sequential(
comfy.ops.Linear(dim, inner_dim, dtype=dtype), comfy.ops.Linear(dim, inner_dim, dtype=dtype, device=device),
nn.GELU() nn.GELU()
) if not glu else GEGLU(dim, inner_dim, dtype=dtype) ) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device)
self.net = nn.Sequential( self.net = nn.Sequential(
project_in, project_in,
nn.Dropout(dropout), nn.Dropout(dropout),
comfy.ops.Linear(inner_dim, dim_out, dtype=dtype) comfy.ops.Linear(inner_dim, dim_out, dtype=dtype, device=device)
) )
def forward(self, x): def forward(self, x):
@ -90,8 +90,8 @@ def zero_module(module):
return module return module
def Normalize(in_channels, dtype=None): def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
class SpatialSelfAttention(nn.Module): class SpatialSelfAttention(nn.Module):
@ -148,7 +148,7 @@ class SpatialSelfAttention(nn.Module):
class CrossAttentionBirchSan(nn.Module): class CrossAttentionBirchSan(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -156,12 +156,12 @@ class CrossAttentionBirchSan(nn.Module):
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.heads = heads self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype) self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -245,7 +245,7 @@ class CrossAttentionBirchSan(nn.Module):
class CrossAttentionDoggettx(nn.Module): class CrossAttentionDoggettx(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -253,12 +253,12 @@ class CrossAttentionDoggettx(nn.Module):
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.heads = heads self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype) self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -343,7 +343,7 @@ class CrossAttentionDoggettx(nn.Module):
return self.to_out(r2) return self.to_out(r2)
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -351,12 +351,12 @@ class CrossAttention(nn.Module):
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.heads = heads self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype) self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -399,7 +399,7 @@ class CrossAttention(nn.Module):
class MemoryEfficientCrossAttention(nn.Module): class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None):
super().__init__() super().__init__()
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads.") f"{heads} heads.")
@ -409,11 +409,11 @@ class MemoryEfficientCrossAttention(nn.Module):
self.heads = heads self.heads = heads
self.dim_head = dim_head self.dim_head = dim_head
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype) self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout)) self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
@ -450,7 +450,7 @@ class MemoryEfficientCrossAttention(nn.Module):
return self.to_out(out) return self.to_out(out)
class CrossAttentionPytorch(nn.Module): class CrossAttentionPytorch(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -458,11 +458,11 @@ class CrossAttentionPytorch(nn.Module):
self.heads = heads self.heads = heads
self.dim_head = dim_head self.dim_head = dim_head
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype) self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout)) self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
@ -508,17 +508,17 @@ else:
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False, dtype=None): disable_self_attn=False, dtype=None, device=None):
super().__init__() super().__init__()
self.disable_self_attn = disable_self_attn self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype) # is a self-attention if not self.disable_self_attn context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype) # is self-attn if context is none heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim, dtype=dtype) self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm2 = nn.LayerNorm(dim, dtype=dtype) self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(dim, dtype=dtype) self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.n_heads = n_heads self.n_heads = n_heads
self.d_head = d_head self.d_head = d_head
@ -648,34 +648,34 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head, def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False, disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None): use_checkpoint=True, dtype=None, device=None):
super().__init__() super().__init__()
if exists(context_dim) and not isinstance(context_dim, list): if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth context_dim = [context_dim] * depth
self.in_channels = in_channels self.in_channels = in_channels
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = Normalize(in_channels, dtype=dtype) self.norm = Normalize(in_channels, dtype=dtype, device=device)
if not use_linear: if not use_linear:
self.proj_in = nn.Conv2d(in_channels, self.proj_in = nn.Conv2d(in_channels,
inner_dim, inner_dim,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, dtype=dtype) padding=0, dtype=dtype, device=device)
else: else:
self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype) self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype) disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device)
for d in range(depth)] for d in range(depth)]
) )
if not use_linear: if not use_linear:
self.proj_out = nn.Conv2d(inner_dim,in_channels, self.proj_out = nn.Conv2d(inner_dim,in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, dtype=dtype) padding=0, dtype=dtype, device=device)
else: else:
self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype) self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.use_linear = use_linear self.use_linear = use_linear
def forward(self, x, context=None, transformer_options={}): def forward(self, x, context=None, transformer_options={}):

View File

@ -8,6 +8,7 @@ from typing import Optional, Any
from ..attention import MemoryEfficientCrossAttention from ..attention import MemoryEfficientCrossAttention
from comfy import model_management from comfy import model_management
import comfy.ops
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():
import xformers import xformers
@ -48,7 +49,7 @@ class Upsample(nn.Module):
super().__init__() super().__init__()
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, self.conv = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -67,7 +68,7 @@ class Downsample(nn.Module):
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves # no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, self.conv = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=3, kernel_size=3,
stride=2, stride=2,
@ -95,30 +96,30 @@ class ResnetBlock(nn.Module):
self.swish = torch.nn.SiLU(inplace=True) self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels) self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, self.conv1 = comfy.ops.Conv2d(in_channels,
out_channels, out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
if temb_channels > 0: if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, self.temb_proj = comfy.ops.Linear(temb_channels,
out_channels) out_channels)
self.norm2 = Normalize(out_channels) self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True) self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = torch.nn.Conv2d(out_channels, self.conv2 = comfy.ops.Conv2d(out_channels,
out_channels, out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, self.conv_shortcut = comfy.ops.Conv2d(in_channels,
out_channels, out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
else: else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, self.nin_shortcut = comfy.ops.Conv2d(in_channels,
out_channels, out_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
@ -188,22 +189,22 @@ class AttnBlock(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, self.q = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.k = torch.nn.Conv2d(in_channels, self.k = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.v = torch.nn.Conv2d(in_channels, self.v = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, self.proj_out = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
@ -243,22 +244,22 @@ class MemoryEfficientAttnBlock(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, self.q = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.k = torch.nn.Conv2d(in_channels, self.k = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.v = torch.nn.Conv2d(in_channels, self.v = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, self.proj_out = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
@ -302,22 +303,22 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, self.q = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.k = torch.nn.Conv2d(in_channels, self.k = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.v = torch.nn.Conv2d(in_channels, self.v = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, self.proj_out = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
@ -399,14 +400,14 @@ class Model(nn.Module):
# timestep embedding # timestep embedding
self.temb = nn.Module() self.temb = nn.Module()
self.temb.dense = nn.ModuleList([ self.temb.dense = nn.ModuleList([
torch.nn.Linear(self.ch, comfy.ops.Linear(self.ch,
self.temb_ch), self.temb_ch),
torch.nn.Linear(self.temb_ch, comfy.ops.Linear(self.temb_ch,
self.temb_ch), self.temb_ch),
]) ])
# downsampling # downsampling
self.conv_in = torch.nn.Conv2d(in_channels, self.conv_in = comfy.ops.Conv2d(in_channels,
self.ch, self.ch,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -475,7 +476,7 @@ class Model(nn.Module):
# end # end
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, self.conv_out = comfy.ops.Conv2d(block_in,
out_ch, out_ch,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -548,7 +549,7 @@ class Encoder(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
# downsampling # downsampling
self.conv_in = torch.nn.Conv2d(in_channels, self.conv_in = comfy.ops.Conv2d(in_channels,
self.ch, self.ch,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -593,7 +594,7 @@ class Encoder(nn.Module):
# end # end
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, self.conv_out = comfy.ops.Conv2d(block_in,
2*z_channels if double_z else z_channels, 2*z_channels if double_z else z_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -653,7 +654,7 @@ class Decoder(nn.Module):
self.z_shape, np.prod(self.z_shape))) self.z_shape, np.prod(self.z_shape)))
# z to block_in # z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, self.conv_in = comfy.ops.Conv2d(z_channels,
block_in, block_in,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -695,7 +696,7 @@ class Decoder(nn.Module):
# end # end
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, self.conv_out = comfy.ops.Conv2d(block_in,
out_ch, out_ch,
kernel_size=3, kernel_size=3,
stride=1, stride=1,

View File

@ -19,45 +19,6 @@ from ..attention import SpatialTransformer
from comfy.ldm.util import exists from comfy.ldm.util import exists
# dummy replace
def convert_module_to_f16(x):
pass
def convert_module_to_f32(x):
pass
## go
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def __init__(
self,
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
self.attention = QKVAttention(self.num_heads)
def forward(self, x):
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1) # NC(HW)
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
return x[:, :, 0]
class TimestepBlock(nn.Module): class TimestepBlock(nn.Module):
""" """
Any module where forward() takes timestep embeddings as a second argument. Any module where forward() takes timestep embeddings as a second argument.
@ -111,14 +72,14 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None): def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims self.dims = dims
if use_conv: if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype) self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device)
def forward(self, x, output_shape=None): def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
@ -138,19 +99,6 @@ class Upsample(nn.Module):
x = self.conv(x) x = self.conv(x)
return x return x
class TransposedUpsample(nn.Module):
'Learned 2x upsampling without padding'
def __init__(self, channels, out_channels=None, ks=5):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
def forward(self,x):
return self.up(x)
class Downsample(nn.Module): class Downsample(nn.Module):
""" """
A downsampling layer with an optional convolution. A downsampling layer with an optional convolution.
@ -160,7 +108,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions. downsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None): def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
@ -169,7 +117,7 @@ class Downsample(nn.Module):
stride = 2 if dims != 3 else (1, 2, 2) stride = 2 if dims != 3 else (1, 2, 2)
if use_conv: if use_conv:
self.op = conv_nd( self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device
) )
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
@ -208,7 +156,8 @@ class ResBlock(TimestepBlock):
use_checkpoint=False, use_checkpoint=False,
up=False, up=False,
down=False, down=False,
dtype=None dtype=None,
device=None,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -220,19 +169,19 @@ class ResBlock(TimestepBlock):
self.use_scale_shift_norm = use_scale_shift_norm self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential( self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype), nn.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype), conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
) )
self.updown = up or down self.updown = up or down
if up: if up:
self.h_upd = Upsample(channels, False, dims, dtype=dtype) self.h_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
self.x_upd = Upsample(channels, False, dims, dtype=dtype) self.x_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
elif down: elif down:
self.h_upd = Downsample(channels, False, dims, dtype=dtype) self.h_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
self.x_upd = Downsample(channels, False, dims, dtype=dtype) self.x_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
else: else:
self.h_upd = self.x_upd = nn.Identity() self.h_upd = self.x_upd = nn.Identity()
@ -240,15 +189,15 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
linear( linear(
emb_channels, emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
), ),
) )
self.out_layers = nn.Sequential( self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels, dtype=dtype), nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype) conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
), ),
) )
@ -256,10 +205,10 @@ class ResBlock(TimestepBlock):
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = conv_nd( self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1, dtype=dtype dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
) )
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype) self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
def forward(self, x, emb): def forward(self, x, emb):
""" """
@ -295,142 +244,6 @@ class ResBlock(TimestepBlock):
h = self.out_layers(h) h = self.out_layers(h)
return self.skip_connection(x) + h return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
if use_new_attention_order:
# split qkv before split heads
self.attention = QKVAttention(self.num_heads)
else:
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
#return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
def count_flops_attn(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class Timestep(nn.Module): class Timestep(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
@ -503,8 +316,10 @@ class UNetModel(nn.Module):
use_linear_in_transformer=False, use_linear_in_transformer=False,
adm_in_channels=None, adm_in_channels=None,
transformer_depth_middle=None, transformer_depth_middle=None,
device=None,
): ):
super().__init__() super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
if use_spatial_transformer: if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
@ -564,9 +379,9 @@ class UNetModel(nn.Module):
time_embed_dim = model_channels * 4 time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim, dtype=self.dtype), linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim, dtype=self.dtype), linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
) )
if self.num_classes is not None: if self.num_classes is not None:
@ -579,9 +394,9 @@ class UNetModel(nn.Module):
assert adm_in_channels is not None assert adm_in_channels is not None
self.label_emb = nn.Sequential( self.label_emb = nn.Sequential(
nn.Sequential( nn.Sequential(
linear(adm_in_channels, time_embed_dim, dtype=self.dtype), linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim, dtype=self.dtype), linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
) )
) )
else: else:
@ -590,7 +405,7 @@ class UNetModel(nn.Module):
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [
TimestepEmbedSequential( TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype) conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
) )
] ]
) )
@ -609,7 +424,8 @@ class UNetModel(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype dtype=self.dtype,
device=device,
) )
] ]
ch = mult * model_channels ch = mult * model_channels
@ -628,17 +444,10 @@ class UNetModel(nn.Module):
disabled_sa = False disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append( layers.append(SpatialTransformer(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
@ -657,11 +466,12 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
down=True, down=True,
dtype=self.dtype dtype=self.dtype,
device=device,
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device
) )
) )
) )
@ -686,18 +496,13 @@ class UNetModel(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype dtype=self.dtype,
device=device,
), ),
AttentionBlock( SpatialTransformer( # always uses a self-attn
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
), ),
ResBlock( ResBlock(
ch, ch,
@ -706,7 +511,8 @@ class UNetModel(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype dtype=self.dtype,
device=device,
), ),
) )
self._feature_size += ch self._feature_size += ch
@ -724,7 +530,8 @@ class UNetModel(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype dtype=self.dtype,
device=device,
) )
] ]
ch = model_channels * mult ch = model_channels * mult
@ -744,16 +551,10 @@ class UNetModel(nn.Module):
if not exists(num_attention_blocks) or i < num_attention_blocks[level]: if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
layers.append( layers.append(
AttentionBlock( SpatialTransformer(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
) )
) )
if level and i == self.num_res_blocks[level]: if level and i == self.num_res_blocks[level]:
@ -768,43 +569,28 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
up=True, up=True,
dtype=self.dtype dtype=self.dtype,
device=device,
) )
if resblock_updown if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype) else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device)
) )
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch self._feature_size += ch
self.out = nn.Sequential( self.out = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype), nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)), zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
) )
if self.predict_codebook_ids: if self.predict_codebook_ids:
self.id_predictor = nn.Sequential( self.id_predictor = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype), nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
conv_nd(dims, model_channels, n_embed, 1), conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
) )
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
""" """
Apply the model to an input batch. Apply the model to an input batch.

View File

@ -12,14 +12,14 @@ class ModelType(Enum):
V_PREDICTION = 2 V_PREDICTION = 2
class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__() super().__init__()
unet_config = model_config.unet_config unet_config = model_config.unet_config
self.latent_format = model_config.latent_format self.latent_format = model_config.latent_format
self.model_config = model_config self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config) self.diffusion_model = UNetModel(**unet_config, device=device)
self.model_type = model_type self.model_type = model_type
self.adm_channels = unet_config.get("adm_in_channels", None) self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None: if self.adm_channels is None:
@ -107,8 +107,8 @@ class BaseModel(torch.nn.Module):
class SD21UNCLIP(BaseModel): class SD21UNCLIP(BaseModel):
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION): def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type) super().__init__(model_config, model_type, device=device)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config) self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
@ -143,13 +143,13 @@ class SD21UNCLIP(BaseModel):
return adm_out return adm_out
class SDInpaint(BaseModel): class SDInpaint(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type) super().__init__(model_config, model_type, device=device)
self.concat_keys = ("mask", "masked_image") self.concat_keys = ("mask", "masked_image")
class SDXLRefiner(BaseModel): class SDXLRefiner(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type) super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256) self.embedder = Timestep(256)
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
@ -164,7 +164,6 @@ class SDXLRefiner(BaseModel):
else: else:
aesthetic_score = kwargs.get("aesthetic_score", 6) aesthetic_score = kwargs.get("aesthetic_score", 6)
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
out = [] out = []
out.append(self.embedder(torch.Tensor([height]))) out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([width]))) out.append(self.embedder(torch.Tensor([width])))
@ -175,8 +174,8 @@ class SDXLRefiner(BaseModel):
return torch.cat((clip_pooled.to(flat.device), flat), dim=1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SDXL(BaseModel): class SDXL(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type) super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256) self.embedder = Timestep(256)
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
@ -188,7 +187,6 @@ class SDXL(BaseModel):
target_width = kwargs.get("target_width", width) target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height) target_height = kwargs.get("target_height", height)
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
out = [] out = []
out.append(self.embedder(torch.Tensor([height]))) out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([width]))) out.append(self.embedder(torch.Tensor([width])))

View File

@ -364,6 +364,7 @@ def text_encoder_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
#NOTE: on a Ryzen 5 7600X with 4080 it's faster to shift to GPU
if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough. if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough.
return get_torch_device() return get_torch_device()
else: else:
@ -534,7 +535,7 @@ def should_use_fp16(device=None, model_params=0):
return False return False
#FP16 is just broken on these cards #FP16 is just broken on these cards
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450"] nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX"]
for x in nvidia_16_series: for x in nvidia_16_series:
if x in props.name: if x in props.name:
return False return False

View File

@ -19,11 +19,11 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
strength = 1.0 strength = 1.0
if 'timestep_start' in cond[1]: if 'timestep_start' in cond[1]:
timestep_start = cond[1]['timestep_start'] timestep_start = cond[1]['timestep_start']
if timestep_in > timestep_start: if timestep_in[0] > timestep_start:
return None return None
if 'timestep_end' in cond[1]: if 'timestep_end' in cond[1]:
timestep_end = cond[1]['timestep_end'] timestep_end = cond[1]['timestep_end']
if timestep_in < timestep_end: if timestep_in[0] < timestep_end:
return None return None
if 'area' in cond[1]: if 'area' in cond[1]:
area = cond[1]['area'] area = cond[1]['area']

View File

@ -70,13 +70,22 @@ def load_lora(lora, to_load):
alpha = lora[alpha_name].item() alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name) loaded_keys.add(alpha_name)
A_name = "{}.lora_up.weight".format(x) regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
A_name = None
if regular_lora in lora.keys():
A_name = regular_lora
B_name = "{}.lora_down.weight".format(x) B_name = "{}.lora_down.weight".format(x)
mid_name = "{}.lora_mid.weight".format(x) mid_name = "{}.lora_mid.weight".format(x)
elif diffusers_lora in lora.keys():
A_name = diffusers_lora
B_name = "{}_lora.down.weight".format(x)
mid_name = None
if A_name in lora.keys(): if A_name is not None:
mid = None mid = None
if mid_name in lora.keys(): if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name] mid = lora[mid_name]
loaded_keys.add(mid_name) loaded_keys.add(mid_name)
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
@ -202,6 +211,11 @@ def model_lora_keys_unet(model, key_map={}):
if k.endswith(".weight"): if k.endswith(".weight"):
key_lora = k[:-len(".weight")].replace(".", "_") key_lora = k[:-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k]) key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
diffusers_lora_key = "unet.{}".format(k[:-len(".weight")].replace(".to_", ".processor.to_"))
if diffusers_lora_key.endswith(".to_out.0"):
diffusers_lora_key = diffusers_lora_key[:-2]
key_map[diffusers_lora_key] = "diffusion_model.{}".format(diffusers_keys[k])
return key_map return key_map
def set_attr(obj, attr, value): def set_attr(obj, attr, value):
@ -864,7 +878,7 @@ def load_controlnet(ckpt_path, model=None):
use_fp16 = model_management.should_use_fp16() use_fp16 = model_management.should_use_fp16()
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
controlnet_config.pop("out_channels") controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = 3 controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = cldm.ControlNet(**controlnet_config) control_model = cldm.ControlNet(**controlnet_config)
if pth: if pth:
@ -1169,8 +1183,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.") model = model_config.get_model(sd, "model.diffusion_model.", device=offload_device)
model = model.to(offload_device)
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, "model.diffusion_model.")
if output_vae: if output_vae:

View File

@ -91,13 +91,15 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_up_textual_embeddings(self, tokens, current_embeds): def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = [] out_tokens = []
next_new_token = token_dict_size = current_embeds.weight.shape[0] next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1
embedding_weights = [] embedding_weights = []
for x in tokens: for x in tokens:
tokens_temp = [] tokens_temp = []
for y in x: for y in x:
if isinstance(y, int): if isinstance(y, int):
if y == token_dict_size: #EOS token
y = -1
tokens_temp += [y] tokens_temp += [y]
else: else:
if y.shape[0] == current_embeds.weight.shape[1]: if y.shape[0] == current_embeds.weight.shape[1]:
@ -110,15 +112,21 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_temp += [self.empty_tokens[0][-1]] tokens_temp += [self.empty_tokens[0][-1]]
out_tokens += [tokens_temp] out_tokens += [tokens_temp]
if len(embedding_weights) > 0:
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
n = token_dict_size n = token_dict_size
if len(embedding_weights) > 0:
new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1]
for x in embedding_weights: for x in embedding_weights:
new_embedding.weight[n] = x new_embedding.weight[n] = x
n += 1 n += 1
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding
self.transformer.set_input_embeddings(new_embedding) self.transformer.set_input_embeddings(new_embedding)
return out_tokens
processed_tokens = []
for x in out_tokens:
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
return processed_tokens
def forward(self, tokens): def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings() backup_embeds = self.transformer.get_input_embeddings()

View File

@ -109,8 +109,8 @@ class SDXLRefiner(supported_models_base.BASE):
latent_format = latent_formats.SDXL latent_format = latent_formats.SDXL
def get_model(self, state_dict, prefix=""): def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXLRefiner(self) return model_base.SDXLRefiner(self, device=device)
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
keys_to_replace = {} keys_to_replace = {}
@ -152,8 +152,8 @@ class SDXL(supported_models_base.BASE):
else: else:
return model_base.ModelType.EPS return model_base.ModelType.EPS
def get_model(self, state_dict, prefix=""): def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix)) return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
keys_to_replace = {} keys_to_replace = {}

View File

@ -53,13 +53,13 @@ class BASE:
for x in self.unet_extra_config: for x in self.unet_extra_config:
self.unet_config[x] = self.unet_extra_config[x] self.unet_config[x] = self.unet_extra_config[x]
def get_model(self, state_dict, prefix=""): def get_model(self, state_dict, prefix="", device=None):
if self.inpaint_model(): if self.inpaint_model():
return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix)) return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix), device=device)
elif self.noise_aug_config is not None: elif self.noise_aug_config is not None:
return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix)) return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
else: else:
return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix)) return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
return state_dict return state_dict

View File

@ -6,6 +6,8 @@ import folder_paths
import json import json
import os import os
from comfy.cli_args import args
class ModelMergeSimple: class ModelMergeSimple:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -101,8 +103,7 @@ class CheckpointSave:
if prompt is not None: if prompt is not None:
prompt_info = json.dumps(prompt) prompt_info = json.dumps(prompt)
metadata = {"prompt": prompt_info} metadata = {}
enable_modelspec = True enable_modelspec = True
if isinstance(model.model, comfy.model_base.SDXL): if isinstance(model.model, comfy.model_base.SDXL):
@ -127,6 +128,8 @@ class CheckpointSave:
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v" metadata["modelspec.predict_key"] = "v"
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None: if extra_pnginfo is not None:
for x in extra_pnginfo: for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x]) metadata[x] = json.dumps(extra_pnginfo[x])

View File

@ -40,7 +40,8 @@ def cuda_malloc_supported():
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M", blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
"GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620", "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620",
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000", "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000"} "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000",
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M"}
try: try:
names = get_gpu_names() names = get_gpu_names()

View File

@ -42,10 +42,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
# check if node wants the lists # check if node wants the lists
intput_is_list = False input_is_list = False
if hasattr(obj, "INPUT_IS_LIST"): if hasattr(obj, "INPUT_IS_LIST"):
intput_is_list = obj.INPUT_IS_LIST input_is_list = obj.INPUT_IS_LIST
if len(input_data_all) == 0:
max_len_input = 0
else:
max_len_input = max([len(x) for x in input_data_all.values()]) max_len_input = max([len(x) for x in input_data_all.values()])
# get a slice of inputs, repeat last input when list isn't long enough # get a slice of inputs, repeat last input when list isn't long enough
@ -56,10 +59,14 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
return d_new return d_new
results = [] results = []
if intput_is_list: if input_is_list:
if allow_interrupt: if allow_interrupt:
nodes.before_node_execution() nodes.before_node_execution()
results.append(getattr(obj, func)(**input_data_all)) results.append(getattr(obj, func)(**input_data_all))
elif max_len_input == 0:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)())
else: else:
for i in range(max_len_input): for i in range(max_len_input):
if allow_interrupt: if allow_interrupt:

View File

@ -160,6 +160,8 @@ if __name__ == "__main__":
if args.auto_launch: if args.auto_launch:
def startup_server(address, port): def startup_server(address, port):
import webbrowser import webbrowser
if os.name == 'nt' and address == '0.0.0.0':
address = '127.0.0.1'
webbrowser.open(f"http://{address}:{port}") webbrowser.open(f"http://{address}:{port}")
call_on_start = startup_server call_on_start = startup_server

View File

@ -26,6 +26,8 @@ import comfy.utils
import comfy.clip_vision import comfy.clip_vision
import comfy.model_management import comfy.model_management
from comfy.cli_args import args
import importlib import importlib
import folder_paths import folder_paths
@ -352,12 +354,22 @@ class SaveLatent:
if prompt is not None: if prompt is not None:
prompt_info = json.dumps(prompt) prompt_info = json.dumps(prompt)
metadata = None
if not args.disable_metadata:
metadata = {"prompt": prompt_info} metadata = {"prompt": prompt_info}
if extra_pnginfo is not None: if extra_pnginfo is not None:
for x in extra_pnginfo: for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x]) metadata[x] = json.dumps(extra_pnginfo[x])
file = f"{filename}_{counter:05}_.latent" file = f"{filename}_{counter:05}_.latent"
results = list()
results.append({
"filename": file,
"subfolder": subfolder,
"type": "output"
})
file = os.path.join(full_output_folder, file) file = os.path.join(full_output_folder, file)
output = {} output = {}
@ -365,7 +377,7 @@ class SaveLatent:
output["latent_format_version_0"] = torch.tensor([]) output["latent_format_version_0"] = torch.tensor([])
comfy.utils.save_torch_file(output, file, metadata=metadata) comfy.utils.save_torch_file(output, file, metadata=metadata)
return {} return { "ui": { "latents": results } }
class LoadLatent: class LoadLatent:
@ -1043,6 +1055,47 @@ class LatentComposite:
samples_out["samples"] = s samples_out["samples"] = s
return (samples_out,) return (samples_out,)
class LatentBlend:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"samples1": ("LATENT",),
"samples2": ("LATENT",),
"blend_factor": ("FLOAT", {
"default": 0.5,
"min": 0,
"max": 1,
"step": 0.01
}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "blend"
CATEGORY = "_for_testing"
def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
samples_out = samples1.copy()
samples1 = samples1["samples"]
samples2 = samples2["samples"]
if samples1.shape != samples2.shape:
samples2.permute(0, 3, 1, 2)
samples2 = comfy.utils.common_upscale(samples2, samples1.shape[3], samples1.shape[2], 'bicubic', crop='center')
samples2.permute(0, 2, 3, 1)
samples_blended = self.blend_mode(samples1, samples2, blend_mode)
samples_blended = samples1 * blend_factor + samples_blended * (1 - blend_factor)
samples_out["samples"] = samples_blended
return (samples_out,)
def blend_mode(self, img1, img2, mode):
if mode == "normal":
return img2
else:
raise ValueError(f"Unsupported blend mode: {mode}")
class LatentCrop: class LatentCrop:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -1214,6 +1267,8 @@ class SaveImage:
for image in images: for image in images:
i = 255. * image.cpu().numpy() i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = None
if not args.disable_metadata:
metadata = PngInfo() metadata = PngInfo()
if prompt is not None: if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt)) metadata.add_text("prompt", json.dumps(prompt))
@ -1487,6 +1542,7 @@ NODE_CLASS_MAPPINGS = {
"KSamplerAdvanced": KSamplerAdvanced, "KSamplerAdvanced": KSamplerAdvanced,
"SetLatentNoiseMask": SetLatentNoiseMask, "SetLatentNoiseMask": SetLatentNoiseMask,
"LatentComposite": LatentComposite, "LatentComposite": LatentComposite,
"LatentBlend": LatentBlend,
"LatentRotate": LatentRotate, "LatentRotate": LatentRotate,
"LatentFlip": LatentFlip, "LatentFlip": LatentFlip,
"LatentCrop": LatentCrop, "LatentCrop": LatentCrop,
@ -1558,6 +1614,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LatentUpscale": "Upscale Latent", "LatentUpscale": "Upscale Latent",
"LatentUpscaleBy": "Upscale Latent By", "LatentUpscaleBy": "Upscale Latent By",
"LatentComposite": "Latent Composite", "LatentComposite": "Latent Composite",
"LatentBlend": "Latent Blend",
"LatentFromBatch" : "Latent From Batch", "LatentFromBatch" : "Latent From Batch",
"RepeatLatentBatch": "Repeat Latent Batch", "RepeatLatentBatch": "Repeat Latent Batch",
# Image # Image

View File

@ -159,13 +159,64 @@
"\n" "\n"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {
"id": "kkkkkkkkkkkkkkk"
},
"source": [
"### Run ComfyUI with cloudflared (Recommended Way)\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jjjjjjjjjjjjjj"
},
"outputs": [],
"source": [
"!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb\n",
"!dpkg -i cloudflared-linux-amd64.deb\n",
"\n",
"import subprocess\n",
"import threading\n",
"import time\n",
"import socket\n",
"import urllib.request\n",
"\n",
"def iframe_thread(port):\n",
" while True:\n",
" time.sleep(0.5)\n",
" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
" result = sock.connect_ex(('127.0.0.1', port))\n",
" if result == 0:\n",
" break\n",
" sock.close()\n",
" print(\"\\nComfyUI finished loading, trying to launch cloudflared (if it gets stuck here cloudflared is having issues)\\n\")\n",
"\n",
" p = subprocess.Popen([\"cloudflared\", \"tunnel\", \"--url\", \"http://127.0.0.1:{}\".format(port)], stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
" for line in p.stderr:\n",
" l = line.decode()\n",
" if \"trycloudflare.com \" in l:\n",
" print(\"This is the URL to access ComfyUI:\", l[l.find(\"http\"):], end='')\n",
" #print(l, end='')\n",
"\n",
"\n",
"threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n",
"\n",
"!python main.py --dont-print-server"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "kkkkkkkkkkkkkk" "id": "kkkkkkkkkkkkkk"
}, },
"source": [ "source": [
"### Run ComfyUI with localtunnel (Recommended Way)\n", "### Run ComfyUI with localtunnel\n",
"\n", "\n",
"\n" "\n"
] ]

View File

@ -1,5 +1,4 @@
torch torch
torchdiffeq
torchsde torchsde
einops einops
transformers>=4.25.1 transformers>=4.25.1

View File

@ -345,6 +345,11 @@ class PromptServer():
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
system_stats = { system_stats = {
"system": {
"os": os.name,
"python_version": sys.version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
},
"devices": [ "devices": [
{ {
"name": device_name, "name": device_name,

View File

@ -1,4 +1,4 @@
import {app} from "/scripts/app.js"; import {app} from "../../scripts/app.js";
// Adds filtering to combo context menus // Adds filtering to combo context menus
@ -27,10 +27,13 @@ const ext = {
const clickedComboValue = currentNode.widgets const clickedComboValue = currentNode.widgets
.filter(w => w.type === "combo" && w.options.values.length === values.length) .filter(w => w.type === "combo" && w.options.values.length === values.length)
.find(w => w.options.values.every((v, i) => v === values[i])) .find(w => w.options.values.every((v, i) => v === values[i]))
.value; ?.value;
let selectedIndex = values.findIndex(v => v === clickedComboValue); let selectedIndex = clickedComboValue ? values.findIndex(v => v === clickedComboValue) : 0;
let selectedItem = displayedItems?.[selectedIndex]; if (selectedIndex < 0) {
selectedIndex = 0;
}
let selectedItem = displayedItems[selectedIndex];
updateSelected(); updateSelected();
// Apply highlighting to the selected item // Apply highlighting to the selected item

View File

@ -0,0 +1,25 @@
import { app } from "/scripts/app.js";
const id = "Comfy.LinkRenderMode";
const ext = {
name: id,
async setup(app) {
app.ui.settings.addSetting({
id,
name: "Link Render Mode",
defaultValue: 2,
type: "combo",
options: LiteGraph.LINK_RENDER_MODES.map((m, i) => ({
value: i,
text: m,
selected: i == app.canvas.links_render_mode,
})),
onChange(value) {
app.canvas.links_render_mode = +value;
app.graph.setDirtyCanvas(true);
},
});
},
};
app.registerExtension(ext);

View File

@ -2,7 +2,7 @@ import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js";
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
const CONVERTED_TYPE = "converted-widget"; const CONVERTED_TYPE = "converted-widget";
const VALID_TYPES = ["STRING", "combo", "number"]; const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"];
function isConvertableWidget(widget, config) { function isConvertableWidget(widget, config) {
return VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0]); return VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0]);

View File

@ -9835,7 +9835,11 @@ LGraphNode.prototype.executeAction = function(action)
ctx.textAlign = "center"; ctx.textAlign = "center";
ctx.fillStyle = text_color; ctx.fillStyle = text_color;
ctx.fillText( ctx.fillText(
w.label || w.name + " " + Number(w.value).toFixed(3), w.label || w.name + " " + Number(w.value).toFixed(
w.options.precision != null
? w.options.precision
: 3
),
widget_width * 0.5, widget_width * 0.5,
y + H * 0.7 y + H * 0.7
); );
@ -13835,7 +13839,7 @@ LGraphNode.prototype.executeAction = function(action)
if (!disabled) { if (!disabled) {
element.addEventListener("click", inner_onclick); element.addEventListener("click", inner_onclick);
} }
if (options.autoopen) { if (!disabled && options.autoopen) {
LiteGraph.pointerListenerAdd(element,"enter",inner_over); LiteGraph.pointerListenerAdd(element,"enter",inner_over);
} }

View File

@ -264,6 +264,15 @@ class ComfyApi extends EventTarget {
} }
} }
/**
* Gets system & device stats
* @returns System stats such as python version, OS, per device info
*/
async getSystemStats() {
const res = await this.fetchApi("/system_stats");
return await res.json();
}
/** /**
* Sends a POST request to the API * Sends a POST request to the API
* @param {*} type The endpoint to post to * @param {*} type The endpoint to post to

View File

@ -1,3 +1,4 @@
import { ComfyLogging } from "./logging.js";
import { ComfyWidgets } from "./widgets.js"; import { ComfyWidgets } from "./widgets.js";
import { ComfyUI, $el } from "./ui.js"; import { ComfyUI, $el } from "./ui.js";
import { api } from "./api.js"; import { api } from "./api.js";
@ -31,6 +32,7 @@ export class ComfyApp {
constructor() { constructor() {
this.ui = new ComfyUI(this); this.ui = new ComfyUI(this);
this.logging = new ComfyLogging(this);
/** /**
* List of extensions that are registered with the app * List of extensions that are registered with the app
@ -768,6 +770,19 @@ export class ComfyApp {
} }
block_default = true; block_default = true;
} }
if (e.keyCode == 66 && e.ctrlKey) {
if (this.selected_nodes) {
for (var i in this.selected_nodes) {
if (this.selected_nodes[i].mode === 4) { // never
this.selected_nodes[i].mode = 0; // always
} else {
this.selected_nodes[i].mode = 4; // never
}
}
}
block_default = true;
}
} }
this.graph.change(); this.graph.change();
@ -914,14 +929,21 @@ export class ComfyApp {
const origDrawNode = LGraphCanvas.prototype.drawNode; const origDrawNode = LGraphCanvas.prototype.drawNode;
LGraphCanvas.prototype.drawNode = function (node, ctx) { LGraphCanvas.prototype.drawNode = function (node, ctx) {
var editor_alpha = this.editor_alpha; var editor_alpha = this.editor_alpha;
var old_color = node.bgcolor;
if (node.mode === 2) { // never if (node.mode === 2) { // never
this.editor_alpha = 0.4; this.editor_alpha = 0.4;
} }
if (node.mode === 4) { // never
node.bgcolor = "#FF00FF";
this.editor_alpha = 0.2;
}
const res = origDrawNode.apply(this, arguments); const res = origDrawNode.apply(this, arguments);
this.editor_alpha = editor_alpha; this.editor_alpha = editor_alpha;
node.bgcolor = old_color;
return res; return res;
}; };
@ -1003,6 +1025,7 @@ export class ComfyApp {
*/ */
async #loadExtensions() { async #loadExtensions() {
const extensions = await api.getExtensions(); const extensions = await api.getExtensions();
this.logging.addEntry("Comfy.App", "debug", { Extensions: extensions });
for (const ext of extensions) { for (const ext of extensions) {
try { try {
await import(api.apiURL(ext)); await import(api.apiURL(ext));
@ -1286,6 +1309,9 @@ export class ComfyApp {
(t) => `<li>${t}</li>` (t) => `<li>${t}</li>`
).join("")}</ul>Nodes that have failed to load will show as red on the graph.` ).join("")}</ul>Nodes that have failed to load will show as red on the graph.`
); );
this.logging.addEntry("Comfy.App", "warn", {
MissingNodes: missingNodeTypes,
});
} }
} }
@ -1308,7 +1334,7 @@ export class ComfyApp {
continue; continue;
} }
if (node.mode === 2) { if (node.mode === 2 || node.mode === 4) {
// Don't serialize muted nodes // Don't serialize muted nodes
continue; continue;
} }
@ -1331,12 +1357,36 @@ export class ComfyApp {
let parent = node.getInputNode(i); let parent = node.getInputNode(i);
if (parent) { if (parent) {
let link = node.getInputLink(i); let link = node.getInputLink(i);
while (parent && parent.isVirtualNode) { while (parent.mode === 4 || parent.isVirtualNode) {
let found = false;
if (parent.isVirtualNode) {
link = parent.getInputLink(link.origin_slot); link = parent.getInputLink(link.origin_slot);
if (link) { if (link) {
parent = parent.getInputNode(link.origin_slot); parent = parent.getInputNode(link.target_slot);
} else { if (parent) {
parent = null; found = true;
}
}
} else if (link && parent.mode === 4) {
let all_inputs = [link.origin_slot];
if (parent.inputs) {
all_inputs = all_inputs.concat(Object.keys(parent.inputs))
for (let parent_input in all_inputs) {
parent_input = all_inputs[parent_input];
if (parent.inputs[parent_input].type === node.inputs[i].type) {
link = parent.getInputLink(parent_input);
if (link) {
parent = parent.getInputNode(parent_input);
}
found = true;
break;
}
}
}
}
if (!found) {
break;
} }
} }

367
web/scripts/logging.js Normal file
View File

@ -0,0 +1,367 @@
import { $el, ComfyDialog } from "./ui.js";
import { api } from "./api.js";
$el("style", {
textContent: `
.comfy-logging-logs {
display: grid;
color: var(--fg-color);
white-space: pre-wrap;
}
.comfy-logging-log {
display: contents;
}
.comfy-logging-title {
background: var(--tr-even-bg-color);
font-weight: bold;
margin-bottom: 5px;
text-align: center;
}
.comfy-logging-log div {
background: var(--row-bg);
padding: 5px;
}
`,
parent: document.body,
});
// Stringify function supporting max depth and removal of circular references
// https://stackoverflow.com/a/57193345
function stringify(val, depth, replacer, space, onGetObjID) {
depth = isNaN(+depth) ? 1 : depth;
var recursMap = new WeakMap();
function _build(val, depth, o, a, r) {
// (JSON.stringify() has it's own rules, which we respect here by using it for property iteration)
return !val || typeof val != "object"
? val
: ((r = recursMap.has(val)),
recursMap.set(val, true),
(a = Array.isArray(val)),
r
? (o = (onGetObjID && onGetObjID(val)) || null)
: JSON.stringify(val, function (k, v) {
if (a || depth > 0) {
if (replacer) v = replacer(k, v);
if (!k) return (a = Array.isArray(v)), (val = v);
!o && (o = a ? [] : {});
o[k] = _build(v, a ? depth : depth - 1);
}
}),
o === void 0 ? (a ? [] : {}) : o);
}
return JSON.stringify(_build(val, depth), null, space);
}
const jsonReplacer = (k, v, ui) => {
if (v instanceof Array && v.length === 1) {
v = v[0];
}
if (v instanceof Date) {
v = v.toISOString();
if (ui) {
v = v.split("T")[1];
}
}
if (v instanceof Error) {
let err = "";
if (v.name) err += v.name + "\n";
if (v.message) err += v.message + "\n";
if (v.stack) err += v.stack + "\n";
if (!err) {
err = v.toString();
}
v = err;
}
return v;
};
const fileInput = $el("input", {
type: "file",
accept: ".json",
style: { display: "none" },
parent: document.body,
});
class ComfyLoggingDialog extends ComfyDialog {
constructor(logging) {
super();
this.logging = logging;
}
clear() {
this.logging.clear();
this.show();
}
export() {
const blob = new Blob([stringify([...this.logging.entries], 20, jsonReplacer, "\t")], {
type: "application/json",
});
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: `comfyui-logs-${Date.now()}.json`,
style: { display: "none" },
parent: document.body,
});
a.click();
setTimeout(function () {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
}
import() {
fileInput.onchange = () => {
const reader = new FileReader();
reader.onload = () => {
fileInput.remove();
try {
const obj = JSON.parse(reader.result);
if (obj instanceof Array) {
this.show(obj);
} else {
throw new Error("Invalid file selected.");
}
} catch (error) {
alert("Unable to load logs: " + error.message);
}
};
reader.readAsText(fileInput.files[0]);
};
fileInput.click();
}
createButtons() {
return [
$el("button", {
type: "button",
textContent: "Clear",
onclick: () => this.clear(),
}),
$el("button", {
type: "button",
textContent: "Export logs...",
onclick: () => this.export(),
}),
$el("button", {
type: "button",
textContent: "View exported logs...",
onclick: () => this.import(),
}),
...super.createButtons(),
];
}
getTypeColor(type) {
switch (type) {
case "error":
return "red";
case "warn":
return "orange";
case "debug":
return "dodgerblue";
}
}
show(entries) {
if (!entries) entries = this.logging.entries;
this.element.style.width = "100%";
const cols = {
source: "Source",
type: "Type",
timestamp: "Timestamp",
message: "Message",
};
const keys = Object.keys(cols);
const headers = Object.values(cols).map((title) =>
$el("div.comfy-logging-title", {
textContent: title,
})
);
const rows = entries.map((entry, i) => {
return $el(
"div.comfy-logging-log",
{
$: (el) => el.style.setProperty("--row-bg", `var(--tr-${i % 2 ? "even" : "odd"}-bg-color)`),
},
keys.map((key) => {
let v = entry[key];
let color;
if (key === "type") {
color = this.getTypeColor(v);
} else {
v = jsonReplacer(key, v, true);
if (typeof v === "object") {
v = stringify(v, 5, jsonReplacer, " ");
}
}
return $el("div", {
style: {
color,
},
textContent: v,
});
})
);
});
const grid = $el(
"div.comfy-logging-logs",
{
style: {
gridTemplateColumns: `repeat(${headers.length}, 1fr)`,
},
},
[...headers, ...rows]
);
const els = [grid];
if (!this.logging.enabled) {
els.unshift(
$el("h3", {
style: { textAlign: "center" },
textContent: "Logging is disabled",
})
);
}
super.show($el("div", els));
}
}
export class ComfyLogging {
/**
* @type Array<{ source: string, type: string, timestamp: Date, message: any }>
*/
entries = [];
#enabled;
#console = {};
get enabled() {
return this.#enabled;
}
set enabled(value) {
if (value === this.#enabled) return;
if (value) {
this.patchConsole();
} else {
this.unpatchConsole();
}
this.#enabled = value;
}
constructor(app) {
this.app = app;
this.dialog = new ComfyLoggingDialog(this);
this.addSetting();
this.catchUnhandled();
this.addInitData();
}
addSetting() {
const settingId = "Comfy.Logging.Enabled";
const htmlSettingId = settingId.replaceAll(".", "-");
const setting = this.app.ui.settings.addSetting({
id: settingId,
name: settingId,
defaultValue: true,
type: (name, setter, value) => {
return $el("tr", [
$el("td", [
$el("label", {
textContent: "Logging",
for: htmlSettingId,
}),
]),
$el("td", [
$el("input", {
id: htmlSettingId,
type: "checkbox",
checked: value,
onchange: (event) => {
setter((this.enabled = event.target.checked));
},
}),
$el("button", {
textContent: "View Logs",
onclick: () => {
this.app.ui.settings.element.close();
this.dialog.show();
},
style: {
fontSize: "14px",
display: "block",
marginTop: "5px",
},
}),
]),
]);
},
});
this.enabled = setting.value;
}
patchConsole() {
// Capture common console outputs
const self = this;
for (const type of ["log", "warn", "error", "debug"]) {
const orig = console[type];
this.#console[type] = orig;
console[type] = function () {
orig.apply(console, arguments);
self.addEntry("console", type, ...arguments);
};
}
}
unpatchConsole() {
// Restore original console functions
for (const type of Object.keys(this.#console)) {
console[type] = this.#console[type];
}
this.#console = {};
}
catchUnhandled() {
// Capture uncaught errors
window.addEventListener("error", (e) => {
this.addEntry("window", "error", e.error ?? "Unknown error");
return false;
});
window.addEventListener("unhandledrejection", (e) => {
this.addEntry("unhandledrejection", "error", e.reason ?? "Unknown error");
});
}
clear() {
this.entries = [];
}
addEntry(source, type, ...args) {
if (this.enabled) {
this.entries.push({
source,
type,
timestamp: new Date(),
message: args,
});
}
}
log(source, ...args) {
this.addEntry(source, "log", ...args);
}
async addInitData() {
if (!this.enabled) return;
const source = "ComfyUI.Logging";
this.addEntry(source, "debug", { UserAgent: navigator.userAgent });
const systemStats = await api.getSystemStats();
this.addEntry(source, "debug", systemStats);
}
}

View File

@ -234,7 +234,7 @@ class ComfySettingsDialog extends ComfyDialog {
localStorage[settingId] = JSON.stringify(value); localStorage[settingId] = JSON.stringify(value);
} }
addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "",}) { addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined}) {
if (!id) { if (!id) {
throw new Error("Settings must have an ID"); throw new Error("Settings must have an ID");
} }
@ -347,6 +347,32 @@ class ComfySettingsDialog extends ComfyDialog {
]), ]),
]); ]);
break; break;
case "combo":
element = $el("tr", [
labelCell,
$el("td", [
$el(
"select",
{
oninput: (e) => {
setter(e.target.value);
},
},
(typeof options === "function" ? options(value) : options || []).map((opt) => {
if (typeof opt === "string") {
opt = { text: opt };
}
const v = opt.value ?? opt.text;
return $el("option", {
value: v,
textContent: opt.text,
selected: value + "" === v + "",
});
})
),
]),
]);
break;
case "text": case "text":
default: default:
if (type !== "text") { if (type !== "text") {
@ -480,7 +506,7 @@ class ComfyList {
hide() { hide() {
this.element.style.display = "none"; this.element.style.display = "none";
this.button.textContent = "See " + this.#text; this.button.textContent = "View " + this.#text;
} }
toggle() { toggle() {
@ -542,6 +568,13 @@ export class ComfyUI {
defaultValue: "", defaultValue: "",
}); });
this.settings.addSetting({
id: "Comfy.DisableSliders",
name: "Disable sliders.",
type: "boolean",
defaultValue: false,
});
const fileInput = $el("input", { const fileInput = $el("input", {
id: "comfy-file-input", id: "comfy-file-input",
type: "file", type: "file",

View File

@ -79,8 +79,8 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
return valueControl; return valueControl;
}; };
function seedWidget(node, inputName, inputData) { function seedWidget(node, inputName, inputData, app) {
const seed = ComfyWidgets.INT(node, inputName, inputData); const seed = ComfyWidgets.INT(node, inputName, inputData, app);
const seedControl = addValueControlWidget(node, seed.widget, "randomize"); const seedControl = addValueControlWidget(node, seed.widget, "randomize");
seed.widget.linkedWidgets = [seedControl]; seed.widget.linkedWidgets = [seedControl];
@ -250,19 +250,29 @@ function addMultilineWidget(node, name, opts, app) {
return { minWidth: 400, minHeight: 200, widget }; return { minWidth: 400, minHeight: 200, widget };
} }
function isSlider(display, app) {
if (app.ui.settings.getSettingValue("Comfy.DisableSliders")) {
return "number"
}
return (display==="slider") ? "slider" : "number"
}
export const ComfyWidgets = { export const ComfyWidgets = {
"INT:seed": seedWidget, "INT:seed": seedWidget,
"INT:noise_seed": seedWidget, "INT:noise_seed": seedWidget,
FLOAT(node, inputName, inputData) { FLOAT(node, inputName, inputData, app) {
let widgetType = isSlider(inputData[1]["display"], app);
const { val, config } = getNumberDefaults(inputData, 0.5); const { val, config } = getNumberDefaults(inputData, 0.5);
return { widget: node.addWidget("number", inputName, val, () => {}, config) }; return { widget: node.addWidget(widgetType, inputName, val, () => {}, config) };
}, },
INT(node, inputName, inputData) { INT(node, inputName, inputData, app) {
let widgetType = isSlider(inputData[1]["display"], app);
const { val, config } = getNumberDefaults(inputData, 1); const { val, config } = getNumberDefaults(inputData, 1);
Object.assign(config, { precision: 0 }); Object.assign(config, { precision: 0 });
return { return {
widget: node.addWidget( widget: node.addWidget(
"number", widgetType,
inputName, inputName,
val, val,
function (v) { function (v) {
@ -273,6 +283,18 @@ export const ComfyWidgets = {
), ),
}; };
}, },
BOOLEAN(node, inputName, inputData) {
let defaultVal = inputData[1]["default"];
return {
widget: node.addWidget(
"toggle",
inputName,
defaultVal,
() => {},
{"on": inputData[1].label_on, "off": inputData[1].label_off}
)
};
},
STRING(node, inputName, inputData, app) { STRING(node, inputName, inputData, app) {
const defaultVal = inputData[1].default || ""; const defaultVal = inputData[1].default || "";
const multiline = !!inputData[1].multiline; const multiline = !!inputData[1].multiline;