Merge branch 'comfyanonymous:master' into master

This commit is contained in:
BlenderNeko 2023-05-06 15:16:33 +02:00 committed by GitHub
commit 3859cbb98d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 699 additions and 309 deletions

View File

@ -1,65 +0,0 @@
import pygit2
from datetime import datetime
import sys
def pull(repo, remote_name='origin', branch='master'):
for remote in repo.remotes:
if remote.name == remote_name:
remote.fetch()
remote_master_id = repo.lookup_reference('refs/remotes/origin/%s' % (branch)).target
merge_result, _ = repo.merge_analysis(remote_master_id)
# Up to date, do nothing
if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
return
# We can just fastforward
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
repo.checkout_tree(repo.get(remote_master_id))
try:
master_ref = repo.lookup_reference('refs/heads/%s' % (branch))
master_ref.set_target(remote_master_id)
except KeyError:
repo.create_branch(branch, repo.get(remote_master_id))
repo.head.set_target(remote_master_id)
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL:
repo.merge(remote_master_id)
if repo.index.conflicts is not None:
for conflict in repo.index.conflicts:
print('Conflicts found in:', conflict[0].path)
raise AssertionError('Conflicts, ahhhhh!!')
user = repo.default_signature
tree = repo.index.write_tree()
commit = repo.create_commit('HEAD',
user,
user,
'Merge!',
tree,
[repo.head.target, remote_master_id])
# We need to do this or git CLI will think we are still merging.
repo.state_cleanup()
else:
raise AssertionError('Unknown merge analysis result')
repo = pygit2.Repository(str(sys.argv[1]))
ident = pygit2.Signature('comfyui', 'comfy@ui')
try:
print("stashing current changes")
repo.stash(ident)
except KeyError:
print("nothing to stash")
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name))
repo.branches.local.create(backup_branch_name, repo.head.peel())
print("checking out master branch")
branch = repo.lookup_branch('master')
ref = repo.lookup_reference(branch.name)
repo.checkout(ref)
print("pulling latest changes")
pull(repo)
print("Done!")

View File

@ -1,2 +0,0 @@
..\python_embeded\python.exe .\update.py ..\ComfyUI\
pause

View File

@ -1,3 +1,3 @@
..\python_embeded\python.exe .\update.py ..\ComfyUI\ ..\python_embeded\python.exe .\update.py ..\ComfyUI\
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -r ../ComfyUI/requirements.txt pygit2 ..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2
pause pause

View File

@ -1,27 +0,0 @@
HOW TO RUN:
if you have a NVIDIA gpu:
run_nvidia_gpu.bat
To run it in slow CPU mode:
run_cpu.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
RECOMMENDED WAY TO UPDATE:
To update the ComfyUI code: update\update_comfyui.bat
To update ComfyUI with the python dependencies:
update\update_comfyui_and_python_dependencies.bat

View File

@ -1,2 +0,0 @@
.\python_embeded\python.exe -s ComfyUI\main.py --cpu --windows-standalone-build
pause

View File

@ -17,7 +17,7 @@ jobs:
- shell: bash - shell: bash
run: | run: |
python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers==0.0.19.dev516 --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/* python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic echo installed basic
ls -lah temp_wheel_dir ls -lah temp_wheel_dir

View File

@ -19,21 +19,21 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- uses: actions/setup-python@v4 - uses: actions/setup-python@v4
with: with:
python-version: '3.10.9' python-version: '3.11.3'
- shell: bash - shell: bash
run: | run: |
cd .. cd ..
cp -r ComfyUI ComfyUI_copy cp -r ComfyUI ComfyUI_copy
curl https://www.python.org/ftp/python/3.10.9/python-3.10.9-embed-amd64.zip -o python_embeded.zip curl https://www.python.org/ftp/python/3.11.3/python-3.11.3-embed-amd64.zip -o python_embeded.zip
unzip python_embeded.zip -d python_embeded unzip python_embeded.zip -d python_embeded
cd python_embeded cd python_embeded
echo 'import site' >> ./python310._pth echo 'import site' >> ./python311._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py ./python.exe get-pip.py
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/* ./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python310._pth sed -i '1i../ComfyUI' ./python311._pth
cd .. cd ..
@ -46,6 +46,8 @@ jobs:
mkdir update mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./ cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/nightly/update_windows/* ./update/
cp -r ComfyUI/.ci/nightly/windows_base_files/* ./
cd .. cd ..

View File

@ -7,6 +7,8 @@ A powerful and modular stable diffusion GUI and backend.
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out: This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/) ### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
### [Installing ComfyUI](#installing)
## Features ## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x and SD2.x - Fully supports SD1.x and SD2.x

View File

@ -5,17 +5,17 @@ import torch
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
from ldm.modules.diffusionmodules.util import ( from ..ldm.modules.diffusionmodules.util import (
conv_nd, conv_nd,
linear, linear,
zero_module, zero_module,
timestep_embedding, timestep_embedding,
) )
from ldm.modules.attention import SpatialTransformer from ..ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ldm.models.diffusion.ddpm import LatentDiffusion from ..ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.util import log_txt_as_img, exists, instantiate_from_config from ..ldm.util import log_txt_as_img, exists, instantiate_from_config
class ControlledUnetModel(UNetModel): class ControlledUnetModel(UNetModel):

View File

@ -10,6 +10,7 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")

View File

@ -712,7 +712,7 @@ class UniPC:
def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform', def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform',
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
atol=0.0078, rtol=0.05, corrector=False, callback=None atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
): ):
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start t_T = self.noise_schedule.T if t_start is None else t_start
@ -723,7 +723,7 @@ class UniPC:
# timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) # timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
assert timesteps.shape[0] - 1 == steps assert timesteps.shape[0] - 1 == steps
# with torch.no_grad(): # with torch.no_grad():
for step_index in trange(steps): for step_index in trange(steps, disable=disable_pbar):
if self.noise_mask is not None: if self.noise_mask is not None:
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index])) x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
if step_index == 0: if step_index == 0:
@ -767,7 +767,7 @@ class UniPC:
model_x = self.model_fn(x, vec_t) model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x model_prev_list[-1] = model_x
if callback is not None: if callback is not None:
callback(step_index, model_prev_list[-1], x) callback(step_index, model_prev_list[-1], x, steps)
else: else:
raise NotImplementedError() raise NotImplementedError()
if denoise_to_zero: if denoise_to_zero:
@ -835,7 +835,7 @@ def expand_dims(v, dims):
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=None, noise_mask=None, variant='bh1'): def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
to_zero = False to_zero = False
if sigmas[-1] == 0: if sigmas[-1] == 0:
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
@ -879,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
order = min(3, len(timesteps) - 1) order = min(3, len(timesteps) - 1)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback) x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
if not to_zero: if not to_zero:
x /= ns.marginal_alpha(timesteps[-1]) x /= ns.marginal_alpha(timesteps[-1])
return x return x

View File

@ -1,6 +1,6 @@
import torch import torch
from torch import nn, einsum from torch import nn, einsum
from ldm.modules.attention import CrossAttention from .ldm.modules.attention import CrossAttention
from inspect import isfunction from inspect import isfunction
@ -242,14 +242,28 @@ class Gligen(nn.Module):
self.position_net = position_net self.position_net = position_net
self.key_dim = key_dim self.key_dim = key_dim
self.max_objs = 30 self.max_objs = 30
self.lowvram = False
def _set_position(self, boxes, masks, positive_embeddings): def _set_position(self, boxes, masks, positive_embeddings):
if self.lowvram == True:
self.position_net.to(boxes.device)
objs = self.position_net(boxes, masks, positive_embeddings) objs = self.position_net(boxes, masks, positive_embeddings)
def func(key, x): if self.lowvram == True:
module = self.module_list[key] self.position_net.cpu()
return module(x, objs) def func_lowvram(key, x):
return func module = self.module_list[key]
module.to(x.device)
r = module(x, objs)
module.cpu()
return r
return func_lowvram
else:
def func(key, x):
module = self.module_list[key]
return module(x, objs)
return func
def set_position(self, latent_image_shape, position_params, device): def set_position(self, latent_image_shape, position_params, device):
batch, c, h, w = latent_image_shape batch, c, h, w = latent_image_shape
@ -294,8 +308,11 @@ class Gligen(nn.Module):
masks.to(device), masks.to(device),
conds.to(device)) conds.to(device))
def set_lowvram(self, value=True):
self.lowvram = value
def cleanup(self): def cleanup(self):
pass self.lowvram = False
def get_models(self): def get_models(self):
return [self] return [self]

View File

@ -3,11 +3,11 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from contextlib import contextmanager from contextlib import contextmanager
from ldm.modules.diffusionmodules.model import Encoder, Decoder from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.util import instantiate_from_config from comfy.ldm.util import instantiate_from_config
from ldm.modules.ema import LitEma from comfy.ldm.modules.ema import LitEma
# class AutoencoderKL(pl.LightningModule): # class AutoencoderKL(pl.LightningModule):
class AutoencoderKL(torch.nn.Module): class AutoencoderKL(torch.nn.Module):

View File

@ -4,7 +4,7 @@ import torch
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor from comfy.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object): class DDIMSampler(object):
@ -81,6 +81,7 @@ class DDIMSampler(object):
extra_args=None, extra_args=None,
to_zero=True, to_zero=True,
end_step=None, end_step=None,
disable_pbar=False,
**kwargs **kwargs
): ):
self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose) self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose)
@ -103,7 +104,8 @@ class DDIMSampler(object):
denoise_function=denoise_function, denoise_function=denoise_function,
extra_args=extra_args, extra_args=extra_args,
to_zero=to_zero, to_zero=to_zero,
end_step=end_step end_step=end_step,
disable_pbar=disable_pbar
) )
return samples, intermediates return samples, intermediates
@ -185,7 +187,7 @@ class DDIMSampler(object):
mask=None, x0=None, img_callback=None, log_every_t=100, mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None): ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False):
device = self.model.betas.device device = self.model.betas.device
b = shape[0] b = shape[0]
if x_T is None: if x_T is None:
@ -204,7 +206,7 @@ class DDIMSampler(object):
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
# print(f"Running DDIM Sampling with {total_steps} timesteps") # print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step) iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar)
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1

View File

@ -19,12 +19,12 @@ from tqdm import tqdm
from torchvision.utils import make_grid from torchvision.utils import make_grid
# from pytorch_lightning.utilities.distributed import rank_zero_only # from pytorch_lightning.utilities.distributed import rank_zero_only
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config from comfy.ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma from comfy.ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution from comfy.ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL from ..autoencoder import IdentityFirstStage, AutoencoderKL
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler from .ddim import DDIMSampler
__conditioning_keys__ = {'concat': 'c_concat', __conditioning_keys__ = {'concat': 'c_concat',

View File

@ -6,7 +6,7 @@ from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from typing import Optional, Any from typing import Optional, Any
from ldm.modules.diffusionmodules.util import checkpoint from .diffusionmodules.util import checkpoint
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management from comfy import model_management
@ -21,7 +21,7 @@ if model_management.xformers_enabled():
import os import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
from cli_args import args from comfy.cli_args import args
def exists(val): def exists(val):
return val is not None return val is not None
@ -572,9 +572,6 @@ class BasicTransformerBlock(nn.Module):
x += n x += n
x = self.ff(self.norm3(x)) + x x = self.ff(self.norm3(x)) + x
if current_index is not None:
transformer_options["current_index"] += 1
return x return x

View File

@ -6,7 +6,7 @@ import numpy as np
from einops import rearrange from einops import rearrange
from typing import Optional, Any from typing import Optional, Any
from ldm.modules.attention import MemoryEfficientCrossAttention from ..attention import MemoryEfficientCrossAttention
from comfy import model_management from comfy import model_management
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():

View File

@ -6,7 +6,7 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ldm.modules.diffusionmodules.util import ( from .util import (
checkpoint, checkpoint,
conv_nd, conv_nd,
linear, linear,
@ -15,8 +15,8 @@ from ldm.modules.diffusionmodules.util import (
normalization, normalization,
timestep_embedding, timestep_embedding,
) )
from ldm.modules.attention import SpatialTransformer from ..attention import SpatialTransformer
from ldm.util import exists from comfy.ldm.util import exists
# dummy replace # dummy replace
@ -76,16 +76,31 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input. support it as an extra input.
""" """
def forward(self, x, emb, context=None, transformer_options={}): def forward(self, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in self: for layer in self:
if isinstance(layer, TimestepBlock): if isinstance(layer, TimestepBlock):
x = layer(x, emb) x = layer(x, emb)
elif isinstance(layer, SpatialTransformer): elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options) x = layer(x, context, transformer_options)
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else: else:
x = layer(x) x = layer(x)
return x return x
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in ts:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
transformer_options["current_index"] += 1
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
class Upsample(nn.Module): class Upsample(nn.Module):
""" """
@ -105,14 +120,20 @@ class Upsample(nn.Module):
if use_conv: if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x): def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.dims == 3: if self.dims == 3:
x = F.interpolate( shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" if output_shape is not None:
) shape[1] = output_shape[3]
shape[2] = output_shape[4]
else: else:
x = F.interpolate(x, scale_factor=2, mode="nearest") shape = [x.shape[2] * 2, x.shape[3] * 2]
if output_shape is not None:
shape[0] = output_shape[2]
shape[1] = output_shape[3]
x = F.interpolate(x, size=shape, mode="nearest")
if self.use_conv: if self.use_conv:
x = self.conv(x) x = self.conv(x)
return x return x
@ -797,13 +818,13 @@ class UNetModel(nn.Module):
h = x.type(self.dtype) h = x.type(self.dtype)
for id, module in enumerate(self.input_blocks): for id, module in enumerate(self.input_blocks):
h = module(h, emb, context, transformer_options) h = forward_timestep_embed(module, h, emb, context, transformer_options)
if control is not None and 'input' in control and len(control['input']) > 0: if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop() ctrl = control['input'].pop()
if ctrl is not None: if ctrl is not None:
h += ctrl h += ctrl
hs.append(h) hs.append(h)
h = self.middle_block(h, emb, context, transformer_options) h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0: if control is not None and 'middle' in control and len(control['middle']) > 0:
h += control['middle'].pop() h += control['middle'].pop()
@ -813,9 +834,14 @@ class UNetModel(nn.Module):
ctrl = control['output'].pop() ctrl = control['output'].pop()
if ctrl is not None: if ctrl is not None:
hsp += ctrl hsp += ctrl
h = th.cat([h, hsp], dim=1) h = th.cat([h, hsp], dim=1)
del hsp del hsp
h = module(h, emb, context, transformer_options) if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
h = h.type(x.dtype) h = h.type(x.dtype)
if self.predict_codebook_ids: if self.predict_codebook_ids:
return self.id_predictor(h) return self.id_predictor(h)

View File

@ -3,8 +3,8 @@ import torch.nn as nn
import numpy as np import numpy as np
from functools import partial from functools import partial
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule from .util import extract_into_tensor, make_beta_schedule
from ldm.util import default from comfy.ldm.util import default
class AbstractLowScaleModel(nn.Module): class AbstractLowScaleModel(nn.Module):

View File

@ -15,7 +15,7 @@ import torch.nn as nn
import numpy as np import numpy as np
from einops import repeat from einops import repeat
from ldm.util import instantiate_from_config from comfy.ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):

View File

@ -1,5 +1,5 @@
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from ..diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from ldm.modules.diffusionmodules.openaimodel import Timestep from ..diffusionmodules.openaimodel import Timestep
import torch import torch
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):

View File

@ -1,6 +1,6 @@
import psutil import psutil
from enum import Enum from enum import Enum
from cli_args import args from comfy.cli_args import args
class VRAMState(Enum): class VRAMState(Enum):
CPU = 0 CPU = 0
@ -20,15 +20,30 @@ total_vram_available_mb = -1
accelerate_enabled = False accelerate_enabled = False
xpu_available = False xpu_available = False
directml_enabled = False
if args.directml is not None:
import torch_directml
directml_enabled = True
device_index = args.directml
if device_index < 0:
directml_device = torch_directml.device()
else:
directml_device = torch_directml.device(device_index)
print("Using directml with device:", torch_directml.device_name(device_index))
# torch_directml.disable_tiled_resources(True)
try: try:
import torch import torch
try: if directml_enabled:
import intel_extension_for_pytorch as ipex total_vram = 4097 #TODO
if torch.xpu.is_available(): else:
xpu_available = True try:
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) import intel_extension_for_pytorch as ipex
except: if torch.xpu.is_available():
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) xpu_available = True
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
except:
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
if not args.normalvram and not args.cpu: if not args.normalvram and not args.cpu:
if total_vram <= 4096: if total_vram <= 4096:
@ -186,6 +201,9 @@ def load_controlnet_gpu(control_models):
return return
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
for m in control_models:
if hasattr(m, 'set_lowvram'):
m.set_lowvram(True)
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return return
@ -217,6 +235,10 @@ def unload_if_low_vram(model):
def get_torch_device(): def get_torch_device():
global xpu_available global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS: if vram_state == VRAMState.MPS:
return torch.device("mps") return torch.device("mps")
if vram_state == VRAMState.CPU: if vram_state == VRAMState.CPU:
@ -234,8 +256,14 @@ def get_autocast_device(dev):
def xformers_enabled(): def xformers_enabled():
global xpu_available
global directml_enabled
if vram_state == VRAMState.CPU: if vram_state == VRAMState.CPU:
return False return False
if xpu_available:
return False
if directml_enabled:
return False
return XFORMERS_IS_AVAILABLE return XFORMERS_IS_AVAILABLE
@ -251,6 +279,7 @@ def pytorch_attention_enabled():
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global xpu_available global xpu_available
global directml_enabled
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()
@ -258,7 +287,10 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = psutil.virtual_memory().available mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
else: else:
if xpu_available: if directml_enabled:
mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_torch = mem_free_total
elif xpu_available:
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
else: else:
@ -293,9 +325,14 @@ def mps_mode():
def should_use_fp16(): def should_use_fp16():
global xpu_available global xpu_available
global directml_enabled
if FORCE_FP32: if FORCE_FP32:
return False return False
if directml_enabled:
return False
if cpu_mode() or mps_mode() or xpu_available: if cpu_mode() or mps_mode() or xpu_available:
return False #TODO ? return False #TODO ?

View File

@ -65,7 +65,7 @@ def cleanup_additional_models(models):
for m in models: for m in models:
m.cleanup() m.cleanup()
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None): def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
if noise_mask is not None: if noise_mask is not None:
@ -85,7 +85,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar)
samples = samples.cpu() samples = samples.cpu()
cleanup_additional_models(models) cleanup_additional_models(models)

View File

@ -23,21 +23,36 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
adm_cond = cond[1]['adm_encoded'] adm_cond = cond[1]['adm_encoded']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
mult = torch.ones_like(input_x) * strength if 'mask' in cond[1]:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if "mask_strength" in cond[1]:
mask_strength = cond[1]["mask_strength"]
mask = cond[1]['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)
mult = mask * strength
if 'mask' not in cond[1]:
rr = 8
if area[2] != 0:
for t in range(rr):
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
if (area[0] + area[2]) < x_in.shape[2]:
for t in range(rr):
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
if area[3] != 0:
for t in range(rr):
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
rr = 8
if area[2] != 0:
for t in range(rr):
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
if (area[0] + area[2]) < x_in.shape[2]:
for t in range(rr):
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
if area[3] != 0:
for t in range(rr):
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {} conditionning = {}
conditionning['c_crossattn'] = cond[0] conditionning['c_crossattn'] = cond[0]
if cond_concat_in is not None and len(cond_concat_in) > 0: if cond_concat_in is not None and len(cond_concat_in) > 0:
@ -301,6 +316,71 @@ def blank_inpaint_image_like(latent_image):
blank_image[:,3] *= 0.1380 blank_image[:,3] *= 0.1380
return blank_image return blank_image
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
b = masks.shape[0]
bounding_boxes = torch.zeros((b, 4), device=masks.device, dtype=torch.int)
is_empty = torch.zeros((b), device=masks.device, dtype=torch.bool)
for i in range(b):
mask = masks[i]
if mask.numel() == 0:
continue
if torch.max(mask != 0) == False:
is_empty[i] = True
continue
y, x = torch.where(mask)
bounding_boxes[i, 0] = torch.min(x)
bounding_boxes[i, 1] = torch.min(y)
bounding_boxes[i, 2] = torch.max(x)
bounding_boxes[i, 3] = torch.max(y)
return bounding_boxes, is_empty
def resolve_cond_masks(conditions, h, w, device):
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
for i in range(len(conditions)):
c = conditions[i]
if 'mask' in c[1]:
mask = c[1]['mask']
mask = mask.to(device=device)
modified = c[1].copy()
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
if mask.shape[2] != h or mask.shape[3] != w:
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
if modified.get("set_area_to_bounds", False):
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
boxes, is_empty = get_mask_aabb(bounds)
if is_empty[0]:
# Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway)
modified['area'] = (8, 8, 0, 0)
else:
box = boxes[0]
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
# Make sure the height and width are divisible by 8
if X % 8 != 0:
newx = X // 8 * 8
W = W + (X - newx)
X = newx
if Y % 8 != 0:
newy = Y // 8 * 8
H = H + (Y - newy)
Y = newy
if H % 8 != 0:
H = H + (8 - (H % 8))
if W % 8 != 0:
W = W + (8 - (W % 8))
area = (int(H), int(W), int(Y), int(X))
modified['area'] = area
modified['mask'] = mask
conditions[i] = [c[0], modified]
def create_cond_with_same_area_if_none(conds, c): def create_cond_with_same_area_if_none(conds, c):
if 'area' not in c[1]: if 'area' not in c[1]:
return return
@ -461,8 +541,7 @@ class KSampler:
sigmas = self.calculate_sigmas(new_steps).to(self.device) sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1):] self.sigmas = sigmas[-(steps + 1):]
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False):
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None):
if sigmas is None: if sigmas is None:
sigmas = self.sigmas sigmas = self.sigmas
sigma_min = self.sigma_min sigma_min = self.sigma_min
@ -484,6 +563,10 @@ class KSampler:
positive = positive[:] positive = positive[:]
negative = negative[:] negative = negative[:]
resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
#make sure each cond area has an opposite one with the same area #make sure each cond area has an opposite one with the same area
for c in positive: for c in positive:
create_cond_with_same_area_if_none(negative, c) create_cond_with_same_area_if_none(negative, c)
@ -527,9 +610,9 @@ class KSampler:
with precision_scope(model_management.get_autocast_device(self.device)): with precision_scope(model_management.get_autocast_device(self.device)):
if self.sampler == "uni_pc": if self.sampler == "uni_pc":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback) samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
elif self.sampler == "uni_pc_bh2": elif self.sampler == "uni_pc_bh2":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2') samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
elif self.sampler == "ddim": elif self.sampler == "ddim":
timesteps = [] timesteps = []
for s in range(sigmas.shape[0]): for s in range(sigmas.shape[0]):
@ -540,7 +623,8 @@ class KSampler:
ddim_callback = None ddim_callback = None
if callback is not None: if callback is not None:
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None) total_steps = len(timesteps) - 1
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
sampler = DDIMSampler(self.model, device=self.device) sampler = DDIMSampler(self.model, device=self.device)
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
@ -560,7 +644,8 @@ class KSampler:
extra_args=extra_args, extra_args=extra_args,
mask=noise_mask, mask=noise_mask,
to_zero=sigmas[-1]==0, to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1) end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)
else: else:
extra_args["denoise_mask"] = denoise_mask extra_args["denoise_mask"] = denoise_mask
@ -570,16 +655,17 @@ class KSampler:
noise = noise * sigmas[0] noise = noise * sigmas[0]
k_callback = None k_callback = None
total_steps = len(sigmas) - 1
if callback is not None: if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"]) k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
if latent_image is not None: if latent_image is not None:
noise += latent_image noise += latent_image
if self.sampler == "dpm_fast": if self.sampler == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback) samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
elif self.sampler == "dpm_adaptive": elif self.sampler == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback) samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
else: else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback) samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
return samples.to(torch.float32) return samples.to(torch.float32)

View File

@ -2,8 +2,8 @@ import torch
import contextlib import contextlib
import copy import copy
import sd1_clip from . import sd1_clip
import sd2_clip from . import sd2_clip
from comfy import model_management from comfy import model_management
from .ldm.util import instantiate_from_config from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL from .ldm.models.autoencoder import AutoencoderKL
@ -111,6 +111,8 @@ def load_lora(path, to_load):
loaded_keys.add(A_name) loaded_keys.add(A_name)
loaded_keys.add(B_name) loaded_keys.add(B_name)
######## loha
hada_w1_a_name = "{}.hada_w1_a".format(x) hada_w1_a_name = "{}.hada_w1_a".format(x)
hada_w1_b_name = "{}.hada_w1_b".format(x) hada_w1_b_name = "{}.hada_w1_b".format(x)
hada_w2_a_name = "{}.hada_w2_a".format(x) hada_w2_a_name = "{}.hada_w2_a".format(x)
@ -132,6 +134,54 @@ def load_lora(path, to_load):
loaded_keys.add(hada_w2_a_name) loaded_keys.add(hada_w2_a_name)
loaded_keys.add(hada_w2_b_name) loaded_keys.add(hada_w2_b_name)
######## lokr
lokr_w1_name = "{}.lokr_w1".format(x)
lokr_w2_name = "{}.lokr_w2".format(x)
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
lokr_t2_name = "{}.lokr_t2".format(x)
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
lokr_w1 = None
if lokr_w1_name in lora.keys():
lokr_w1 = lora[lokr_w1_name]
loaded_keys.add(lokr_w1_name)
lokr_w2 = None
if lokr_w2_name in lora.keys():
lokr_w2 = lora[lokr_w2_name]
loaded_keys.add(lokr_w2_name)
lokr_w1_a = None
if lokr_w1_a_name in lora.keys():
lokr_w1_a = lora[lokr_w1_a_name]
loaded_keys.add(lokr_w1_a_name)
lokr_w1_b = None
if lokr_w1_b_name in lora.keys():
lokr_w1_b = lora[lokr_w1_b_name]
loaded_keys.add(lokr_w1_b_name)
lokr_w2_a = None
if lokr_w2_a_name in lora.keys():
lokr_w2_a = lora[lokr_w2_a_name]
loaded_keys.add(lokr_w2_a_name)
lokr_w2_b = None
if lokr_w2_b_name in lora.keys():
lokr_w2_b = lora[lokr_w2_b_name]
loaded_keys.add(lokr_w2_b_name)
lokr_t2 = None
if lokr_t2_name in lora.keys():
lokr_t2 = lora[lokr_t2_name]
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
for x in lora.keys(): for x in lora.keys():
if x not in loaded_keys: if x not in loaded_keys:
print("lora key not loaded", x) print("lora key not loaded", x)
@ -315,6 +365,33 @@ class ModelPatcher:
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
elif len(v) == 8: #lokr
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(w1_a.float(), w1_b.float())
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(w2_a.float(), w2_b.float())
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float())
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device)
else: #loha else: #loha
w1a = v[0] w1a = v[0]
w1b = v[1] w1b = v[1]
@ -369,10 +446,10 @@ class CLIP:
else: else:
params = {} params = {}
if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder": if self.target_clip.endswith("FrozenOpenCLIPEmbedder"):
clip = sd2_clip.SD2ClipModel clip = sd2_clip.SD2ClipModel
tokenizer = sd2_clip.SD2Tokenizer tokenizer = sd2_clip.SD2Tokenizer
elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder": elif self.target_clip.endswith("FrozenCLIPEmbedder"):
clip = sd1_clip.SD1ClipModel clip = sd1_clip.SD1ClipModel
tokenizer = sd1_clip.SD1Tokenizer tokenizer = sd1_clip.SD1Tokenizer
@ -437,11 +514,16 @@ class VAE:
self.device = device self.device = device
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = utils.ProgressBar(steps)
decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0)
output = torch.clamp(( output = torch.clamp((
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) + (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) + utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8)) utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar))
/ 3.0) / 2.0, min=0.0, max=1.0) / 3.0) / 2.0, min=0.0, max=1.0)
return output return output
@ -485,9 +567,15 @@ class VAE:
model_management.unload_model() model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1).to(self.device) pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4)
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4) steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4) steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = utils.ProgressBar(steps)
samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples /= 3.0 samples /= 3.0
self.first_stage_model = self.first_stage_model.cpu() self.first_stage_model = self.first_stage_model.cpu()
samples = samples.cpu() samples = samples.cpu()
@ -808,9 +896,9 @@ def load_clip(ckpt_path, embedding_directory=None):
clip_data = utils.load_torch_file(ckpt_path) clip_data = utils.load_torch_file(ckpt_path)
config = {} config = {}
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data: if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
else: else:
config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder' config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder'
clip = CLIP(config=config, embedding_directory=embedding_directory) clip = CLIP(config=config, embedding_directory=embedding_directory)
clip.load_from_state_dict(clip_data) clip.load_from_state_dict(clip_data)
return clip return clip
@ -886,9 +974,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_clip: if output_clip:
clip_config = {} clip_config = {}
if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys: if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys:
clip_config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
else: else:
clip_config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder' clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder'
clip = CLIP(config=clip_config, embedding_directory=embedding_directory) clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w] load_state_dict_to = [w]
@ -909,7 +997,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0] noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0]
noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2" noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2"
params["noise_schedule_config"] = noise_schedule_config params["noise_schedule_config"] = noise_schedule_config
noise_aug_config['target'] = "ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation" noise_aug_config['target'] = "comfy.ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation"
if size == 1280: #h if size == 1280: #h
params["timestep_dim"] = 1024 params["timestep_dim"] = 1024
elif size == 1024: #l elif size == 1024: #l
@ -961,19 +1049,19 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1] unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
if noise_aug_config is not None: #SD2.x unclip model if noise_aug_config is not None: #SD2.x unclip model
sd_config["noise_aug_config"] = noise_aug_config sd_config["noise_aug_config"] = noise_aug_config
sd_config["image_size"] = 96 sd_config["image_size"] = 96
sd_config["embedding_dropout"] = 0.25 sd_config["embedding_dropout"] = 0.25
sd_config["conditioning_key"] = 'crossattn-adm' sd_config["conditioning_key"] = 'crossattn-adm'
model_config["target"] = "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion" model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
elif unet_config["in_channels"] > 4: #inpainting model elif unet_config["in_channels"] > 4: #inpainting model
sd_config["conditioning_key"] = "hybrid" sd_config["conditioning_key"] = "hybrid"
sd_config["finetune_keys"] = None sd_config["finetune_keys"] = None
model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
else: else:
sd_config["conditioning_key"] = "crossattn" sd_config["conditioning_key"] = "crossattn"

View File

@ -191,11 +191,20 @@ def safe_load_embed_zip(embed_path):
del embed del embed
return out return out
def expand_directory_list(directories):
dirs = set()
for x in directories:
dirs.add(x)
for root, subdir, file in os.walk(x, followlinks=True):
dirs.add(root)
return list(dirs)
def load_embed(embedding_name, embedding_directory): def load_embed(embedding_name, embedding_directory):
if isinstance(embedding_directory, str): if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory] embedding_directory = [embedding_directory]
embedding_directory = expand_directory_list(embedding_directory)
valid_file = None valid_file = None
for embed_dir in embedding_directory: for embed_dir in embedding_directory:
embed_path = os.path.join(embed_dir, embedding_name) embed_path = os.path.join(embed_dir, embedding_name)

View File

@ -1,4 +1,5 @@
import torch import torch
import math
def load_torch_file(ckpt, safe_load=False): def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"): if ckpt.lower().endswith(".safetensors"):
@ -62,8 +63,11 @@ def common_upscale(samples, width, height, upscale_method, crop):
s = samples s = samples
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
@torch.inference_mode() @torch.inference_mode()
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3): def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None):
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu") output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu")
for b in range(samples.shape[0]): for b in range(samples.shape[0]):
s = samples[b:b+1] s = samples[b:b+1]
@ -83,6 +87,33 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1)) mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask
out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask
if pbar is not None:
pbar.update(1)
output[b:b+1] = out/out_div output[b:b+1] = out/out_div
return output return output
PROGRESS_BAR_HOOK = None
def set_progress_bar_global_hook(function):
global PROGRESS_BAR_HOOK
PROGRESS_BAR_HOOK = function
class ProgressBar:
def __init__(self, total):
global PROGRESS_BAR_HOOK
self.total = total
self.current = 0
self.hook = PROGRESS_BAR_HOOK
def update_absolute(self, value, total=None):
if total is not None:
self.total = total
if value > self.total:
value = self.total
self.current = value
if self.hook is not None:
self.hook(self.current, self.total)
def update(self, value):
self.update_absolute(self.current + value)

View File

@ -18,6 +18,7 @@ def load_hypernetwork_patch(path, strength):
"swish": torch.nn.Hardswish, "swish": torch.nn.Hardswish,
"tanh": torch.nn.Tanh, "tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid, "sigmoid": torch.nn.Sigmoid,
"softsign": torch.nn.Softsign,
} }
if activation_func not in valid_activation: if activation_func not in valid_activation:

View File

@ -37,7 +37,12 @@ class ImageUpscaleWithModel:
device = model_management.get_torch_device() device = model_management.get_torch_device()
upscale_model.to(device) upscale_model.to(device)
in_img = image.movedim(-1,-3).to(device) in_img = image.movedim(-1,-3).to(device)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=128 + 64, tile_y=128 + 64, overlap = 8, upscale_amount=upscale_model.scale)
tile = 128 + 64
overlap = 8
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
upscale_model.cpu() upscale_model.cpu()
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
return (s,) return (s,)

12
main.py
View File

@ -5,6 +5,7 @@ import shutil
import threading import threading
from comfy.cli_args import args from comfy.cli_args import args
import comfy.utils
if os.name == "nt": if os.name == "nt":
import logging import logging
@ -39,14 +40,9 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
def hijack_progress(server): def hijack_progress(server):
from tqdm.auto import tqdm def hook(value, total):
orig_func = getattr(tqdm, "update") server.send_sync("progress", { "value": value, "max": total}, server.client_id)
def wrapped_func(*args, **kwargs): comfy.utils.set_progress_bar_global_hook(hook)
pbar = args[0]
v = orig_func(*args, **kwargs)
server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id)
return v
setattr(tqdm, "update", wrapped_func)
def cleanup_temp(): def cleanup_temp():
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")

157
nodes.py
View File

@ -5,6 +5,7 @@ import sys
import json import json
import hashlib import hashlib
import traceback import traceback
import math
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
@ -59,14 +60,44 @@ class ConditioningCombine:
def combine(self, conditioning_1, conditioning_2): def combine(self, conditioning_1, conditioning_2):
return (conditioning_1 + conditioning_2, ) return (conditioning_1 + conditioning_2, )
class ConditioningAverage :
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ),
"conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "addWeighted"
CATEGORY = "conditioning"
def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_strength):
out = []
if len(conditioning_from) > 1:
print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
cond_from = conditioning_from[0][0]
for i in range(len(conditioning_to)):
t1 = conditioning_to[i][0]
t0 = cond_from[:,:t1.shape[1]]
if t0.shape[1] < t1.shape[1]:
t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)
tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
n = [tw, conditioning_to[i][1].copy()]
out.append(n)
return (out, )
class ConditioningSetArea: class ConditioningSetArea:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ), return {"required": {"conditioning": ("CONDITIONING", ),
"width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}} }}
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
@ -80,11 +111,41 @@ class ConditioningSetArea:
n = [t[0], t[1].copy()] n = [t[0], t[1].copy()]
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8) n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
n[1]['strength'] = strength n[1]['strength'] = strength
n[1]['set_area_to_bounds'] = False
n[1]['min_sigma'] = min_sigma n[1]['min_sigma'] = min_sigma
n[1]['max_sigma'] = max_sigma n[1]['max_sigma'] = max_sigma
c.append(n) c.append(n)
return (c, ) return (c, )
class ConditioningSetMask:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"mask": ("MASK", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"set_cond_area": (["default", "mask bounds"],),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, conditioning, mask, set_cond_area, strength):
c = []
set_area_to_bounds = False
if set_cond_area != "default":
set_area_to_bounds = True
if len(mask.shape) < 3:
mask = mask.unsqueeze(0)
for t in conditioning:
n = [t[0], t[1].copy()]
_, h, w = mask.shape
n[1]['mask'] = mask
n[1]['set_area_to_bounds'] = set_area_to_bounds
n[1]['mask_strength'] = strength
c.append(n)
return (c, )
class VAEDecode: class VAEDecode:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
self.device = device self.device = device
@ -127,16 +188,21 @@ class VAEEncode:
CATEGORY = "latent" CATEGORY = "latent"
def encode(self, vae, pixels): @staticmethod
x = (pixels.shape[1] // 64) * 64 def vae_encode_crop_pixels(pixels):
y = (pixels.shape[2] // 64) * 64 x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
if pixels.shape[1] != x or pixels.shape[2] != y: if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:] x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
return pixels
def encode(self, vae, pixels):
pixels = self.vae_encode_crop_pixels(pixels)
t = vae.encode(pixels[:,:,:,:3]) t = vae.encode(pixels[:,:,:,:3])
return ({"samples":t}, ) return ({"samples":t}, )
class VAEEncodeTiled: class VAEEncodeTiled:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
self.device = device self.device = device
@ -150,38 +216,43 @@ class VAEEncodeTiled:
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def encode(self, vae, pixels): def encode(self, vae, pixels):
x = (pixels.shape[1] // 64) * 64 pixels = VAEEncode.vae_encode_crop_pixels(pixels)
y = (pixels.shape[2] // 64) * 64
if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:]
t = vae.encode_tiled(pixels[:,:,:,:3]) t = vae.encode_tiled(pixels[:,:,:,:3])
return ({"samples":t}, ) return ({"samples":t}, )
class VAEEncodeForInpaint: class VAEEncodeForInpaint:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
self.device = device self.device = device
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}} return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "encode" FUNCTION = "encode"
CATEGORY = "latent/inpaint" CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask): def encode(self, vae, pixels, mask, grow_mask_by=6):
x = (pixels.shape[1] // 64) * 64 x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 64) * 64 y = (pixels.shape[2] // 8) * 8
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
pixels = pixels.clone() pixels = pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y: if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:] x_offset = (pixels.shape[1] % 8) // 2
mask = mask[:,:,:x,:y] y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
#grow mask by a few pixels to keep things seamless in latent space #grow mask by a few pixels to keep things seamless in latent space
kernel_tensor = torch.ones((1, 1, 6, 6)) if grow_mask_by == 0:
mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=3), 0, 1) mask_erosion = mask
else:
kernel_tensor = torch.ones((1, 1, grow_mask_by, grow_mask_by))
padding = math.ceil((grow_mask_by - 1) / 2)
mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=padding), 0, 1)
m = (1.0 - mask.round()).squeeze(1) m = (1.0 - mask.round()).squeeze(1)
for i in range(3): for i in range(3):
pixels[:,:,:,i] -= 0.5 pixels[:,:,:,i] -= 0.5
@ -543,8 +614,8 @@ class EmptyLatentImage:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}} "batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "generate" FUNCTION = "generate"
@ -587,8 +658,8 @@ class LatentUpscale:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"crop": (s.crop_methods,)}} "crop": (s.crop_methods,)}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "upscale" FUNCTION = "upscale"
@ -690,8 +761,8 @@ class LatentCrop:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), return {"required": { "samples": ("LATENT",),
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
}} }}
@ -716,16 +787,6 @@ class LatentCrop:
new_width = width // 8 new_width = width // 8
to_x = new_width + x to_x = new_width + x
to_y = new_height + y to_y = new_height + y
def enforce_image_dim(d, to_d, max_d):
if to_d > max_d:
leftover = (to_d - max_d) % 8
to_d = max_d
d -= leftover
return (d, to_d)
#make sure size is always multiple of 64
x, to_x = enforce_image_dim(x, to_x, samples.shape[3])
y, to_y = enforce_image_dim(y, to_y, samples.shape[2])
s['samples'] = samples[:,:,y:to_y, x:to_x] s['samples'] = samples[:,:,y:to_y, x:to_x]
return (s,) return (s,)
@ -759,9 +820,13 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent: if "noise_mask" in latent:
noise_mask = latent["noise_mask"] noise_mask = latent["noise_mask"]
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
pbar.update_absolute(step + 1, total_steps)
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask) force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback)
out = latent.copy() out = latent.copy()
out["samples"] = samples out["samples"] = samples
return (out, ) return (out, )
@ -1043,10 +1108,10 @@ class ImagePadForOutpaint:
return { return {
"required": { "required": {
"image": ("IMAGE",), "image": ("IMAGE",),
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
} }
} }
@ -1118,8 +1183,10 @@ NODE_CLASS_MAPPINGS = {
"ImageScale": ImageScale, "ImageScale": ImageScale,
"ImageInvert": ImageInvert, "ImageInvert": ImageInvert,
"ImagePadForOutpaint": ImagePadForOutpaint, "ImagePadForOutpaint": ImagePadForOutpaint,
"ConditioningAverage ": ConditioningAverage ,
"ConditioningCombine": ConditioningCombine, "ConditioningCombine": ConditioningCombine,
"ConditioningSetArea": ConditioningSetArea, "ConditioningSetArea": ConditioningSetArea,
"ConditioningSetMask": ConditioningSetMask,
"KSamplerAdvanced": KSamplerAdvanced, "KSamplerAdvanced": KSamplerAdvanced,
"SetLatentNoiseMask": SetLatentNoiseMask, "SetLatentNoiseMask": SetLatentNoiseMask,
"LatentComposite": LatentComposite, "LatentComposite": LatentComposite,
@ -1168,7 +1235,9 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CLIPTextEncode": "CLIP Text Encode (Prompt)", "CLIPTextEncode": "CLIP Text Encode (Prompt)",
"CLIPSetLastLayer": "CLIP Set Last Layer", "CLIPSetLastLayer": "CLIP Set Last Layer",
"ConditioningCombine": "Conditioning (Combine)", "ConditioningCombine": "Conditioning (Combine)",
"ConditioningAverage ": "Conditioning (Average)",
"ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetArea": "Conditioning (Set Area)",
"ConditioningSetMask": "Conditioning (Set Mask)",
"ControlNetApply": "Apply ControlNet", "ControlNetApply": "Apply ControlNet",
# Latent # Latent
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)", "VAEEncodeForInpaint": "VAE Encode (for Inpainting)",

View File

@ -232,10 +232,27 @@ app.registerExtension({
"name": "My Color Palette", "name": "My Color Palette",
"colors": { "colors": {
"node_slot": { "node_slot": {
},
"litegraph_base": {
},
"comfy_base": {
} }
} }
}; };
// Copy over missing keys from default color palette
const defaultColorPalette = colorPalettes[defaultColorPaletteId];
for (const key in defaultColorPalette.colors.litegraph_base) {
if (!colorPalette.colors.litegraph_base[key]) {
colorPalette.colors.litegraph_base[key] = "";
}
}
for (const key in defaultColorPalette.colors.comfy_base) {
if (!colorPalette.colors.comfy_base[key]) {
colorPalette.colors.comfy_base[key] = "";
}
}
return completeColorPalette(colorPalette); return completeColorPalette(colorPalette);
}; };

View File

@ -6,6 +6,7 @@ app.registerExtension({
name: "Comfy.SlotDefaults", name: "Comfy.SlotDefaults",
suggestionsNumber: null, suggestionsNumber: null,
init() { init() {
LiteGraph.search_filter_enabled = true;
LiteGraph.middle_click_slot_add_default_node = true; LiteGraph.middle_click_slot_add_default_node = true;
this.suggestionsNumber = app.ui.settings.addSetting({ this.suggestionsNumber = app.ui.settings.addSetting({
id: "Comfy.NodeSuggestions.number", id: "Comfy.NodeSuggestions.number",
@ -43,6 +44,14 @@ app.registerExtension({
} }
if (this.slot_types_default_out[type].includes(nodeId)) continue; if (this.slot_types_default_out[type].includes(nodeId)) continue;
this.slot_types_default_out[type].push(nodeId); this.slot_types_default_out[type].push(nodeId);
// Input types have to be stored as lower case
// Store each node that can handle this input type
const lowerType = type.toLocaleLowerCase();
if (!(lowerType in LiteGraph.registered_slot_in_types)) {
LiteGraph.registered_slot_in_types[lowerType] = { nodes: [] };
}
LiteGraph.registered_slot_in_types[lowerType].nodes.push(nodeType.comfyClass);
} }
var outputs = nodeData["output"]; var outputs = nodeData["output"];
@ -53,6 +62,16 @@ app.registerExtension({
} }
this.slot_types_default_in[type].push(nodeId); this.slot_types_default_in[type].push(nodeId);
// Store each node that can handle this output type
if (!(type in LiteGraph.registered_slot_out_types)) {
LiteGraph.registered_slot_out_types[type] = { nodes: [] };
}
LiteGraph.registered_slot_out_types[type].nodes.push(nodeType.comfyClass);
if(!LiteGraph.slot_types_out.includes(type)) {
LiteGraph.slot_types_out.push(type);
}
} }
var maxNum = this.suggestionsNumber.value; var maxNum = this.suggestionsNumber.value;
this.setDefaults(maxNum); this.setDefaults(maxNum);

View File

@ -3628,6 +3628,18 @@
return size; return size;
}; };
LGraphNode.prototype.inResizeCorner = function(canvasX, canvasY) {
var rows = this.outputs ? this.outputs.length : 1;
var outputs_offset = (this.constructor.slot_start_y || 0) + rows * LiteGraph.NODE_SLOT_HEIGHT;
return isInsideRectangle(canvasX,
canvasY,
this.pos[0] + this.size[0] - 15,
this.pos[1] + Math.max(this.size[1] - 15, outputs_offset),
20,
20
);
}
/** /**
* returns all the info available about a property of this node. * returns all the info available about a property of this node.
* *
@ -5877,14 +5889,7 @@ LGraphNode.prototype.executeAction = function(action)
if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
//Search for corner for resize //Search for corner for resize
if ( !skip_action && if ( !skip_action &&
node.resizable !== false && node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY)
isInsideRectangle( e.canvasX,
e.canvasY,
node.pos[0] + node.size[0] - 5,
node.pos[1] + node.size[1] - 5,
10,
10
)
) { ) {
this.graph.beforeChange(); this.graph.beforeChange();
this.resizing_node = node; this.resizing_node = node;
@ -6424,16 +6429,7 @@ LGraphNode.prototype.executeAction = function(action)
//Search for corner //Search for corner
if (this.canvas) { if (this.canvas) {
if ( if (node.inResizeCorner(e.canvasX, e.canvasY)) {
isInsideRectangle(
e.canvasX,
e.canvasY,
node.pos[0] + node.size[0] - 5,
node.pos[1] + node.size[1] - 5,
5,
5
)
) {
this.canvas.style.cursor = "se-resize"; this.canvas.style.cursor = "se-resize";
} else { } else {
this.canvas.style.cursor = "crosshair"; this.canvas.style.cursor = "crosshair";

View File

@ -263,6 +263,34 @@ export class ComfyApp {
*/ */
#addDrawBackgroundHandler(node) { #addDrawBackgroundHandler(node) {
const app = this; const app = this;
function getImageTop(node) {
let shiftY;
if (node.imageOffset != null) {
shiftY = node.imageOffset;
} else {
if (node.widgets?.length) {
const w = node.widgets[node.widgets.length - 1];
shiftY = w.last_y;
if (w.computeSize) {
shiftY += w.computeSize()[1] + 4;
} else {
shiftY += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
} else {
shiftY = node.computeSize()[1];
}
}
return shiftY;
}
node.prototype.setSizeForImage = function () {
const minHeight = getImageTop(this) + 220;
if (this.size[1] < minHeight) {
this.setSize([this.size[0], minHeight]);
}
};
node.prototype.onDrawBackground = function (ctx) { node.prototype.onDrawBackground = function (ctx) {
if (!this.flags.collapsed) { if (!this.flags.collapsed) {
const output = app.nodeOutputs[this.id + ""]; const output = app.nodeOutputs[this.id + ""];
@ -283,9 +311,7 @@ export class ComfyApp {
).then((imgs) => { ).then((imgs) => {
if (this.images === output.images) { if (this.images === output.images) {
this.imgs = imgs.filter(Boolean); this.imgs = imgs.filter(Boolean);
if (this.size[1] < 100) { this.setSizeForImage?.();
this.size[1] = 250;
}
app.graph.setDirtyCanvas(true); app.graph.setDirtyCanvas(true);
} }
}); });
@ -310,12 +336,7 @@ export class ComfyApp {
this.imageIndex = imageIndex = 0; this.imageIndex = imageIndex = 0;
} }
let shiftY; const shiftY = getImageTop(this);
if (this.imageOffset != null) {
shiftY = this.imageOffset;
} else {
shiftY = this.computeSize()[1];
}
let dw = this.size[0]; let dw = this.size[0];
let dh = this.size[1]; let dh = this.size[1];
@ -703,7 +724,7 @@ export class ComfyApp {
ctx.globalAlpha = 0.8; ctx.globalAlpha = 0.8;
ctx.beginPath(); ctx.beginPath();
if (shape == LiteGraph.BOX_SHAPE) if (shape == LiteGraph.BOX_SHAPE)
ctx.rect(-6, -6 + LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT); ctx.rect(-6, -6 - LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT);
else if (shape == LiteGraph.ROUND_SHAPE || (shape == LiteGraph.CARD_SHAPE && node.flags.collapsed)) else if (shape == LiteGraph.ROUND_SHAPE || (shape == LiteGraph.CARD_SHAPE && node.flags.collapsed))
ctx.roundRect( ctx.roundRect(
-6, -6,
@ -715,12 +736,11 @@ export class ComfyApp {
else if (shape == LiteGraph.CARD_SHAPE) else if (shape == LiteGraph.CARD_SHAPE)
ctx.roundRect( ctx.roundRect(
-6, -6,
-6 + LiteGraph.NODE_TITLE_HEIGHT, -6 - LiteGraph.NODE_TITLE_HEIGHT,
12 + size[0] + 1, 12 + size[0] + 1,
12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT,
this.round_radius * 2, [this.round_radius * 2, this.round_radius * 2, 2, 2]
2 );
);
else if (shape == LiteGraph.CIRCLE_SHAPE) else if (shape == LiteGraph.CIRCLE_SHAPE)
ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI * 2); ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI * 2);
ctx.strokeStyle = color; ctx.strokeStyle = color;
@ -972,8 +992,10 @@ export class ComfyApp {
loadGraphData(graphData) { loadGraphData(graphData) {
this.clean(); this.clean();
let reset_invalid_values = false;
if (!graphData) { if (!graphData) {
graphData = structuredClone(defaultGraph); graphData = structuredClone(defaultGraph);
reset_invalid_values = true;
} }
const missingNodeTypes = []; const missingNodeTypes = [];
@ -1059,6 +1081,13 @@ export class ComfyApp {
} }
} }
} }
if (reset_invalid_values) {
if (widget.type == "combo") {
if (!widget.options.values.includes(widget.value) && widget.options.values.length > 0) {
widget.value = widget.options.values[0];
}
}
}
} }
} }

View File

@ -261,20 +261,13 @@ export const ComfyWidgets = {
let uploadWidget; let uploadWidget;
function showImage(name) { function showImage(name) {
// Position the image somewhere sensible
if (!node.imageOffset) {
node.imageOffset = uploadWidget.last_y ? uploadWidget.last_y + 25 : 75;
}
const img = new Image(); const img = new Image();
img.onload = () => { img.onload = () => {
node.imgs = [img]; node.imgs = [img];
app.graph.setDirtyCanvas(true); app.graph.setDirtyCanvas(true);
}; };
img.src = `/view?filename=${name}&type=input`; img.src = `/view?filename=${name}&type=input`;
if ((node.size[1] - node.imageOffset) < 100) { node.setSizeForImage?.();
node.size[1] = 250 + node.imageOffset;
}
} }
// Add our own callback to the combo widget to render an image when it changes // Add our own callback to the combo widget to render an image when it changes

View File

@ -120,7 +120,7 @@ body {
.comfy-menu > button, .comfy-menu > button,
.comfy-menu-btns button, .comfy-menu-btns button,
.comfy-menu .comfy-list button, .comfy-menu .comfy-list button,
.comfy-modal button{ .comfy-modal button {
color: var(--input-text); color: var(--input-text);
background-color: var(--comfy-input-bg); background-color: var(--comfy-input-bg);
border-radius: 8px; border-radius: 8px;
@ -129,6 +129,15 @@ body {
margin-top: 2px; margin-top: 2px;
} }
.comfy-menu > button:hover,
.comfy-menu-btns button:hover,
.comfy-menu .comfy-list button:hover,
.comfy-modal button:hover,
.comfy-settings-btn:hover {
filter: brightness(1.2);
cursor: pointer;
}
.comfy-menu span.drag-handle { .comfy-menu span.drag-handle {
width: 10px; width: 10px;
height: 20px; height: 20px;
@ -248,8 +257,11 @@ button.comfy-queue-btn {
} }
} }
/* Input popup */
.graphdialog { .graphdialog {
min-height: 1em; min-height: 1em;
background-color: var(--comfy-menu-bg);
} }
.graphdialog .name { .graphdialog .name {
@ -273,15 +285,66 @@ button.comfy-queue-btn {
border-radius: 12px 0 0 12px; border-radius: 12px 0 0 12px;
} }
/* Context menu */
.litegraph .litemenu-entry.has_submenu { .litegraph .litemenu-entry.has_submenu {
position: relative; position: relative;
padding-right: 20px; padding-right: 20px;
} }
.litemenu-entry.has_submenu::after { .litemenu-entry.has_submenu::after {
content: ">"; content: ">";
position: absolute; position: absolute;
top: 0; top: 0;
right: 2px; right: 2px;
} }
.litegraph.litecontextmenu,
.litegraph.litecontextmenu.dark {
z-index: 9999 !important;
background-color: var(--comfy-menu-bg) !important;
filter: brightness(95%);
}
.litegraph.litecontextmenu .litemenu-entry:hover:not(.disabled):not(.separator) {
background-color: var(--comfy-menu-bg) !important;
filter: brightness(155%);
color: var(--input-text);
}
.litegraph.litecontextmenu .litemenu-entry.submenu,
.litegraph.litecontextmenu.dark .litemenu-entry.submenu {
background-color: var(--comfy-menu-bg) !important;
color: var(--input-text);
}
.litegraph.litecontextmenu input {
background-color: var(--comfy-input-bg) !important;
color: var(--input-text) !important;
}
/* Search box */
.litegraph.litesearchbox {
z-index: 9999 !important;
background-color: var(--comfy-menu-bg) !important;
overflow: hidden;
}
.litegraph.litesearchbox input,
.litegraph.litesearchbox select {
background-color: var(--comfy-input-bg) !important;
color: var(--input-text);
}
.litegraph.lite-search-item {
color: var(--input-text);
background-color: var(--comfy-input-bg);
filter: brightness(80%);
padding-left: 0.2em;
}
.litegraph.lite-search-item.generic_type {
color: var(--input-text);
filter: brightness(50%);
}