Merge branch 'master' into bool-input

This commit is contained in:
missionfloyd 2023-05-19 02:13:59 -06:00 committed by GitHub
commit d9973a036e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
63 changed files with 6081 additions and 847 deletions

View File

@ -1,65 +0,0 @@
import pygit2
from datetime import datetime
import sys
def pull(repo, remote_name='origin', branch='master'):
for remote in repo.remotes:
if remote.name == remote_name:
remote.fetch()
remote_master_id = repo.lookup_reference('refs/remotes/origin/%s' % (branch)).target
merge_result, _ = repo.merge_analysis(remote_master_id)
# Up to date, do nothing
if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
return
# We can just fastforward
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
repo.checkout_tree(repo.get(remote_master_id))
try:
master_ref = repo.lookup_reference('refs/heads/%s' % (branch))
master_ref.set_target(remote_master_id)
except KeyError:
repo.create_branch(branch, repo.get(remote_master_id))
repo.head.set_target(remote_master_id)
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL:
repo.merge(remote_master_id)
if repo.index.conflicts is not None:
for conflict in repo.index.conflicts:
print('Conflicts found in:', conflict[0].path)
raise AssertionError('Conflicts, ahhhhh!!')
user = repo.default_signature
tree = repo.index.write_tree()
commit = repo.create_commit('HEAD',
user,
user,
'Merge!',
tree,
[repo.head.target, remote_master_id])
# We need to do this or git CLI will think we are still merging.
repo.state_cleanup()
else:
raise AssertionError('Unknown merge analysis result')
repo = pygit2.Repository(str(sys.argv[1]))
ident = pygit2.Signature('comfyui', 'comfy@ui')
try:
print("stashing current changes")
repo.stash(ident)
except KeyError:
print("nothing to stash")
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name))
repo.branches.local.create(backup_branch_name, repo.head.peel())
print("checking out master branch")
branch = repo.lookup_branch('master')
ref = repo.lookup_reference(branch.name)
repo.checkout(ref)
print("pulling latest changes")
pull(repo)
print("Done!")

View File

@ -1,2 +0,0 @@
..\python_embeded\python.exe .\update.py ..\ComfyUI\
pause

View File

@ -1,3 +1,3 @@
..\python_embeded\python.exe .\update.py ..\ComfyUI\
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -r ../ComfyUI/requirements.txt pygit2
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2
pause

View File

@ -1,27 +0,0 @@
HOW TO RUN:
if you have a NVIDIA gpu:
run_nvidia_gpu.bat
To run it in slow CPU mode:
run_cpu.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
RECOMMENDED WAY TO UPDATE:
To update the ComfyUI code: update\update_comfyui.bat
To update ComfyUI with the python dependencies:
update\update_comfyui_and_python_dependencies.bat

View File

@ -1,2 +0,0 @@
.\python_embeded\python.exe -s ComfyUI\main.py --cpu --windows-standalone-build
pause

View File

@ -0,0 +1,30 @@
name: "Windows Release cu118 dependencies 2"
on:
workflow_dispatch:
# push:
# branches:
# - master
jobs:
build_dependencies:
runs-on: windows-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10.9'
- shell: bash
run: |
python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
mv temp_wheel_dir cu118_python_deps
tar cf cu118_python_deps.tar cu118_python_deps
- uses: actions/cache/save@v3
with:
path: cu118_python_deps.tar
key: ${{ runner.os }}-build-cu118

View File

@ -19,21 +19,21 @@ jobs:
fetch-depth: 0
- uses: actions/setup-python@v4
with:
python-version: '3.10.9'
python-version: '3.11.3'
- shell: bash
run: |
cd ..
cp -r ComfyUI ComfyUI_copy
curl https://www.python.org/ftp/python/3.10.9/python-3.10.9-embed-amd64.zip -o python_embeded.zip
curl https://www.python.org/ftp/python/3.11.3/python-3.11.3-embed-amd64.zip -o python_embeded.zip
unzip python_embeded.zip -d python_embeded
cd python_embeded
echo 'import site' >> ./python310._pth
echo 'import site' >> ./python311._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python310._pth
sed -i '1i../ComfyUI' ./python311._pth
cd ..
@ -46,6 +46,8 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/nightly/update_windows/* ./update/
cp -r ComfyUI/.ci/nightly/windows_base_files/* ./
cd ..

View File

@ -7,6 +7,8 @@ A powerful and modular stable diffusion GUI and backend.
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
### [Installing ComfyUI](#installing)
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x and SD2.x
@ -17,6 +19,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
- Embeddings/Textual inversion
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
- Loading full workflows (with seeds) from generated PNG files.
- Saving/Loading workflows as Json files.
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
@ -25,6 +28,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- Starts up very fast.
- Works fully offline: will never download anything.
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
@ -32,14 +36,29 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
## Shortcuts
- **Ctrl + A** select all nodes
- **Ctrl + M** mute/unmute selected nodes
- **Delete** or **Backspace** delete selected nodes
- **Space** Holding space key while moving the cursor moves the canvas around. It works when holding the mouse button down so it is easier to connect different nodes when the canvas gets too large.
- **Ctrl/Shift + Click** Add clicked node to selection.
- **Ctrl + C/Ctrl + V** - Copy and paste selected nodes, without maintaining the connection to the outputs of unselected nodes.
- **Ctrl + C/Ctrl + Shift + V** - Copy and paste selected nodes, and maintaining the connection from the outputs of unselected nodes to the inputs of the newly pasted nodes.
- Holding **Shift** and drag selected nodes - Move multiple selected nodes at the same time.
| Keybind | Explanation |
| - | - |
| Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + S | Save workflow |
| Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes |
| Ctrl + M | Mute/unmute selected nodes |
| Delete/Backspace | Delete selected nodes |
| Ctrl + Delete/Backspace | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor |
| Ctrl/Shift + Click | Add clicked node to selection |
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| Ctrl + C/Ctrl + Shift + V| Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| Shift + Drag | Move multiple selected nodes at the same time |
| Ctrl + D | Load default graph |
| Q | Toggle visibility of the queue |
| H | Toggle visibility of history |
| R | Refresh graph |
| Double-Click LMB | Open node quick search palette |
Ctrl can also be replaced with Cmd instead for MacOS users
# Installing
@ -69,7 +88,7 @@ Put your VAE in: models/vae
At the time of writing this pytorch has issues with python versions higher than 3.10 so make sure your python/pip versions are 3.10.
### AMD (Linux only)
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```

View File

@ -5,17 +5,17 @@ import torch
import torch as th
import torch.nn as nn
from ldm.modules.diffusionmodules.util import (
from ..ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module,
timestep_embedding,
)
from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.util import log_txt_as_img, exists, instantiate_from_config
from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ..ldm.models.diffusion.ddpm import LatentDiffusion
from ..ldm.util import log_txt_as_img, exists, instantiate_from_config
class ControlledUnetModel(UNetModel):

View File

@ -7,9 +7,11 @@ parser.add_argument("--port", type=int, default=8188, help="Set the listen port.
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
@ -29,3 +31,6 @@ parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
args = parser.parse_args()
if args.windows_standalone_build:
args.auto_launch = True

View File

@ -712,7 +712,7 @@ class UniPC:
def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform',
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
atol=0.0078, rtol=0.05, corrector=False,
atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
):
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
@ -723,7 +723,7 @@ class UniPC:
# timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
assert timesteps.shape[0] - 1 == steps
# with torch.no_grad():
for step_index in trange(steps):
for step_index in trange(steps, disable=disable_pbar):
if self.noise_mask is not None:
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
if step_index == 0:
@ -766,6 +766,8 @@ class UniPC:
if model_x is None:
model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x
if callback is not None:
callback(step_index, model_prev_list[-1], x, steps)
else:
raise NotImplementedError()
if denoise_to_zero:
@ -833,7 +835,7 @@ def expand_dims(v, dims):
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=None, noise_mask=None, variant='bh1'):
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
to_zero = False
if sigmas[-1] == 0:
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
@ -877,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
order = min(3, len(timesteps) - 1)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
if not to_zero:
x /= ns.marginal_alpha(timesteps[-1])
return x

360
comfy/gligen.py Normal file
View File

@ -0,0 +1,360 @@
import torch
from torch import nn, einsum
from .ldm.modules.attention import CrossAttention
from inspect import isfunction
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * torch.nn.functional.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class GatedCrossAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()
self.attn = CrossAttention(
query_dim=query_dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head)
self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
# this can be useful: we can externally change magnitude of tanh(alpha)
# for example, when it is set to 0, then the entire model is same as
# original one
self.scale = 1
def forward(self, x, objs):
x = x + self.scale * \
torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
x = x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class GatedSelfAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()
# we need a linear projection since we need cat visual feature and obj
# feature
self.linear = nn.Linear(context_dim, query_dim)
self.attn = CrossAttention(
query_dim=query_dim,
context_dim=query_dim,
heads=n_heads,
dim_head=d_head)
self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
# this can be useful: we can externally change magnitude of tanh(alpha)
# for example, when it is set to 0, then the entire model is same as
# original one
self.scale = 1
def forward(self, x, objs):
N_visual = x.shape[1]
objs = self.linear(objs)
x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
x = x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class GatedSelfAttentionDense2(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()
# we need a linear projection since we need cat visual feature and obj
# feature
self.linear = nn.Linear(context_dim, query_dim)
self.attn = CrossAttention(
query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
# this can be useful: we can externally change magnitude of tanh(alpha)
# for example, when it is set to 0, then the entire model is same as
# original one
self.scale = 1
def forward(self, x, objs):
B, N_visual, _ = x.shape
B, N_ground, _ = objs.shape
objs = self.linear(objs)
# sanity check
size_v = math.sqrt(N_visual)
size_g = math.sqrt(N_ground)
assert int(size_v) == size_v, "Visual tokens must be square rootable"
assert int(size_g) == size_g, "Grounding tokens must be square rootable"
size_v = int(size_v)
size_g = int(size_g)
# select grounding token and resize it to visual token size as residual
out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
:, N_visual:, :]
out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
out = torch.nn.functional.interpolate(
out, (size_v, size_v), mode='bicubic')
residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)
# add residual to visual feature
x = x + self.scale * torch.tanh(self.alpha_attn) * residual
x = x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class FourierEmbedder():
def __init__(self, num_freqs=64, temperature=100):
self.num_freqs = num_freqs
self.temperature = temperature
self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
@torch.no_grad()
def __call__(self, x, cat_dim=-1):
"x: arbitrary shape of tensor. dim: cat dim"
out = []
for freq in self.freq_bands:
out.append(torch.sin(freq * x))
out.append(torch.cos(freq * x))
return torch.cat(out, cat_dim)
class PositionNet(nn.Module):
def __init__(self, in_dim, out_dim, fourier_freqs=8):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
self.linears = nn.Sequential(
nn.Linear(self.in_dim + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(
torch.zeros([self.in_dim]))
self.null_position_feature = torch.nn.Parameter(
torch.zeros([self.position_dim]))
def forward(self, boxes, masks, positive_embeddings):
B, N, _ = boxes.shape
masks = masks.unsqueeze(-1)
# embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
xyxy_null = self.null_position_feature.view(1, 1, -1)
# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * \
masks + (1 - masks) * positive_null
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
objs = self.linears(
torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
assert objs.shape == torch.Size([B, N, self.out_dim])
return objs
class Gligen(nn.Module):
def __init__(self, modules, position_net, key_dim):
super().__init__()
self.module_list = nn.ModuleList(modules)
self.position_net = position_net
self.key_dim = key_dim
self.max_objs = 30
self.lowvram = False
def _set_position(self, boxes, masks, positive_embeddings):
if self.lowvram == True:
self.position_net.to(boxes.device)
objs = self.position_net(boxes, masks, positive_embeddings)
if self.lowvram == True:
self.position_net.cpu()
def func_lowvram(key, x):
module = self.module_list[key]
module.to(x.device)
r = module(x, objs)
module.cpu()
return r
return func_lowvram
else:
def func(key, x):
module = self.module_list[key]
return module(x, objs)
return func
def set_position(self, latent_image_shape, position_params, device):
batch, c, h, w = latent_image_shape
masks = torch.zeros([self.max_objs], device="cpu")
boxes = []
positive_embeddings = []
for p in position_params:
x1 = (p[4]) / w
y1 = (p[3]) / h
x2 = (p[4] + p[2]) / w
y2 = (p[3] + p[1]) / h
masks[len(boxes)] = 1.0
boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
positive_embeddings += [p[0]]
append_boxes = []
append_conds = []
if len(boxes) < self.max_objs:
append_boxes = [torch.zeros(
[self.max_objs - len(boxes), 4], device="cpu")]
append_conds = [torch.zeros(
[self.max_objs - len(boxes), self.key_dim], device="cpu")]
box_out = torch.cat(
boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
masks = masks.unsqueeze(0).repeat(batch, 1)
conds = torch.cat(positive_embeddings +
append_conds).unsqueeze(0).repeat(batch, 1, 1)
return self._set_position(
box_out.to(device),
masks.to(device),
conds.to(device))
def set_empty(self, latent_image_shape, device):
batch, c, h, w = latent_image_shape
masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
box_out = torch.zeros([self.max_objs, 4],
device="cpu").repeat(batch, 1, 1)
conds = torch.zeros([self.max_objs, self.key_dim],
device="cpu").repeat(batch, 1, 1)
return self._set_position(
box_out.to(device),
masks.to(device),
conds.to(device))
def set_lowvram(self, value=True):
self.lowvram = value
def cleanup(self):
self.lowvram = False
def get_models(self):
return [self]
def load_gligen(sd):
sd_k = sd.keys()
output_list = []
key_dim = 768
for a in ["input_blocks", "middle_block", "output_blocks"]:
for b in range(20):
k_temp = filter(lambda k: "{}.{}.".format(a, b)
in k and ".fuser." in k, sd_k)
k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)
n_sd = {}
for k in k_temp:
n_sd[k[1]] = sd[k[0]]
if len(n_sd) > 0:
query_dim = n_sd["linear.weight"].shape[0]
key_dim = n_sd["linear.weight"].shape[1]
if key_dim == 768: # SD1.x
n_heads = 8
d_head = query_dim // n_heads
else:
d_head = 64
n_heads = query_dim // d_head
gated = GatedSelfAttentionDense(
query_dim, key_dim, n_heads, d_head)
gated.load_state_dict(n_sd, strict=False)
output_list.append(gated)
if "position_net.null_positive_feature" in sd_k:
in_dim = sd["position_net.null_positive_feature"].shape[0]
out_dim = sd["position_net.linears.4.weight"].shape[0]
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
w.position_net = PositionNet(in_dim, out_dim)
w.load_state_dict(sd, strict=False)
gligen = Gligen(output_list, w.position_net, key_dim)
return gligen

View File

@ -3,11 +3,11 @@ import torch
import torch.nn.functional as F
from contextlib import contextmanager
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.util import instantiate_from_config
from ldm.modules.ema import LitEma
from comfy.ldm.util import instantiate_from_config
from comfy.ldm.modules.ema import LitEma
# class AutoencoderKL(pl.LightningModule):
class AutoencoderKL(torch.nn.Module):

View File

@ -4,7 +4,7 @@ import torch
import numpy as np
from tqdm import tqdm
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
from comfy.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object):
@ -81,6 +81,7 @@ class DDIMSampler(object):
extra_args=None,
to_zero=True,
end_step=None,
disable_pbar=False,
**kwargs
):
self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose)
@ -103,7 +104,8 @@ class DDIMSampler(object):
denoise_function=denoise_function,
extra_args=extra_args,
to_zero=to_zero,
end_step=end_step
end_step=end_step,
disable_pbar=disable_pbar
)
return samples, intermediates
@ -185,7 +187,7 @@ class DDIMSampler(object):
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None):
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False):
device = self.model.betas.device
b = shape[0]
if x_T is None:
@ -204,7 +206,7 @@ class DDIMSampler(object):
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
# print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step)
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar)
for i, step in enumerate(iterator):
index = total_steps - i - 1

View File

@ -19,12 +19,12 @@ from tqdm import tqdm
from torchvision.utils import make_grid
# from pytorch_lightning.utilities.distributed import rank_zero_only
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler
from comfy.ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from comfy.ldm.modules.ema import LitEma
from comfy.ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ..autoencoder import IdentityFirstStage, AutoencoderKL
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from .ddim import DDIMSampler
__conditioning_keys__ = {'concat': 'c_concat',

View File

@ -6,10 +6,10 @@ from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any
from ldm.modules.diffusionmodules.util import checkpoint
from .diffusionmodules.util import checkpoint
from .sub_quadratic_attention import efficient_dot_product_attention
import model_management
from comfy import model_management
from . import tomesd
@ -21,7 +21,7 @@ if model_management.xformers_enabled():
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
from cli_args import args
from comfy.cli_args import args
def exists(val):
return val is not None
@ -163,13 +163,17 @@ class CrossAttentionBirchSan(nn.Module):
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
def forward(self, x, context=None, value=None, mask=None):
h = self.heads
query = self.to_q(x)
context = default(context, x)
key = self.to_k(context)
value = self.to_v(context)
if value is not None:
value = self.to_v(value)
else:
value = self.to_v(context)
del context, x
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
@ -256,13 +260,17 @@ class CrossAttentionDoggettx(nn.Module):
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
def forward(self, x, context=None, value=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context)
v_in = self.to_v(context)
if value is not None:
v_in = self.to_v(value)
del value
else:
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
@ -350,13 +358,17 @@ class CrossAttention(nn.Module):
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
def forward(self, x, context=None, value=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if value is not None:
v = self.to_v(value)
del value
else:
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
@ -402,11 +414,15 @@ class MemoryEfficientCrossAttention(nn.Module):
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None):
def forward(self, x, context=None, value=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if value is not None:
v = self.to_v(value)
del value
else:
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
@ -447,19 +463,19 @@ class CrossAttentionPytorch(nn.Module):
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None):
def forward(self, x, context=None, value=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if value is not None:
v = self.to_v(value)
del value
else:
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2),
(q, k, v),
)
@ -468,10 +484,7 @@ class CrossAttentionPytorch(nn.Module):
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head)
)
return self.to_out(out)
@ -510,16 +523,52 @@ class BasicTransformerBlock(nn.Module):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}):
current_index = None
if "current_index" in transformer_options:
current_index = transformer_options["current_index"]
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:
transformer_patches = {}
n = self.norm1(x)
if self.disable_self_attn:
context_attn1 = context
else:
context_attn1 = None
value_attn1 = None
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
if context_attn1 is None:
context_attn1 = n
value_attn1 = context_attn1
for p in patch:
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
n = u(self.attn1(m(n), context=context if self.disable_self_attn else None))
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
else:
n = self.attn1(n, context=context if self.disable_self_attn else None)
n = self.attn1(n, context=context_attn1, value=value_attn1)
x += n
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
x = p(current_index, x)
n = self.norm2(x)
n = self.attn2(n, context=context)
context_attn2 = context
value_attn2 = None
if "attn2_patch" in transformer_patches:
patch = transformer_patches["attn2_patch"]
value_attn2 = context_attn2
for p in patch:
n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2)
n = self.attn2(n, context=context_attn2, value=value_attn2)
x += n
x = self.ff(self.norm3(x)) + x

View File

@ -6,8 +6,8 @@ import numpy as np
from einops import rearrange
from typing import Optional, Any
from ldm.modules.attention import MemoryEfficientCrossAttention
import model_management
from ..attention import MemoryEfficientCrossAttention
from comfy import model_management
if model_management.xformers_enabled_vae():
import xformers

View File

@ -6,7 +6,7 @@ import torch as th
import torch.nn as nn
import torch.nn.functional as F
from ldm.modules.diffusionmodules.util import (
from .util import (
checkpoint,
conv_nd,
linear,
@ -15,8 +15,8 @@ from ldm.modules.diffusionmodules.util import (
normalization,
timestep_embedding,
)
from ldm.modules.attention import SpatialTransformer
from ldm.util import exists
from ..attention import SpatialTransformer
from comfy.ldm.util import exists
# dummy replace
@ -76,16 +76,31 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input.
"""
def forward(self, x, emb, context=None, transformer_options={}):
def forward(self, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in ts:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
transformer_options["current_index"] += 1
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
@ -105,14 +120,20 @@ class Upsample(nn.Module):
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
if output_shape is not None:
shape[1] = output_shape[3]
shape[2] = output_shape[4]
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
shape = [x.shape[2] * 2, x.shape[3] * 2]
if output_shape is not None:
shape[0] = output_shape[2]
shape[1] = output_shape[3]
x = F.interpolate(x, size=shape, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
@ -782,6 +803,8 @@ class UNetModel(nn.Module):
:return: an [N x C x ...] Tensor of outputs.
"""
transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
@ -795,13 +818,13 @@ class UNetModel(nn.Module):
h = x.type(self.dtype)
for id, module in enumerate(self.input_blocks):
h = module(h, emb, context, transformer_options)
h = forward_timestep_embed(module, h, emb, context, transformer_options)
if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
hs.append(h)
h = self.middle_block(h, emb, context, transformer_options)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0:
h += control['middle'].pop()
@ -811,9 +834,14 @@ class UNetModel(nn.Module):
ctrl = control['output'].pop()
if ctrl is not None:
hsp += ctrl
h = th.cat([h, hsp], dim=1)
del hsp
h = module(h, emb, context, transformer_options)
if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)

View File

@ -3,8 +3,8 @@ import torch.nn as nn
import numpy as np
from functools import partial
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
from ldm.util import default
from .util import extract_into_tensor, make_beta_schedule
from comfy.ldm.util import default
class AbstractLowScaleModel(nn.Module):

View File

@ -15,7 +15,7 @@ import torch.nn as nn
import numpy as np
from einops import repeat
from ldm.util import instantiate_from_config
from comfy.ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):

View File

@ -1,5 +1,5 @@
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from ldm.modules.diffusionmodules.openaimodel import Timestep
from ..diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from ..diffusionmodules.openaimodel import Timestep
import torch
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):

View File

@ -24,7 +24,7 @@ except ImportError:
from torch import Tensor
from typing import List
import model_management
from comfy import model_management
def dynamic_slice(
x: Tensor,

View File

@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
"""
B, N, _ = metric.shape
if r <= 0:
if r <= 0 or w == 1 or h == 1:
return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather

View File

@ -1,6 +1,6 @@
import psutil
from enum import Enum
from cli_args import args
from comfy.cli_args import args
class VRAMState(Enum):
CPU = 0
@ -20,15 +20,30 @@ total_vram_available_mb = -1
accelerate_enabled = False
xpu_available = False
directml_enabled = False
if args.directml is not None:
import torch_directml
directml_enabled = True
device_index = args.directml
if device_index < 0:
directml_device = torch_directml.device()
else:
directml_device = torch_directml.device(device_index)
print("Using directml with device:", torch_directml.device_name(device_index))
# torch_directml.disable_tiled_resources(True)
try:
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
except:
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
if directml_enabled:
total_vram = 4097 #TODO
else:
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
except:
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
if not args.normalvram and not args.cpu:
if total_vram <= 4096:
@ -112,6 +127,32 @@ if args.cpu:
print(f"Set vram state to: {vram_state.name}")
def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_torch_device_name(device):
if hasattr(device, 'type'):
return "{}".format(device.type)
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try:
print("Using device:", get_torch_device_name(get_torch_device()))
except:
print("Could not pick default device.")
current_loaded_model = None
current_gpu_controlnets = []
@ -133,6 +174,7 @@ def unload_model():
#never unload models from GPU on high vram
if vram_state != VRAMState.HIGH_VRAM:
current_loaded_model.model.cpu()
current_loaded_model.model_patches_to("cpu")
current_loaded_model.unpatch_model()
current_loaded_model = None
@ -156,6 +198,8 @@ def load_model_gpu(model):
except Exception as e:
model.unpatch_model()
raise e
model.model_patches_to(get_torch_device())
current_loaded_model = model
if vram_state == VRAMState.CPU:
pass
@ -176,16 +220,23 @@ def load_model_gpu(model):
model_accelerated = True
return current_loaded_model
def load_controlnet_gpu(models):
def load_controlnet_gpu(control_models):
global current_gpu_controlnets
global vram_state
if vram_state == VRAMState.CPU:
return
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
for m in control_models:
if hasattr(m, 'set_lowvram'):
m.set_lowvram(True)
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return
models = []
for m in control_models:
models += m.get_models()
for m in current_gpu_controlnets:
if m not in models:
m.cpu()
@ -208,18 +259,6 @@ def unload_if_low_vram(model):
return model.cpu()
return model
def get_torch_device():
global xpu_available
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_autocast_device(dev):
if hasattr(dev, 'type'):
return dev.type
@ -227,8 +266,14 @@ def get_autocast_device(dev):
def xformers_enabled():
global xpu_available
global directml_enabled
if vram_state == VRAMState.CPU:
return False
if xpu_available:
return False
if directml_enabled:
return False
return XFORMERS_IS_AVAILABLE
@ -240,10 +285,20 @@ def xformers_enabled_vae():
return XFORMERS_ENABLED_VAE
def pytorch_attention_enabled():
global ENABLE_PYTORCH_ATTENTION
return ENABLE_PYTORCH_ATTENTION
def pytorch_attention_flash_attention():
global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION:
#TODO: more reliable way of checking for flash attention?
if torch.version.cuda: #pytorch flash attention only works on Nvidia
return True
return False
def get_free_memory(dev=None, torch_free_too=False):
global xpu_available
global directml_enabled
if dev is None:
dev = get_torch_device()
@ -251,7 +306,10 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else:
if xpu_available:
if directml_enabled:
mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_torch = mem_free_total
elif xpu_available:
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
mem_free_torch = mem_free_total
else:
@ -273,7 +331,12 @@ def maximum_batch_area():
return 0
memory_free = get_free_memory() / (1024 * 1024)
area = ((memory_free - 1024) * 0.9) / (0.6)
if xformers_enabled() or pytorch_attention_flash_attention():
#TODO: this needs to be tweaked
area = 20 * memory_free
else:
#TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future
area = ((memory_free - 1024) * 0.9) / (0.6)
return int(max(area, 0))
def cpu_mode():
@ -286,9 +349,14 @@ def mps_mode():
def should_use_fp16():
global xpu_available
global directml_enabled
if FORCE_FP32:
return False
if directml_enabled:
return False
if cpu_mode() or mps_mode() or xpu_available:
return False #TODO ?
@ -307,6 +375,15 @@ def should_use_fp16():
return True
def soft_empty_cache():
global xpu_available
if xpu_available:
torch.xpu.empty_cache()
elif torch.cuda.is_available():
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
#TODO: might be cleaner to put this somewhere else
import threading

92
comfy/sample.py Normal file
View File

@ -0,0 +1,92 @@
import torch
import comfy.model_management
import comfy.samplers
import math
import numpy as np
def prepare_noise(latent_image, seed, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1]+1):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
return noises
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
if noise_mask.shape[0] < shape[0]:
noise_mask = noise_mask.repeat(math.ceil(shape[0] / noise_mask.shape[0]), 1, 1, 1)[:shape[0]]
noise_mask = noise_mask.to(device)
return noise_mask
def broadcast_cond(cond, batch, device):
"""broadcasts conditioning to the batch size"""
copy = []
for p in cond:
t = p[0]
if t.shape[0] < batch:
t = torch.cat([t] * batch)
t = t.to(device)
copy += [[t] + p[1:]]
return copy
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c[1]:
models += [c[1][model_type]]
return models
def load_additional_models(positive, negative):
"""loads additional models in positive and negative conditioning"""
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
gligen = [x[1] for x in gligen]
models = control_nets + gligen
comfy.model_management.load_controlnet_gpu(models)
return models
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
m.cleanup()
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False):
device = comfy.model_management.get_torch_device()
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise.shape, device)
real_model = None
comfy.model_management.load_model_gpu(model)
real_model = model.model
noise = noise.to(device)
latent_image = latent_image.to(device)
positive_copy = broadcast_cond(positive, noise.shape[0], device)
negative_copy = broadcast_cond(negative, noise.shape[0], device)
models = load_additional_models(positive, negative)
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar)
samples = samples.cpu()
cleanup_additional_models(models)
return samples

View File

@ -3,26 +3,13 @@ from .k_diffusion import external as k_diffusion_external
from .extra_samplers import uni_pc
import torch
import contextlib
import model_management
from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
import math
class CFGDenoiser(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
if len(uncond[0]) == len(cond[0]) and x.shape[0] * x.shape[2] * x.shape[3] < (96 * 96): #TODO check memory instead
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
else:
cond = self.inner_model(x, sigma, cond=cond)
uncond = self.inner_model(x, sigma, cond=uncond)
return uncond + (cond - uncond) * cond_scale
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
#The main sampling function shared by all the samplers
#Returns predicted noise
@ -36,25 +23,40 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
strength = cond[1]['strength']
adm_cond = None
if 'adm' in cond[1]:
adm_cond = cond[1]['adm']
if 'adm_encoded' in cond[1]:
adm_cond = cond[1]['adm_encoded']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
mult = torch.ones_like(input_x) * strength
if 'mask' in cond[1]:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if "mask_strength" in cond[1]:
mask_strength = cond[1]["mask_strength"]
mask = cond[1]['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)
mult = mask * strength
if 'mask' not in cond[1]:
rr = 8
if area[2] != 0:
for t in range(rr):
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
if (area[0] + area[2]) < x_in.shape[2]:
for t in range(rr):
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
if area[3] != 0:
for t in range(rr):
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
rr = 8
if area[2] != 0:
for t in range(rr):
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
if (area[0] + area[2]) < x_in.shape[2]:
for t in range(rr):
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
if area[3] != 0:
for t in range(rr):
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {}
conditionning['c_crossattn'] = cond[0]
if cond_concat_in is not None and len(cond_concat_in) > 0:
@ -70,7 +72,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
control = None
if 'control' in cond[1]:
control = cond[1]['control']
return (input_x, mult, conditionning, area, control)
patches = None
if 'gligen' in cond[1]:
gligen = cond[1]['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
if gligen_type == "position":
gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device)
else:
gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch]
return (input_x, mult, conditionning, area, control, patches)
def cond_equal_size(c1, c2):
if c1 is c2:
@ -78,8 +94,16 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if c1.keys() != c2.keys():
return False
if 'c_crossattn' in c1:
if c1['c_crossattn'].shape != c2['c_crossattn'].shape:
return False
s1 = c1['c_crossattn'].shape
s2 = c2['c_crossattn'].shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
if 'c_concat' in c1:
if c1['c_concat'].shape != c2['c_concat'].shape:
return False
@ -91,28 +115,49 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
def can_concat_cond(c1, c2):
if c1[0].shape != c2[0].shape:
return False
#control
if (c1[4] is None) != (c2[4] is None):
return False
if c1[4] is not None:
if c1[4] is not c2[4]:
return False
#patches
if (c1[5] is None) != (c2[5] is None):
return False
if (c1[5] is not None):
if c1[5] is not c2[5]:
return False
return cond_equal_size(c1[2], c2[2])
def cond_cat(c_list):
c_crossattn = []
c_concat = []
c_adm = []
crossattn_max_len = 0
for x in c_list:
if 'c_crossattn' in x:
c_crossattn.append(x['c_crossattn'])
c = x['c_crossattn']
if crossattn_max_len == 0:
crossattn_max_len = c.shape[1]
else:
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
c_crossattn.append(c)
if 'c_concat' in x:
c_concat.append(x['c_concat'])
if 'c_adm' in x:
c_adm.append(x['c_adm'])
out = {}
if len(c_crossattn) > 0:
out['c_crossattn'] = [torch.cat(c_crossattn)]
c_crossattn_out = []
for c in c_crossattn:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
c_crossattn_out.append(c)
if len(c_crossattn_out) > 0:
out['c_crossattn'] = [torch.cat(c_crossattn_out)]
if len(c_concat) > 0:
out['c_concat'] = [torch.cat(c_concat)]
if len(c_adm) > 0:
@ -166,6 +211,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
cond_or_uncond = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
@ -175,6 +221,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
area += [p[3]]
cond_or_uncond += [o[1]]
control = p[4]
patches = p[5]
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
@ -184,8 +231,22 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
c['transformer_options'] = model_options['transformer_options']
transformer_options = model_options['transformer_options'].copy()
if patches is not None:
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
else:
transformer_options["patches"] = patches
c['transformer_options'] = transformer_options
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
del input_x
@ -211,7 +272,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
return uncond + (cond - uncond) * cond_scale
if "sampler_cfg_function" in model_options:
return model_options["sampler_cfg_function"](cond, uncond, cond_scale)
else:
return uncond + (cond - uncond) * cond_scale
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
@ -276,6 +340,60 @@ def blank_inpaint_image_like(latent_image):
blank_image[:,3] *= 0.1380
return blank_image
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
b = masks.shape[0]
bounding_boxes = torch.zeros((b, 4), device=masks.device, dtype=torch.int)
is_empty = torch.zeros((b), device=masks.device, dtype=torch.bool)
for i in range(b):
mask = masks[i]
if mask.numel() == 0:
continue
if torch.max(mask != 0) == False:
is_empty[i] = True
continue
y, x = torch.where(mask)
bounding_boxes[i, 0] = torch.min(x)
bounding_boxes[i, 1] = torch.min(y)
bounding_boxes[i, 2] = torch.max(x)
bounding_boxes[i, 3] = torch.max(y)
return bounding_boxes, is_empty
def resolve_cond_masks(conditions, h, w, device):
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
for i in range(len(conditions)):
c = conditions[i]
if 'mask' in c[1]:
mask = c[1]['mask']
mask = mask.to(device=device)
modified = c[1].copy()
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
if mask.shape[2] != h or mask.shape[3] != w:
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
if modified.get("set_area_to_bounds", False):
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
boxes, is_empty = get_mask_aabb(bounds)
if is_empty[0]:
# Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway)
modified['area'] = (8, 8, 0, 0)
else:
box = boxes[0]
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
H = max(8, H)
W = max(8, W)
area = (int(H), int(W), int(Y), int(X))
modified['area'] = area
modified['mask'] = mask
conditions[i] = [c[0], modified]
def create_cond_with_same_area_if_none(conds, c):
if 'area' not in c[1]:
return
@ -306,8 +424,7 @@ def create_cond_with_same_area_if_none(conds, c):
n = c[1].copy()
conds += [[smallest[0], n]]
def apply_control_net_to_equal_area(conds, uncond):
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
cond_other = []
uncond_cnets = []
@ -315,15 +432,15 @@ def apply_control_net_to_equal_area(conds, uncond):
for t in range(len(conds)):
x = conds[t]
if 'area' not in x[1]:
if 'control' in x[1] and x[1]['control'] is not None:
cond_cnets.append(x[1]['control'])
if name in x[1] and x[1][name] is not None:
cond_cnets.append(x[1][name])
else:
cond_other.append((x, t))
for t in range(len(uncond)):
x = uncond[t]
if 'area' not in x[1]:
if 'control' in x[1] and x[1]['control'] is not None:
uncond_cnets.append(x[1]['control'])
if name in x[1] and x[1][name] is not None:
uncond_cnets.append(x[1][name])
else:
uncond_other.append((x, t))
@ -333,15 +450,16 @@ def apply_control_net_to_equal_area(conds, uncond):
for x in range(len(cond_cnets)):
temp = uncond_other[x % len(uncond_other)]
o = temp[0]
if 'control' in o[1] and o[1]['control'] is not None:
if name in o[1] and o[1][name] is not None:
n = o[1].copy()
n['control'] = cond_cnets[x]
n[name] = uncond_fill_func(cond_cnets, x)
uncond += [[o[0], n]]
else:
n = o[1].copy()
n['control'] = cond_cnets[x]
n[name] = uncond_fill_func(cond_cnets, x)
uncond[temp[1]] = [o[0], n]
def encode_adm(noise_augmentor, conds, batch_size, device):
for t in range(len(conds)):
x = conds[t]
@ -371,12 +489,13 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
else:
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
x[1] = x[1].copy()
x[1]["adm"] = torch.cat([adm_out] * batch_size)
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size)
return conds
class KSampler:
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
SCHEDULERS = ["normal", "karras", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
@ -403,7 +522,7 @@ class KSampler:
self.denoise = denoise
self.model_options = model_options
def _calculate_sigmas(self, steps):
def calculate_sigmas(self, steps):
sigmas = None
discard_penultimate_sigma = False
@ -412,13 +531,13 @@ class KSampler:
discard_penultimate_sigma = True
if self.scheduler == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, device=self.device)
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "normal":
sigmas = self.model_wrap.get_sigmas(steps).to(self.device)
sigmas = self.model_wrap.get_sigmas(steps)
elif self.scheduler == "simple":
sigmas = simple_scheduler(self.model_wrap, steps).to(self.device)
sigmas = simple_scheduler(self.model_wrap, steps)
elif self.scheduler == "ddim_uniform":
sigmas = ddim_scheduler(self.model_wrap, steps).to(self.device)
sigmas = ddim_scheduler(self.model_wrap, steps)
else:
print("error invalid scheduler", self.scheduler)
@ -429,15 +548,15 @@ class KSampler:
def set_steps(self, steps, denoise=None):
self.steps = steps
if denoise is None or denoise > 0.9999:
self.sigmas = self._calculate_sigmas(steps)
self.sigmas = self.calculate_sigmas(steps).to(self.device)
else:
new_steps = int(steps/denoise)
sigmas = self._calculate_sigmas(new_steps)
sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1):]
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None):
sigmas = self.sigmas
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False):
if sigmas is None:
sigmas = self.sigmas
sigma_min = self.sigma_min
if last_step is not None and last_step < (len(sigmas) - 1):
@ -457,13 +576,18 @@ class KSampler:
positive = positive[:]
negative = negative[:]
resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
#make sure each cond area has an opposite one with the same area
for c in positive:
create_cond_with_same_area_if_none(negative, c)
for c in negative:
create_cond_with_same_area_if_none(positive, c)
apply_control_net_to_equal_area(positive, negative)
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if self.model.model.diffusion_model.dtype == torch.float16:
precision_scope = torch.autocast
@ -499,9 +623,9 @@ class KSampler:
with precision_scope(model_management.get_autocast_device(self.device)):
if self.sampler == "uni_pc":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask)
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
elif self.sampler == "uni_pc_bh2":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2')
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
elif self.sampler == "ddim":
timesteps = []
for s in range(sigmas.shape[0]):
@ -509,6 +633,12 @@ class KSampler:
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask
ddim_callback = None
if callback is not None:
total_steps = len(timesteps) - 1
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
sampler = DDIMSampler(self.model, device=self.device)
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
@ -522,11 +652,13 @@ class KSampler:
eta=0.0,
x_T=z_enc,
x0=latent_image,
img_callback=ddim_callback,
denoise_function=sampling_function,
extra_args=extra_args,
mask=noise_mask,
to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1)
end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)
else:
extra_args["denoise_mask"] = denoise_mask
@ -535,13 +667,18 @@ class KSampler:
noise = noise * sigmas[0]
k_callback = None
total_steps = len(sigmas) - 1
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
if latent_image is not None:
noise += latent_image
if self.sampler == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args)
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
elif self.sampler == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args)
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args)
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
return samples.to(torch.float32)

View File

@ -2,9 +2,9 @@ import torch
import contextlib
import copy
import sd1_clip
import sd2_clip
import model_management
from . import sd1_clip
from . import sd2_clip
from comfy import model_management
from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL
import yaml
@ -13,6 +13,7 @@ from .t2i_adapter import adapter
from . import utils
from . import clip_vision
from . import gligen
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
m, u = model.load_state_dict(sd, strict=False)
@ -110,6 +111,8 @@ def load_lora(path, to_load):
loaded_keys.add(A_name)
loaded_keys.add(B_name)
######## loha
hada_w1_a_name = "{}.hada_w1_a".format(x)
hada_w1_b_name = "{}.hada_w1_b".format(x)
hada_w2_a_name = "{}.hada_w2_a".format(x)
@ -131,6 +134,54 @@ def load_lora(path, to_load):
loaded_keys.add(hada_w2_a_name)
loaded_keys.add(hada_w2_b_name)
######## lokr
lokr_w1_name = "{}.lokr_w1".format(x)
lokr_w2_name = "{}.lokr_w2".format(x)
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
lokr_t2_name = "{}.lokr_t2".format(x)
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
lokr_w1 = None
if lokr_w1_name in lora.keys():
lokr_w1 = lora[lokr_w1_name]
loaded_keys.add(lokr_w1_name)
lokr_w2 = None
if lokr_w2_name in lora.keys():
lokr_w2 = lora[lokr_w2_name]
loaded_keys.add(lokr_w2_name)
lokr_w1_a = None
if lokr_w1_a_name in lora.keys():
lokr_w1_a = lora[lokr_w1_a_name]
loaded_keys.add(lokr_w1_a_name)
lokr_w1_b = None
if lokr_w1_b_name in lora.keys():
lokr_w1_b = lora[lokr_w1_b_name]
loaded_keys.add(lokr_w1_b_name)
lokr_w2_a = None
if lokr_w2_a_name in lora.keys():
lokr_w2_a = lora[lokr_w2_a_name]
loaded_keys.add(lokr_w2_a_name)
lokr_w2_b = None
if lokr_w2_b_name in lora.keys():
lokr_w2_b = lora[lokr_w2_b_name]
loaded_keys.add(lokr_w2_b_name)
lokr_t2 = None
if lokr_t2_name in lora.keys():
lokr_t2 = lora[lokr_t2_name]
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
for x in lora.keys():
if x not in loaded_keys:
print("lora key not loaded", x)
@ -250,6 +301,32 @@ class ModelPatcher:
def set_model_tomesd(self, ratio):
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
def set_model_sampler_cfg_function(self, sampler_cfg_function):
self.model_options["sampler_cfg_function"] = sampler_cfg_function
def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"]
if "patches" not in to:
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch")
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
patch_list[i] = patch_list[i].to(device)
def model_dtype(self):
return self.model.diffusion_model.dtype
@ -288,6 +365,33 @@ class ModelPatcher:
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
elif len(v) == 8: #lokr
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(w1_a.float(), w1_b.float())
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(w2_a.float(), w2_b.float())
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float())
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device)
else: #loha
w1a = v[0]
w1b = v[1]
@ -342,10 +446,10 @@ class CLIP:
else:
params = {}
if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder":
if self.target_clip.endswith("FrozenOpenCLIPEmbedder"):
clip = sd2_clip.SD2ClipModel
tokenizer = sd2_clip.SD2Tokenizer
elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder":
elif self.target_clip.endswith("FrozenCLIPEmbedder"):
clip = sd1_clip.SD1ClipModel
tokenizer = sd1_clip.SD1Tokenizer
@ -372,10 +476,12 @@ class CLIP:
def clip_layer(self, layer_idx):
self.layer_idx = layer_idx
def encode(self, text):
def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode_from_tokens(self, tokens, return_pooled=False):
if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx)
tokens = self.tokenizer.tokenize_with_weights(text)
try:
self.patcher.patch_model()
cond = self.cond_stage_model.encode_token_weights(tokens)
@ -383,8 +489,16 @@ class CLIP:
except Exception as e:
self.patcher.unpatch_model()
raise e
if return_pooled:
eos_token_index = max(range(len(tokens[0])), key=tokens[0].__getitem__)
pooled = cond[:, eos_token_index]
return cond, pooled
return cond
def encode(self, text):
tokens = self.tokenize(text)
return self.encode_from_tokens(tokens)
class VAE:
def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None):
if config is None:
@ -400,11 +514,16 @@ class VAE:
self.device = device
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = utils.ProgressBar(steps)
decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0)
output = torch.clamp((
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) +
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) +
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8))
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar))
/ 3.0) / 2.0, min=0.0, max=1.0)
return output
@ -448,20 +567,23 @@ class VAE:
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4)
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4)
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4)
steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = utils.ProgressBar(steps)
samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples /= 3.0
self.first_stage_model = self.first_stage_model.cpu()
samples = samples.cpu()
return samples
def resize_image_to(tensor, target_latent_tensor, batched_number):
tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center")
target_batch_size = target_latent_tensor.shape[0]
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
print(current_batch_size, target_batch_size)
#print(current_batch_size, target_batch_size)
if current_batch_size == 1:
return tensor
@ -498,7 +620,9 @@ class ControlNet:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device)
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_model.dtype == torch.float16:
precision_scope = torch.autocast
@ -555,10 +679,10 @@ class ControlNet:
c.strength = self.strength
return c
def get_control_models(self):
def get_models(self):
out = []
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_control_models()
out += self.previous_controlnet.get_models()
out.append(self.control_model)
return out
@ -669,10 +793,14 @@ class T2IAdapter:
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.control_input = None
self.cond_hint = None
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device)
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device)
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_input is None:
self.t2i_model.to(self.device)
self.control_input = self.t2i_model(self.cond_hint)
self.t2i_model.cpu()
@ -728,10 +856,10 @@ class T2IAdapter:
del self.cond_hint
self.cond_hint = None
def get_control_models(self):
def get_models(self):
out = []
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_control_models()
out += self.previous_controlnet.get_models()
return out
def load_t2i_adapter(t2i_data):
@ -771,13 +899,20 @@ def load_clip(ckpt_path, embedding_directory=None):
clip_data = utils.load_torch_file(ckpt_path)
config = {}
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
else:
config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder'
config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder'
clip = CLIP(config=config, embedding_directory=embedding_directory)
clip.load_from_state_dict(clip_data)
return clip
def load_gligen(ckpt_path):
data = utils.load_torch_file(ckpt_path)
model = gligen.load_gligen(data)
if model_management.should_use_fp16():
model = model.half()
return model
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
@ -842,9 +977,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_clip:
clip_config = {}
if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys:
clip_config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
else:
clip_config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder'
clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder'
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
@ -865,7 +1000,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0]
noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2"
params["noise_schedule_config"] = noise_schedule_config
noise_aug_config['target'] = "ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation"
noise_aug_config['target'] = "comfy.ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation"
if size == 1280: #h
params["timestep_dim"] = 1024
elif size == 1024: #l
@ -917,19 +1052,19 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
if noise_aug_config is not None: #SD2.x unclip model
sd_config["noise_aug_config"] = noise_aug_config
sd_config["image_size"] = 96
sd_config["embedding_dropout"] = 0.25
sd_config["conditioning_key"] = 'crossattn-adm'
model_config["target"] = "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
elif unet_config["in_channels"] > 4: #inpainting model
sd_config["conditioning_key"] = "hybrid"
sd_config["finetune_keys"] = None
model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
else:
sd_config["conditioning_key"] = "crossattn"

View File

@ -2,6 +2,8 @@ import os
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
import torch
import traceback
import zipfile
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
@ -170,10 +172,39 @@ def unescape_important(text):
text = text.replace("\0\2", "(")
return text
def safe_load_embed_zip(embed_path):
with zipfile.ZipFile(embed_path) as myzip:
names = list(filter(lambda a: "data/" in a, myzip.namelist()))
names.reverse()
for n in names:
with myzip.open(n) as myfile:
data = myfile.read()
number = len(data) // 4
length_embed = 1024 #sd2.x
if number < 768:
continue
if number % 768 == 0:
length_embed = 768 #sd1.x
num_embeds = number // length_embed
embed = torch.frombuffer(data, dtype=torch.float)
out = embed.reshape((num_embeds, length_embed)).clone()
del embed
return out
def expand_directory_list(directories):
dirs = set()
for x in directories:
dirs.add(x)
for root, subdir, file in os.walk(x, followlinks=True):
dirs.add(root)
return list(dirs)
def load_embed(embedding_name, embedding_directory):
if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory]
embedding_directory = expand_directory_list(embedding_directory)
valid_file = None
for embed_dir in embedding_directory:
embed_path = os.path.join(embed_dir, embedding_name)
@ -194,19 +225,33 @@ def load_embed(embedding_name, embedding_directory):
embed_path = valid_file
if embed_path.lower().endswith(".safetensors"):
import safetensors.torch
embed = safetensors.torch.load_file(embed_path, device="cpu")
else:
if 'weights_only' in torch.load.__code__.co_varnames:
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
embed_out = None
try:
if embed_path.lower().endswith(".safetensors"):
import safetensors.torch
embed = safetensors.torch.load_file(embed_path, device="cpu")
else:
embed = torch.load(embed_path, map_location="cpu")
if 'string_to_param' in embed:
values = embed['string_to_param'].values()
else:
values = embed.values()
return next(iter(values))
if 'weights_only' in torch.load.__code__.co_varnames:
try:
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
except:
embed_out = safe_load_embed_zip(embed_path)
else:
embed = torch.load(embed_path, map_location="cpu")
except Exception as e:
print(traceback.format_exc())
print()
print("error loading embedding, skipping loading:", embedding_name)
return None
if embed_out is None:
if 'string_to_param' in embed:
values = embed['string_to_param'].values()
else:
values = embed.values()
embed_out = next(iter(values))
return embed_out
class SD1Tokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):
@ -224,60 +269,97 @@ class SD1Tokenizer:
self.inv_vocab = {v: k for k, v in vocab.items()}
self.embedding_directory = embedding_directory
self.max_word_length = 8
self.embedding_identifier = "embedding:"
def _try_get_embedding(self, embedding_name:str):
'''
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
'''
embed = load_embed(embedding_name, self.embedding_directory)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory)
return (embed, embedding_name[len(stripped):])
return (embed, "")
def tokenize_with_weights(self, text:str, return_word_ids=False):
'''
Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors.
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
Returned list has the dimensions NxM where M is the input size of CLIP
'''
if self.pad_with_end:
pad_token = self.end_token
else:
pad_token = 0
def tokenize_with_weights(self, text):
text = escape_important(text)
parsed_weights = token_weights(text, 1.0)
#tokenize words
tokens = []
for t in parsed_weights:
to_tokenize = unescape_important(t[0]).replace("\n", " ").split(' ')
while len(to_tokenize) > 0:
word = to_tokenize.pop(0)
temp_tokens = []
embedding_identifier = "embedding:"
if word.startswith(embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(embedding_identifier):].strip('\n')
embed = load_embed(embedding_name, self.embedding_directory)
for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize:
#if we find an embedding, deal with the embedding
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory)
if embed is not None:
to_tokenize.insert(0, embedding_name[len(stripped):])
if embed is not None:
if len(embed.shape) == 1:
temp_tokens += [(embed, t[1])]
else:
for x in range(embed.shape[0]):
temp_tokens += [(embed[x], t[1])]
print(f"warning, embedding:{embedding_name} does not exist, ignoring")
else:
print("warning, embedding:{} does not exist, ignoring".format(embedding_name))
elif len(word) > 0:
tt = self.tokenizer(word)["input_ids"][1:-1]
for x in tt:
temp_tokens += [(x, t[1])]
tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section)
if len(embed.shape) == 1:
tokens.append([(embed, weight)])
else:
tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word
if leftover != "":
word = leftover
else:
continue
#parse word
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]])
#try not to split words in different sections
if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length):
for x in range(tokens_left):
tokens += [(self.end_token, 1.0)]
tokens += temp_tokens
#reshape token array to CLIP input size
batched_tokens = []
batch = [(self.start_token, 1.0, 0)]
batched_tokens.append(batch)
for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length
out_tokens = []
for x in range(0, len(tokens), self.max_tokens_per_section):
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))]
o_token += [(self.end_token, 1.0)]
if self.pad_with_end:
o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token))
else:
o_token +=[(0, 1.0)] * (self.max_length - len(o_token))
while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_length - len(batch) - 1
#break word in two and add end token
if is_large:
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:]
#add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
#start new batch
batch = [(self.start_token, 1.0, 0)]
batched_tokens.append(batch)
else:
batch.extend([(t,w,i+1) for t,w in t_group])
t_group = []
out_tokens += [o_token]
#fill last batch
batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
return batched_tokens
return out_tokens
def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))

View File

@ -1,4 +1,4 @@
import sd1_clip
from comfy import sd1_clip
import torch
import os

View File

@ -56,7 +56,12 @@ class Downsample(nn.Module):
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
if not self.use_conv:
padding = [x.shape[2] % 2, x.shape[3] % 2]
self.op.padding = padding
x = self.op(x)
return x
class ResnetBlock(nn.Module):

View File

@ -1,11 +1,15 @@
import torch
import math
def load_torch_file(ckpt):
def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"):
import safetensors.torch
sd = safetensors.torch.load_file(ckpt, device="cpu")
else:
pl_sd = torch.load(ckpt, map_location="cpu")
if safe_load:
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
@ -59,8 +63,11 @@ def common_upscale(samples, width, height, upscale_method, crop):
s = samples
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
@torch.inference_mode()
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3):
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None):
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu")
for b in range(samples.shape[0]):
s = samples[b:b+1]
@ -80,6 +87,33 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask
out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask
if pbar is not None:
pbar.update(1)
output[b:b+1] = out/out_div
return output
PROGRESS_BAR_HOOK = None
def set_progress_bar_global_hook(function):
global PROGRESS_BAR_HOOK
PROGRESS_BAR_HOOK = function
class ProgressBar:
def __init__(self, total):
global PROGRESS_BAR_HOOK
self.total = total
self.current = 0
self.hook = PROGRESS_BAR_HOOK
def update_absolute(self, value, total=None):
if total is not None:
self.total = total
if value > self.total:
value = self.total
self.current = value
if self.hook is not None:
self.hook(self.current, self.total)
def update(self, value):
self.update_absolute(self.current + value)

View File

@ -4,7 +4,10 @@
from __future__ import annotations
from collections import OrderedDict
from typing import Literal
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import torch
import torch.nn as nn

View File

@ -0,0 +1,110 @@
import comfy.utils
import folder_paths
import torch
def load_hypernetwork_patch(path, strength):
sd = comfy.utils.load_torch_file(path, safe_load=True)
activation_func = sd.get('activation_func', 'linear')
is_layer_norm = sd.get('is_layer_norm', False)
use_dropout = sd.get('use_dropout', False)
activate_output = sd.get('activate_output', False)
last_layer_dropout = sd.get('last_layer_dropout', False)
valid_activation = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
"swish": torch.nn.Hardswish,
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
"softsign": torch.nn.Softsign,
}
if activation_func not in valid_activation:
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
return None
out = {}
for d in sd:
try:
dim = int(d)
except:
continue
output = []
for index in [0, 1]:
attn_weights = sd[dim][index]
keys = attn_weights.keys()
linears = filter(lambda a: a.endswith(".weight"), keys)
linears = list(map(lambda a: a[:-len(".weight")], linears))
layers = []
for i in range(len(linears)):
lin_name = linears[i]
last_layer = (i == (len(linears) - 1))
penultimate_layer = (i == (len(linears) - 2))
lin_weight = attn_weights['{}.weight'.format(lin_name)]
lin_bias = attn_weights['{}.bias'.format(lin_name)]
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
layers.append(layer)
if activation_func != "linear":
if (not last_layer) or (activate_output):
layers.append(valid_activation[activation_func]())
if is_layer_norm:
layers.append(torch.nn.LayerNorm(lin_weight.shape[0]))
if use_dropout:
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
layers.append(torch.nn.Dropout(p=0.3))
output.append(torch.nn.Sequential(*layers))
out[dim] = torch.nn.ModuleList(output)
class hypernetwork_patch:
def __init__(self, hypernet, strength):
self.hypernet = hypernet
self.strength = strength
def __call__(self, current_index, q, k, v):
dim = k.shape[-1]
if dim in self.hypernet:
hn = self.hypernet[dim]
k = k + hn[0](k) * self.strength
v = v + hn[1](v) * self.strength
return q, k, v
def to(self, device):
for d in self.hypernet.keys():
self.hypernet[d] = self.hypernet[d].to(device)
return self
return hypernetwork_patch(out, strength)
class HypernetworkLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_hypernetwork"
CATEGORY = "loaders"
def load_hypernetwork(self, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None:
model_hypernetwork.set_model_attn1_patch(patch)
model_hypernetwork.set_model_attn2_patch(patch)
return (model_hypernetwork,)
NODE_CLASS_MAPPINGS = {
"HypernetworkLoader": HypernetworkLoader
}

262
comfy_extras/nodes_mask.py Normal file
View File

@ -0,0 +1,262 @@
import torch
from nodes import MAX_RESOLUTION
class LatentCompositeMasked:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"destination": ("LATENT",),
"source": ("LATENT",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
},
"optional": {
"mask": ("MASK",),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "composite"
CATEGORY = "latent"
def composite(self, destination, source, x, y, mask = None):
output = destination.copy()
destination = destination["samples"].clone()
source = source["samples"]
x = max(-source.shape[3] * 8, min(x, destination.shape[3] * 8))
y = max(-source.shape[2] * 8, min(y, destination.shape[2] * 8))
left, top = (x // 8, y // 8)
right, bottom = (left + source.shape[3], top + source.shape[2],)
if mask is None:
mask = torch.ones_like(source)
else:
mask = mask.clone()
mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = mask.repeat((source.shape[0], source.shape[1], 1, 1))
# calculate the bounds of the source that will be overlapping the destination
# this prevents the source trying to overwrite latent pixels that are out of bounds
# of the destination
visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)
mask = mask[:, :, :visible_height, :visible_width]
inverse_mask = torch.ones_like(mask) - mask
source_portion = mask * source[:, :, :visible_height, :visible_width]
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
destination[:, :, top:bottom, left:right] = source_portion + destination_portion
output["samples"] = destination
return (output,)
class MaskToImage:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"mask": ("MASK",),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("IMAGE",)
FUNCTION = "mask_to_image"
def mask_to_image(self, mask):
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
return (result,)
class ImageToMask:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"channel": (["red", "green", "blue"],),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "image_to_mask"
def image_to_mask(self, image, channel):
channels = ["red", "green", "blue"]
mask = image[0, :, :, channels.index(channel)]
return (mask,)
class SolidMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "solid"
def solid(self, value, width, height):
out = torch.full((height, width), value, dtype=torch.float32, device="cpu")
return (out,)
class InvertMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "invert"
def invert(self, mask):
out = 1.0 - mask
return (out,)
class CropMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "crop"
def crop(self, mask, x, y, width, height):
out = mask[y:y + height, x:x + width]
return (out,)
class MaskComposite:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"destination": ("MASK",),
"source": ("MASK",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"operation": (["multiply", "add", "subtract"],),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "combine"
def combine(self, destination, source, x, y, operation):
output = destination.clone()
left, top = (x, y,)
right, bottom = (min(left + source.shape[1], destination.shape[1]), min(top + source.shape[0], destination.shape[0]))
visible_width, visible_height = (right - left, bottom - top,)
source_portion = source[:visible_height, :visible_width]
destination_portion = destination[top:bottom, left:right]
if operation == "multiply":
output[top:bottom, left:right] = destination_portion * source_portion
elif operation == "add":
output[top:bottom, left:right] = destination_portion + source_portion
elif operation == "subtract":
output[top:bottom, left:right] = destination_portion - source_portion
output = torch.clamp(output, 0.0, 1.0)
return (output,)
class FeatherMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "feather"
def feather(self, mask, left, top, right, bottom):
output = mask.clone()
left = min(left, output.shape[1])
right = min(right, output.shape[1])
top = min(top, output.shape[0])
bottom = min(bottom, output.shape[0])
for x in range(left):
feather_rate = (x + 1.0) / left
output[:, x] *= feather_rate
for x in range(right):
feather_rate = (x + 1) / right
output[:, -x] *= feather_rate
for y in range(top):
feather_rate = (y + 1) / top
output[y, :] *= feather_rate
for y in range(bottom):
feather_rate = (y + 1) / bottom
output[-y, :] *= feather_rate
return (output,)
NODE_CLASS_MAPPINGS = {
"LatentCompositeMasked": LatentCompositeMasked,
"MaskToImage": MaskToImage,
"ImageToMask": ImageToMask,
"SolidMask": SolidMask,
"InvertMask": InvertMask,
"CropMask": CropMask,
"MaskComposite": MaskComposite,
"FeatherMask": FeatherMask,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageToMask": "Convert Image to Mask",
"MaskToImage": "Convert Mask to Image",
}

View File

@ -0,0 +1,108 @@
import torch
class LatentRebatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "latents": ("LATENT",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
INPUT_IS_LIST = True
OUTPUT_IS_LIST = (True, )
FUNCTION = "rebatch"
CATEGORY = "latent/batch"
@staticmethod
def get_batch(latents, list_ind, offset):
'''prepare a batch out of the list of latents'''
samples = latents[list_ind]['samples']
shape = samples.shape
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
if mask.shape[0] < samples.shape[0]:
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
if 'batch_index' in latents[list_ind]:
batch_inds = latents[list_ind]['batch_index']
else:
batch_inds = [x+offset for x in range(shape[0])]
return samples, mask, batch_inds
@staticmethod
def get_slices(indexable, num, batch_size):
'''divides an indexable object into num slices of length batch_size, and a remainder'''
slices = []
for i in range(num):
slices.append(indexable[i*batch_size:(i+1)*batch_size])
if num * batch_size < len(indexable):
return slices, indexable[num * batch_size:]
else:
return slices, None
@staticmethod
def slice_batch(batch, num, batch_size):
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
return list(zip(*result))
@staticmethod
def cat_batch(batch1, batch2):
if batch1[0] is None:
return batch2
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
return result
def rebatch(self, latents, batch_size):
batch_size = batch_size[0]
output_list = []
current_batch = (None, None, None)
processed = 0
for i in range(len(latents)):
# fetch new entry of list
#samples, masks, indices = self.get_batch(latents, i)
next_batch = self.get_batch(latents, i, processed)
processed += len(next_batch[2])
# set to current if current is None
if current_batch[0] is None:
current_batch = next_batch
# add previous to list if dimensions do not match
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
current_batch = next_batch
# cat if everything checks out
else:
current_batch = self.cat_batch(current_batch, next_batch)
# add to list if dimensions gone above target batch size
if current_batch[0].shape[0] > batch_size:
num = current_batch[0].shape[0] // batch_size
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
for i in range(num):
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
current_batch = remainder
#add remainder
if current_batch[0] is not None:
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
#get rid of empty masks
for s in output_list:
if s['noise_mask'].mean() == 1.0:
del s['noise_mask']
return (output_list,)
NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RebatchLatents": "Rebatch Latents",
}

View File

@ -1,6 +1,6 @@
import os
from comfy_extras.chainner_models import model_loading
import model_management
from comfy import model_management
import torch
import comfy.utils
import folder_paths
@ -17,7 +17,7 @@ class UpscaleModelLoader:
def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
out = model_loading.load_state_dict(sd).eval()
return (out, )
@ -37,7 +37,12 @@ class ImageUpscaleWithModel:
device = model_management.get_torch_device()
upscale_model.to(device)
in_img = image.movedim(-1,-3).to(device)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=128 + 64, tile_y=128 + 64, overlap = 8, upscale_amount=upscale_model.scale)
tile = 128 + 64
overlap = 8
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
upscale_model.cpu()
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
return (s,)

View File

@ -6,10 +6,13 @@ import threading
import heapq
import traceback
import gc
import time
import torch
import nodes
import comfy.model_management
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
@ -24,29 +27,88 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = obj
else:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
input_data_all[x] = input_data
input_data_all[x] = [input_data]
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
for x in h:
if h[x] == "PROMPT":
input_data_all[x] = prompt
input_data_all[x] = [prompt]
if h[x] == "EXTRA_PNGINFO":
if "extra_pnginfo" in extra_data:
input_data_all[x] = extra_data['extra_pnginfo']
input_data_all[x] = [extra_data['extra_pnginfo']]
if h[x] == "UNIQUE_ID":
input_data_all[x] = unique_id
input_data_all[x] = [unique_id]
return input_data_all
def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
# check if node wants the lists
intput_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
intput_is_list = obj.INPUT_IS_LIST
max_len_input = max([len(x) for x in input_data_all.values()])
# get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i):
d_new = dict()
for k,v in d.items():
d_new[k] = v[i if len(v) > i else -1]
return d_new
results = []
if intput_is_list:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**input_data_all))
else:
for i in range(max_len_input):
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
return results
def get_output_data(obj, input_data_all):
results = []
uis = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
for r in return_values:
if isinstance(r, dict):
if 'ui' in r:
uis.append(r['ui'])
if 'result' in r:
results.append(r['result'])
else:
results.append(r)
output = []
if len(results) > 0:
# check which outputs need concatenating
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
output.append([x for o in results for x in o[i]])
else:
output.append([o[i] for o in results])
ui = dict()
if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if unique_id in outputs:
return []
executed = []
return
for x in inputs:
input_data = inputs[x]
@ -55,22 +117,21 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
executed += recursive_execute(server, prompt, outputs, input_unique_id, extra_data)
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None:
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id }, server.client_id)
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
obj = class_def()
nodes.before_node_execution()
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
if "ui" in outputs[unique_id]:
output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = output_data
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
if "result" in outputs[unique_id]:
outputs[unique_id] = outputs[unique_id]["result"]
return executed + [unique_id]
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
executed.add(unique_id)
def recursive_will_execute(prompt, outputs, current_item):
unique_id = current_item
@ -97,40 +158,45 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
is_changed_old = ''
is_changed = ''
to_delete = False
if hasattr(class_def, 'IS_CHANGED'):
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
is_changed_old = old_prompt[unique_id]['is_changed']
if 'is_changed' not in prompt[unique_id]:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
if input_data_all is not None:
is_changed = class_def.IS_CHANGED(**input_data_all)
prompt[unique_id]['is_changed'] = is_changed
try:
#is_changed = class_def.IS_CHANGED(**input_data_all)
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
prompt[unique_id]['is_changed'] = is_changed
except:
to_delete = True
else:
is_changed = prompt[unique_id]['is_changed']
if unique_id not in outputs:
return True
to_delete = False
if is_changed != is_changed_old:
to_delete = True
elif unique_id not in old_prompt:
to_delete = True
elif inputs == old_prompt[unique_id]['inputs']:
for x in inputs:
input_data = inputs[x]
if not to_delete:
if is_changed != is_changed_old:
to_delete = True
elif unique_id not in old_prompt:
to_delete = True
elif inputs == old_prompt[unique_id]['inputs']:
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id in outputs:
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
else:
to_delete = True
if to_delete:
break
else:
to_delete = True
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id in outputs:
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
else:
to_delete = True
if to_delete:
break
else:
to_delete = True
if to_delete:
d = outputs.pop(unique_id)
@ -140,10 +206,11 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
class PromptExecutor:
def __init__(self, server):
self.outputs = {}
self.outputs_ui = {}
self.old_prompt = {}
self.server = server
def execute(self, prompt, extra_data={}):
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False)
if "client_id" in extra_data:
@ -151,40 +218,55 @@ class PromptExecutor:
else:
self.server.client_id = None
execution_start_time = time.perf_counter()
if self.server.client_id is not None:
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
with torch.inference_mode():
#delete cached outputs if nodes don't exist for them
to_delete = []
for o in self.outputs:
if o not in prompt:
to_delete += [o]
for o in to_delete:
d = self.outputs.pop(o)
del d
for x in prompt:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
current_outputs = set(self.outputs.keys())
executed = []
for x in list(self.outputs_ui.keys()):
if x not in current_outputs:
d = self.outputs_ui.pop(x)
del d
if self.server.client_id is not None:
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
executed = set()
try:
to_execute = []
for x in prompt:
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
if hasattr(class_, 'OUTPUT_NODE'):
to_execute += [(0, x)]
for x in list(execute_outputs):
to_execute += [(0, x)]
while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
x = to_execute.pop(0)[-1]
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
if hasattr(class_, 'OUTPUT_NODE'):
if class_.OUTPUT_NODE == True:
valid = False
try:
m = validate_inputs(prompt, x)
valid = m[0]
except:
valid = False
if valid:
executed += recursive_execute(self.server, prompt, self.outputs, x, extra_data)
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui)
except Exception as e:
print(traceback.format_exc())
if isinstance(e, comfy.model_management.InterruptProcessingException):
print("Processing interrupted")
else:
message = str(traceback.format_exc())
print(message)
if self.server.client_id is not None:
self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id)
to_delete = []
for o in self.outputs:
if o not in current_outputs:
if (o not in current_outputs) and (o not in executed):
to_delete += [o]
if o in self.old_prompt:
d = self.old_prompt.pop(o)
@ -192,24 +274,23 @@ class PromptExecutor:
for o in to_delete:
d = self.outputs.pop(o)
del d
else:
executed = set(executed)
finally:
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])
finally:
self.server.last_node_id = None
if self.server.client_id is not None:
self.server.send_sync("executing", { "node": None }, self.server.client_id)
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
gc.collect()
if torch.cuda.is_available():
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
comfy.model_management.soft_empty_cache()
def validate_inputs(prompt, item):
def validate_inputs(prompt, item, validated):
unique_id = item
if unique_id in validated:
return validated[unique_id]
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
@ -230,8 +311,9 @@ def validate_inputs(prompt, item):
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
if r[val[1]] != type_input:
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input))
r = validate_inputs(prompt, o_id)
r = validate_inputs(prompt, o_id, validated)
if r[0] == False:
validated[o_id] = r
return r
else:
if type_input == "INT":
@ -250,10 +332,21 @@ def validate_inputs(prompt, item):
if "max" in info[1] and val > info[1]["max"]:
return (False, "Value bigger than max. {}, {}".format(class_type, x))
if isinstance(type_input, list):
if val not in type_input:
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
return (True, "")
if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id)
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for r in ret:
if r != True:
return (False, "{}, {}".format(class_type, r))
else:
if isinstance(type_input, list):
if val not in type_input:
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
ret = (True, "")
validated[unique_id] = ret
return ret
def validate_prompt(prompt):
outputs = set()
@ -267,19 +360,21 @@ def validate_prompt(prompt):
good_outputs = set()
errors = []
validated = {}
for o in outputs:
valid = False
reason = ""
try:
m = validate_inputs(prompt, o)
m = validate_inputs(prompt, o, validated)
valid = m[0]
reason = m[1]
except:
except Exception as e:
print(traceback.format_exc())
valid = False
reason = "Parsing error"
if valid == True:
good_outputs.add(x)
good_outputs.add(o)
else:
print("Failed to validate prompt for output {} {}".format(o, reason))
print("output will be ignored")
@ -289,7 +384,7 @@ def validate_prompt(prompt):
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list))
return (True, "")
return (True, "", list(good_outputs))
class PromptQueue:
@ -325,8 +420,7 @@ class PromptQueue:
prompt = self.currently_running.pop(item_id)
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
for o in outputs:
if "ui" in outputs[o]:
self.history[prompt[1]]["outputs"][o] = outputs[o]["ui"]
self.history[prompt[1]]["outputs"][o] = outputs[o]
self.server.queue_updated()
def get_current_queue(self):

View File

@ -13,11 +13,13 @@ a111:
models/ESRGAN
models/SwinIR
embeddings: embeddings
hypernetworks: models/hypernetworks
controlnet: models/ControlNet
#other_ui:
# base_path: path/to/ui
# checkpoints: models/checkpoints
# gligen: models/gligen
# custom_nodes: path/custom_nodes

View File

@ -12,8 +12,8 @@ except:
folder_names_and_paths = {}
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
base_path = os.path.dirname(os.path.realpath(__file__))
models_dir = os.path.join(base_path, "models")
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions)
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
@ -26,8 +26,14 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")]
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@ -63,6 +69,46 @@ def get_directory_by_type(type_name):
return None
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
# otherwise use default_path as base_dir
def annotated_filepath(name):
if name.endswith("[output]"):
base_dir = get_output_directory()
name = name[:-9]
elif name.endswith("[input]"):
base_dir = get_input_directory()
name = name[:-8]
elif name.endswith("[temp]"):
base_dir = get_temp_directory()
name = name[:-7]
else:
return name, None
return name, base_dir
def get_annotated_filepath(name, default_dir=None):
name, base_dir = annotated_filepath(name)
if base_dir is None:
if default_dir is not None:
base_dir = default_dir
else:
base_dir = get_input_directory() # fallback path
return os.path.join(base_dir, name)
def exists_annotated_filepath(name):
name, base_dir = annotated_filepath(name)
if base_dir is None:
base_dir = get_input_directory() # fallback path
filepath = os.path.join(base_dir, name)
return os.path.exists(filepath)
def add_model_folder_path(folder_name, full_folder_path):
global folder_names_and_paths
if folder_name in folder_names_and_paths:
@ -101,4 +147,37 @@ def get_filename_list(folder_name):
output_list.update(filter_files_extensions(recursive_search(x), folders[1]))
return sorted(list(output_list))
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
def map_filename(filename):
prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1]
try:
digits = int(filename[prefix_len + 1:].split('_')[0])
except:
digits = 0
return (digits, prefix)
def compute_vars(input, image_width, image_height):
input = input.replace("%width%", str(image_width))
input = input.replace("%height%", str(image_height))
return input
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix))
full_output_folder = os.path.join(output_dir, subfolder)
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
print("Saving image outside the output folder is not allowed.")
return {}
try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
except ValueError:
counter = 1
except FileNotFoundError:
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
return full_output_folder, filename, counter, subfolder, filename_prefix

40
main.py
View File

@ -5,6 +5,7 @@ import shutil
import threading
from comfy.cli_args import args
import comfy.utils
if os.name == "nt":
import logging
@ -32,21 +33,16 @@ def prompt_worker(q, server):
e = execution.PromptExecutor(server)
while True:
item, item_id = q.get()
e.execute(item[-2], item[-1])
q.task_done(item_id, e.outputs)
e.execute(item[2], item[1], item[3], item[4])
q.task_done(item_id, e.outputs_ui)
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
def hijack_progress(server):
from tqdm.auto import tqdm
orig_func = getattr(tqdm, "update")
def wrapped_func(*args, **kwargs):
pbar = args[0]
v = orig_func(*args, **kwargs)
server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id)
return v
setattr(tqdm, "update", wrapped_func)
def hook(value, total):
server.send_sync("progress", { "value": value, "max": total}, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook)
def cleanup_temp():
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
@ -81,16 +77,6 @@ if __name__ == "__main__":
server = server.PromptServer(loop)
q = execution.PromptQueue(server)
init_custom_nodes()
server.add_routes()
hijack_progress(server)
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
address = args.listen
dont_print = args.dont_print_server
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path)
@ -99,18 +85,22 @@ if __name__ == "__main__":
for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path)
init_custom_nodes()
server.add_routes()
hijack_progress(server)
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
if args.output_directory:
output_dir = os.path.abspath(args.output_directory)
print(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir)
port = args.port
if args.quick_test_for_ci:
exit(0)
call_on_start = None
if args.windows_standalone_build:
if args.auto_launch:
def startup_server(address, port):
import webbrowser
webbrowser.open("http://{}:{}".format(address, port))
@ -118,10 +108,10 @@ if __name__ == "__main__":
if os.name == "nt":
try:
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
except KeyboardInterrupt:
pass
else:
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
cleanup_temp()

View File

569
nodes.py
View File

@ -5,10 +5,13 @@ import sys
import json
import hashlib
import traceback
import math
import time
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
@ -16,21 +19,23 @@ sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "co
import comfy.diffusers_convert
import comfy.samplers
import comfy.sample
import comfy.sd
import comfy.utils
import comfy.clip_vision
import model_management
import comfy.model_management
import importlib
import folder_paths
def before_node_execution():
model_management.throw_exception_if_processing_interrupted()
comfy.model_management.throw_exception_if_processing_interrupted()
def interrupt_processing(value=True):
model_management.interrupt_current_processing(value)
comfy.model_management.interrupt_current_processing(value)
MAX_RESOLUTION=8192
@ -58,14 +63,44 @@ class ConditioningCombine:
def combine(self, conditioning_1, conditioning_2):
return (conditioning_1 + conditioning_2, )
class ConditioningAverage :
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ),
"conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "addWeighted"
CATEGORY = "conditioning"
def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_strength):
out = []
if len(conditioning_from) > 1:
print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
cond_from = conditioning_from[0][0]
for i in range(len(conditioning_to)):
t1 = conditioning_to[i][0]
t0 = cond_from[:,:t1.shape[1]]
if t0.shape[1] < t1.shape[1]:
t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)
tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
n = [tw, conditioning_to[i][1].copy()]
out.append(n)
return (out, )
class ConditioningSetArea:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
"height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
"width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
@ -73,21 +108,46 @@ class ConditioningSetArea:
CATEGORY = "conditioning"
def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0):
def append(self, conditioning, width, height, x, y, strength):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
n[1]['strength'] = strength
n[1]['min_sigma'] = min_sigma
n[1]['max_sigma'] = max_sigma
n[1]['set_area_to_bounds'] = False
c.append(n)
return (c, )
class ConditioningSetMask:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"mask": ("MASK", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"set_cond_area": (["default", "mask bounds"],),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, conditioning, mask, set_cond_area, strength):
c = []
set_area_to_bounds = False
if set_cond_area != "default":
set_area_to_bounds = True
if len(mask.shape) < 3:
mask = mask.unsqueeze(0)
for t in conditioning:
n = [t[0], t[1].copy()]
_, h, w = mask.shape
n[1]['mask'] = mask
n[1]['set_area_to_bounds'] = set_area_to_bounds
n[1]['mask_strength'] = strength
c.append(n)
return (c, )
class VAEDecode:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
@ -100,9 +160,6 @@ class VAEDecode:
return (vae.decode(samples["samples"]), )
class VAEDecodeTiled:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
@ -115,9 +172,6 @@ class VAEDecodeTiled:
return (vae.decode_tiled(samples["samples"]), )
class VAEEncode:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
@ -126,20 +180,22 @@ class VAEEncode:
CATEGORY = "latent"
def encode(self, vae, pixels):
x = (pixels.shape[1] // 64) * 64
y = (pixels.shape[2] // 64) * 64
@staticmethod
def vae_encode_crop_pixels(pixels):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:]
t = vae.encode(pixels[:,:,:,:3])
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
return pixels
def encode(self, vae, pixels):
pixels = self.vae_encode_crop_pixels(pixels)
t = vae.encode(pixels[:,:,:,:3])
return ({"samples":t}, )
class VAEEncodeTiled:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
@ -149,46 +205,123 @@ class VAEEncodeTiled:
CATEGORY = "_for_testing"
def encode(self, vae, pixels):
x = (pixels.shape[1] // 64) * 64
y = (pixels.shape[2] // 64) * 64
if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:]
pixels = VAEEncode.vae_encode_crop_pixels(pixels)
t = vae.encode_tiled(pixels[:,:,:,:3])
return ({"samples":t}, )
class VAEEncodeForInpaint:
def __init__(self, device="cpu"):
self.device = device
class VAEEncodeForInpaint:
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}}
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask):
x = (pixels.shape[1] // 64) * 64
y = (pixels.shape[2] // 64) * 64
mask = torch.nn.functional.interpolate(mask[None,None,], size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")[0][0]
def encode(self, vae, pixels, mask, grow_mask_by=6):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
pixels = pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:]
mask = mask[:x,:y]
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
#grow mask by a few pixels to keep things seamless in latent space
kernel_tensor = torch.ones((1, 1, 6, 6))
mask_erosion = torch.clamp(torch.nn.functional.conv2d((mask.round())[None], kernel_tensor, padding=3), 0, 1)
m = (1.0 - mask.round())
if grow_mask_by == 0:
mask_erosion = mask
else:
kernel_tensor = torch.ones((1, 1, grow_mask_by, grow_mask_by))
padding = math.ceil((grow_mask_by - 1) / 2)
mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=padding), 0, 1)
m = (1.0 - mask.round()).squeeze(1)
for i in range(3):
pixels[:,:,:,i] -= 0.5
pixels[:,:,:,i] *= m
pixels[:,:,:,i] += 0.5
t = vae.encode(pixels)
return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, )
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class SaveLatent:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ),
"filename_prefix": ("STRING", {"default": "latents/ComfyUI"})},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
# support save metadata for latent sharing
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {"prompt": prompt_info}
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
file = f"{filename}_{counter:05}_.latent"
file = os.path.join(full_output_folder, file)
output = {}
output["latent_tensor"] = samples["samples"]
safetensors.torch.save_file(output, file, metadata=metadata)
return {}
class LoadLatent:
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
return {"required": {"latent": [sorted(files), ]}, }
CATEGORY = "_for_testing"
RETURN_TYPES = ("LATENT", )
FUNCTION = "load"
def load(self, latent):
latent_path = folder_paths.get_annotated_filepath(latent)
latent = safetensors.torch.load_file(latent_path, device="cpu")
samples = {"samples": latent["latent_tensor"].float()}
return (samples, )
@classmethod
def IS_CHANGED(s, latent):
image_path = folder_paths.get_annotated_filepath(latent)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, latent):
if not folder_paths.exists_annotated_filepath(latent):
return "Invalid latent file: {}".format(latent)
return True
class CheckpointLoader:
@classmethod
@ -226,7 +359,10 @@ class DiffusersLoader:
paths = []
for search_path in folder_paths.get_folder_paths("diffusers"):
if os.path.exists(search_path):
paths += next(os.walk(search_path))[1]
for root, subdir, files in os.walk(search_path, followlinks=True):
if "model_index.json" in files:
paths.append(os.path.relpath(root, start=search_path))
return {"required": {"model_path": (paths,), }}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
@ -236,12 +372,12 @@ class DiffusersLoader:
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
for search_path in folder_paths.get_folder_paths("diffusers"):
if os.path.exists(search_path):
paths = next(os.walk(search_path))[1]
if model_path in paths:
model_path = os.path.join(search_path, model_path)
path = os.path.join(search_path, model_path)
if os.path.exists(path):
model_path = path
break
return comfy.diffusers_convert.load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class unCLIPCheckpointLoader:
@ -373,7 +509,6 @@ class ControlNetApply:
def apply_controlnet(self, conditioning, control_net, image, strength):
c = []
control_hint = image.movedim(-1,1)
print(control_hint.shape)
for t in conditioning:
n = [t[0], t[1].copy()]
c_net = control_net.copy().set_cond_hint(control_hint, strength)
@ -490,6 +625,51 @@ class unCLIPConditioning:
c.append(n)
return (c, )
class GLIGENLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}}
RETURN_TYPES = ("GLIGEN",)
FUNCTION = "load_gligen"
CATEGORY = "loaders"
def load_gligen(self, gligen_name):
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
gligen = comfy.sd.load_gligen(gligen_path)
return (gligen,)
class GLIGENTextBoxApply:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning_to": ("CONDITIONING", ),
"clip": ("CLIP", ),
"gligen_textbox_model": ("GLIGEN", ),
"text": ("STRING", {"multiline": True}),
"width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "conditioning/gligen"
def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
c = []
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
for t in conditioning_to:
n = [t[0], t[1].copy()]
position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
prev = []
if "gligen" in n[1]:
prev = n[1]['gligen'][2]
n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params)
c.append(n)
return (c, )
class EmptyLatentImage:
def __init__(self, device="cpu"):
@ -497,8 +677,8 @@ class EmptyLatentImage:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
@ -510,6 +690,63 @@ class EmptyLatentImage:
return ({"samples":latent}, )
class LatentFromBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "frombatch"
CATEGORY = "latent/batch"
def frombatch(self, samples, batch_index, length):
s = samples.copy()
s_in = samples["samples"]
batch_index = min(s_in.shape[0] - 1, batch_index)
length = min(s_in.shape[0] - batch_index, length)
s["samples"] = s_in[batch_index:batch_index + length].clone()
if "noise_mask" in samples:
masks = samples["noise_mask"]
if masks.shape[0] == 1:
s["noise_mask"] = masks.clone()
else:
if masks.shape[0] < s_in.shape[0]:
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
s["noise_mask"] = masks[batch_index:batch_index + length].clone()
if "batch_index" not in s:
s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
else:
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
return (s,)
class RepeatLatentBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "repeat"
CATEGORY = "latent/batch"
def repeat(self, samples, amount):
s = samples.copy()
s_in = samples["samples"]
s["samples"] = s_in.repeat((amount, 1,1,1))
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
masks = samples["noise_mask"]
if masks.shape[0] < s_in.shape[0]:
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
if "batch_index" in s:
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
return (s,)
class LatentUpscale:
upscale_methods = ["nearest-exact", "bilinear", "area"]
@ -517,9 +754,10 @@ class LatentUpscale:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"crop": ("BOOL", {"on": "center", "off": "disabled"}),}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "upscale"
@ -621,8 +859,8 @@ class LatentCrop:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
}}
@ -647,16 +885,6 @@ class LatentCrop:
new_width = width // 8
to_x = new_width + x
to_y = new_height + y
def enforce_image_dim(d, to_d, max_d):
if to_d > max_d:
leftover = (to_d - max_d) % 8
to_d = max_d
d -= leftover
return (d, to_d)
#make sure size is always multiple of 64
x, to_x = enforce_image_dim(x, to_x, samples.shape[3])
y, to_y = enforce_image_dim(y, to_y, samples.shape[2])
s['samples'] = samples[:,:,y:to_y, x:to_x]
return (s,)
@ -673,72 +901,30 @@ class SetLatentNoiseMask:
def set_mask(self, samples, mask):
s = samples.copy()
s["noise_mask"] = mask
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
return (s,)
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
device = comfy.model_management.get_torch_device()
latent_image = latent["samples"]
noise_mask = None
device = model_management.get_torch_device()
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu")
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent['noise_mask']
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
noise_mask = torch.cat([noise_mask] * noise.shape[0])
noise_mask = noise_mask.to(device)
noise_mask = latent["noise_mask"]
real_model = None
model_management.load_model_gpu(model)
real_model = model.model
noise = noise.to(device)
latent_image = latent_image.to(device)
positive_copy = []
negative_copy = []
control_nets = []
for p in positive:
t = p[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
if 'control' in p[1]:
control_nets += [p[1]['control']]
positive_copy += [[t] + p[1:]]
for n in negative:
t = n[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
if 'control' in n[1]:
control_nets += [n[1]['control']]
negative_copy += [[t] + n[1:]]
control_net_models = []
for x in control_nets:
control_net_models += x.get_control_models()
model_management.load_controlnet_gpu(control_net_models)
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
else:
#other samplers
pass
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
samples = samples.cpu()
for c in control_nets:
c.cleanup()
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
pbar.update_absolute(step + 1, total_steps)
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback)
out = latent.copy()
out["samples"] = samples
return (out, )
@ -817,39 +1003,7 @@ class SaveImage:
CATEGORY = "image"
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
def map_filename(filename):
prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1]
try:
digits = int(filename[prefix_len + 1:].split('_')[0])
except:
digits = 0
return (digits, prefix)
def compute_vars(input):
input = input.replace("%width%", str(images[0].shape[1]))
input = input.replace("%height%", str(images[0].shape[0]))
return input
filename_prefix = compute_vars(filename_prefix)
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix))
full_output_folder = os.path.join(self.output_dir, subfolder)
if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir:
print("Saving image outside the output folder is not allowed.")
return {}
try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
except ValueError:
counter = 1
except FileNotFoundError:
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
for image in images:
i = 255. * image.cpu().numpy()
@ -867,7 +1021,7 @@ class SaveImage:
"filename": file,
"subfolder": subfolder,
"type": self.type
});
})
counter += 1
return { "ui": { "images": results } }
@ -888,8 +1042,9 @@ class LoadImage:
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
return {"required":
{"image": (sorted(os.listdir(input_dir)), )},
{"image": (sorted(files), )},
}
CATEGORY = "image"
@ -897,8 +1052,7 @@ class LoadImage:
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image"
def load_image(self, image):
input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
@ -912,29 +1066,36 @@ class LoadImage:
@classmethod
def IS_CHANGED(s, image):
input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True
class LoadImageMask:
_color_channels = ["alpha", "red", "green", "blue"]
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
return {"required":
{"image": (sorted(os.listdir(input_dir)), ),
"channel": (["alpha", "red", "green", "blue"], ),}
{"image": (sorted(files), ),
"channel": (s._color_channels, ), }
}
CATEGORY = "image"
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "load_image"
def load_image(self, image, channel):
input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path)
if i.getbands() != ("R", "G", "B", "A"):
i = i.convert("RGBA")
@ -951,13 +1112,22 @@ class LoadImageMask:
@classmethod
def IS_CHANGED(s, image, channel):
input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image, channel):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
if channel not in s._color_channels:
return "Invalid color channel: {}".format(channel)
return True
class ImageScale:
upscale_methods = ["nearest-exact", "bilinear", "area"]
@ -1002,10 +1172,10 @@ class ImagePadForOutpaint:
return {
"required": {
"image": ("IMAGE",),
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
}
}
@ -1069,6 +1239,8 @@ NODE_CLASS_MAPPINGS = {
"VAELoader": VAELoader,
"EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale,
"LatentFromBatch": LatentFromBatch,
"RepeatLatentBatch": RepeatLatentBatch,
"SaveImage": SaveImage,
"PreviewImage": PreviewImage,
"LoadImage": LoadImage,
@ -1076,8 +1248,10 @@ NODE_CLASS_MAPPINGS = {
"ImageScale": ImageScale,
"ImageInvert": ImageInvert,
"ImagePadForOutpaint": ImagePadForOutpaint,
"ConditioningAverage ": ConditioningAverage ,
"ConditioningCombine": ConditioningCombine,
"ConditioningSetArea": ConditioningSetArea,
"ConditioningSetMask": ConditioningSetMask,
"KSamplerAdvanced": KSamplerAdvanced,
"SetLatentNoiseMask": SetLatentNoiseMask,
"LatentComposite": LatentComposite,
@ -1098,8 +1272,14 @@ NODE_CLASS_MAPPINGS = {
"VAEEncodeTiled": VAEEncodeTiled,
"TomePatchModel": TomePatchModel,
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"GLIGENLoader": GLIGENLoader,
"GLIGENTextBoxApply": GLIGENTextBoxApply,
"CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader,
"LoadLatent": LoadLatent,
"SaveLatent": SaveLatent
}
NODE_DISPLAY_NAME_MAPPINGS = {
@ -1123,7 +1303,9 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CLIPTextEncode": "CLIP Text Encode (Prompt)",
"CLIPSetLastLayer": "CLIP Set Last Layer",
"ConditioningCombine": "Conditioning (Combine)",
"ConditioningAverage ": "Conditioning (Average)",
"ConditioningSetArea": "Conditioning (Set Area)",
"ConditioningSetMask": "Conditioning (Set Mask)",
"ControlNetApply": "Apply ControlNet",
# Latent
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
@ -1136,6 +1318,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"EmptyLatentImage": "Empty Latent Image",
"LatentUpscale": "Upscale Latent",
"LatentComposite": "Latent Composite",
"LatentFromBatch" : "Latent From Batch",
"RepeatLatentBatch": "Repeat Latent Batch",
# Image
"SaveImage": "Save Image",
"PreviewImage": "Preview Image",
@ -1167,24 +1351,45 @@ def load_custom_node(module_path):
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
return True
else:
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
return False
except Exception as e:
print(traceback.format_exc())
print(f"Cannot import {module_path} module for custom nodes:", e)
return False
def load_custom_nodes():
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")
possible_modules = os.listdir(CUSTOM_NODE_PATH)
if "__pycache__" in possible_modules:
possible_modules.remove("__pycache__")
node_paths = folder_paths.get_folder_paths("custom_nodes")
node_import_times = []
for custom_node_path in node_paths:
possible_modules = os.listdir(custom_node_path)
if "__pycache__" in possible_modules:
possible_modules.remove("__pycache__")
for possible_module in possible_modules:
module_path = os.path.join(CUSTOM_NODE_PATH, possible_module)
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
load_custom_node(module_path)
for possible_module in possible_modules:
module_path = os.path.join(custom_node_path, possible_module)
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
if module_path.endswith(".disabled"): continue
time_before = time.perf_counter()
success = load_custom_node(module_path)
node_import_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_import_times) > 0:
print("\nImport times for custom nodes:")
for n in sorted(node_import_times):
if n[2]:
import_message = ""
else:
import_message = " (IMPORT FAILED)"
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
print()
def init_custom_nodes():
load_custom_nodes()
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
load_custom_nodes()

View File

@ -47,7 +47,7 @@
" !git pull\n",
"\n",
"!echo -= Install dependencies =-\n",
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118"
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117"
]
},
{
@ -119,14 +119,30 @@
"\n",
"\n",
"# ControlNet\n",
"#!wget -c https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/control_depth-fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/control_scribble-fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/control_openpose-fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_ip2p_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_shuffle_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_canny_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11f1p_sd15_depth_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_inpaint_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_lineart_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_mlsd_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_normalbae_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_openpose_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_scribble_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_seg_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_softedge_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15s2_lineart_anime_fp16.safetensors -P ./models/controlnet/\n",
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11u_sd15_tile_fp16.safetensors -P ./models/controlnet/\n",
"\n",
"\n",
"# Controlnet Preprocessor nodes by Fannovel16\n",
"#!cd custom_nodes && git clone https://github.com/Fannovel16/comfy_controlnet_preprocessors; cd comfy_controlnet_preprocessors && python install.py\n",
"\n",
"\n",
"# GLIGEN\n",
"#!wget -c https://huggingface.co/comfyanonymous/GLIGEN_pruned_safetensors/resolve/main/gligen_sd14_textbox_pruned_fp16.safetensors -P ./models/gligen/\n",
"\n",
"\n",
"# ESRGAN upscale model\n",
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
@ -159,6 +175,8 @@
"import threading\n",
"import time\n",
"import socket\n",
"import urllib.request\n",
"\n",
"def iframe_thread(port):\n",
" while True:\n",
" time.sleep(0.5)\n",
@ -167,7 +185,9 @@
" if result == 0:\n",
" break\n",
" sock.close()\n",
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n",
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n",
"\n",
" print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n",
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
" for line in p.stdout:\n",
" print(line.decode(), end='')\n",

155
server.py
View File

@ -7,6 +7,9 @@ import execution
import uuid
import json
import glob
from PIL import Image
from io import BytesIO
try:
import aiohttp
from aiohttp import web
@ -78,7 +81,7 @@ class PromptServer():
# Reusing existing session, remove old
self.sockets.pop(sid, None)
else:
sid = uuid.uuid4().hex
sid = uuid.uuid4().hex
self.sockets[sid] = ws
@ -110,42 +113,96 @@ class PromptServer():
files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True)
return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)))
@routes.post("/upload/image")
async def upload_image(request):
upload_dir = folder_paths.get_input_directory()
def get_dir_by_type(dir_type):
if dir_type is None:
dir_type = "input"
if not os.path.exists(upload_dir):
os.makedirs(upload_dir)
post = await request.post()
if dir_type == "input":
type_dir = folder_paths.get_input_directory()
elif dir_type == "temp":
type_dir = folder_paths.get_temp_directory()
elif dir_type == "output":
type_dir = folder_paths.get_output_directory()
return type_dir, dir_type
def image_upload(post, image_save_function=None):
image = post.get("image")
overwrite = post.get("overwrite")
image_upload_type = post.get("type")
upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
if image and image.file:
filename = image.filename
if not filename:
return web.Response(status=400)
subfolder = post.get("subfolder", "")
full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder))
if os.path.commonpath((upload_dir, os.path.abspath(full_output_folder))) != upload_dir:
return web.Response(status=400)
if not os.path.exists(full_output_folder):
os.makedirs(full_output_folder)
split = os.path.splitext(filename)
i = 1
while os.path.exists(os.path.join(upload_dir, filename)):
filename = f"{split[0]} ({i}){split[1]}"
i += 1
filepath = os.path.join(full_output_folder, filename)
filepath = os.path.join(upload_dir, filename)
if overwrite is not None and (overwrite == "true" or overwrite == "1"):
pass
else:
i = 1
while os.path.exists(filepath):
filename = f"{split[0]} ({i}){split[1]}"
filepath = os.path.join(full_output_folder, filename)
i += 1
with open(filepath, "wb") as f:
f.write(image.file.read())
return web.json_response({"name" : filename})
if image_save_function is not None:
image_save_function(image, post, filepath)
else:
with open(filepath, "wb") as f:
f.write(image.file.read())
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
else:
return web.Response(status=400)
@routes.post("/upload/image")
async def upload_image(request):
post = await request.post()
return image_upload(post)
@routes.post("/upload/mask")
async def upload_mask(request):
post = await request.post()
def image_save_function(image, post, filepath):
original_pil = Image.open(post.get("original_image").file).convert('RGBA')
mask_pil = Image.open(image.file).convert('RGBA')
# alpha copy
new_alpha = mask_pil.getchannel('A')
original_pil.putalpha(new_alpha)
original_pil.save(filepath, compress_level=4)
return image_upload(post, image_save_function)
@routes.get("/view")
async def view_image(request):
if "filename" in request.rel_url.query:
type = request.rel_url.query.get("type", "output")
output_dir = folder_paths.get_directory_by_type(type)
filename = request.rel_url.query["filename"]
filename,output_dir = folder_paths.annotated_filepath(filename)
# validation for security: prevent accessing arbitrary path
if filename[0] == '/' or '..' in filename:
return web.Response(status=400)
if output_dir is None:
type = request.rel_url.query.get("type", "output")
output_dir = folder_paths.get_directory_by_type(type)
if output_dir is None:
return web.Response(status=400)
@ -155,13 +212,49 @@ class PromptServer():
return web.Response(status=403)
output_dir = full_output_dir
filename = request.rel_url.query["filename"]
filename = os.path.basename(filename)
file = os.path.join(output_dir, filename)
if os.path.isfile(file):
return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
if 'channel' not in request.rel_url.query:
channel = 'rgba'
else:
channel = request.rel_url.query["channel"]
if channel == 'rgb':
with Image.open(file) as img:
if img.mode == "RGBA":
r, g, b, a = img.split()
new_img = Image.merge('RGB', (r, g, b))
else:
new_img = img.convert("RGB")
buffer = BytesIO()
new_img.save(buffer, format='PNG')
buffer.seek(0)
return web.Response(body=buffer.read(), content_type='image/png',
headers={"Content-Disposition": f"filename=\"{filename}\""})
elif channel == 'a':
with Image.open(file) as img:
if img.mode == "RGBA":
_, _, _, a = img.split()
else:
a = Image.new('L', img.size, 255)
# alpha img
alpha_img = Image.new('RGBA', img.size)
alpha_img.putalpha(a)
alpha_buffer = BytesIO()
alpha_img.save(alpha_buffer, format='PNG')
alpha_buffer.seek(0)
return web.Response(body=alpha_buffer.read(), content_type='image/png',
headers={"Content-Disposition": f"filename=\"{filename}\""})
else:
return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
return web.Response(status=404)
@routes.get("/prompt")
@ -176,6 +269,7 @@ class PromptServer():
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = x
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x
@ -225,14 +319,15 @@ class PromptServer():
if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
prompt_id = str(uuid.uuid4())
self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2]))
return web.json_response({"prompt_id": prompt_id})
else:
resp_code = 400
out_string = valid[1]
print("invalid prompt:", valid[1])
return web.json_response({"error": valid[1]}, status=400)
else:
return web.json_response({"error": "no prompt"}, status=400)
return web.Response(body=out_string, status=resp_code)
@routes.post("/queue")
async def post_queue(request):
json_data = await request.json()
@ -242,9 +337,9 @@ class PromptServer():
if "delete" in json_data:
to_delete = json_data['delete']
for id_to_delete in to_delete:
delete_func = lambda a: a[1] == int(id_to_delete)
delete_func = lambda a: a[1] == id_to_delete
self.prompt_queue.delete_queue_item(delete_func)
return web.Response(status=200)
@routes.post("/interrupt")
@ -268,7 +363,7 @@ class PromptServer():
def add_routes(self):
self.app.add_routes(self.routes)
self.app.add_routes([
web.static('/', self.web_root),
web.static('/', self.web_root, follow_symlinks=True),
])
def get_queue_info(self):

View File

@ -0,0 +1,166 @@
import { app } from "/scripts/app.js";
import { ComfyDialog, $el } from "/scripts/ui.js";
import { ComfyApp } from "/scripts/app.js";
export class ClipspaceDialog extends ComfyDialog {
static items = [];
static instance = null;
static registerButton(name, contextPredicate, callback) {
const item =
$el("button", {
type: "button",
textContent: name,
contextPredicate: contextPredicate,
onclick: callback
})
ClipspaceDialog.items.push(item);
}
static invalidatePreview() {
if(ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0) {
const img_preview = document.getElementById("clipspace_preview");
if(img_preview) {
img_preview.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
img_preview.style.maxHeight = "100%";
img_preview.style.maxWidth = "100%";
}
}
}
static invalidate() {
if(ClipspaceDialog.instance) {
const self = ClipspaceDialog.instance;
// allow reconstruct controls when copying from non-image to image content.
const children = $el("div.comfy-modal-content", [ self.createImgSettings(), ...self.createButtons() ]);
if(self.element) {
// update
self.element.removeChild(self.element.firstChild);
self.element.appendChild(children);
}
else {
// new
self.element = $el("div.comfy-modal", { parent: document.body }, [children,]);
}
if(self.element.children[0].children.length <= 1) {
self.element.children[0].appendChild($el("p", {}, ["Unable to find the features to edit content of a format stored in the current Clipspace."]));
}
ClipspaceDialog.invalidatePreview();
}
}
constructor() {
super();
}
createButtons(self) {
const buttons = [];
for(let idx in ClipspaceDialog.items) {
const item = ClipspaceDialog.items[idx];
if(!item.contextPredicate || item.contextPredicate())
buttons.push(ClipspaceDialog.items[idx]);
}
buttons.push(
$el("button", {
type: "button",
textContent: "Close",
onclick: () => { this.close(); }
})
);
return buttons;
}
createImgSettings() {
if(ComfyApp.clipspace.imgs) {
const combo_items = [];
const imgs = ComfyApp.clipspace.imgs;
for(let i=0; i < imgs.length; i++) {
combo_items.push($el("option", {value:i}, [`${i}`]));
}
const combo1 = $el("select",
{id:"clipspace_img_selector", onchange:(event) => {
ComfyApp.clipspace['selectedIndex'] = event.target.selectedIndex;
ClipspaceDialog.invalidatePreview();
} }, combo_items);
const row1 =
$el("tr", {},
[
$el("td", {}, [$el("font", {color:"white"}, ["Select Image"])]),
$el("td", {}, [combo1])
]);
const combo2 = $el("select",
{id:"clipspace_img_paste_mode", onchange:(event) => {
ComfyApp.clipspace['img_paste_mode'] = event.target.value;
} },
[
$el("option", {value:'selected'}, 'selected'),
$el("option", {value:'all'}, 'all')
]);
combo2.value = ComfyApp.clipspace['img_paste_mode'];
const row2 =
$el("tr", {},
[
$el("td", {}, [$el("font", {color:"white"}, ["Paste Mode"])]),
$el("td", {}, [combo2])
]);
const td = $el("td", {align:'center', width:'100px', height:'100px', colSpan:'2'},
[ $el("img",{id:"clipspace_preview", ondragstart:() => false},[]) ]);
const row3 =
$el("tr", {}, [td]);
return $el("table", {}, [row1, row2, row3]);
}
else {
return [];
}
}
createImgPreview() {
if(ComfyApp.clipspace.imgs) {
return $el("img",{id:"clipspace_preview", ondragstart:() => false});
}
else
return [];
}
show() {
const img_preview = document.getElementById("clipspace_preview");
ClipspaceDialog.invalidate();
this.element.style.display = "block";
}
}
app.registerExtension({
name: "Comfy.Clipspace",
init(app) {
app.openClipspace =
function () {
if(!ClipspaceDialog.instance) {
ClipspaceDialog.instance = new ClipspaceDialog(app);
ComfyApp.clipspace_invalidate_handler = ClipspaceDialog.invalidate;
}
if(ComfyApp.clipspace) {
ClipspaceDialog.instance.show();
}
else
app.ui.dialog.show("Clipspace is Empty!");
};
}
});

View File

@ -107,7 +107,7 @@ const colorPalettes = {
"descrip-text": "#444",
"drag-text": "#555",
"error-text": "#F44336",
"border-color": "#CCC"
"border-color": "#888"
}
},
},
@ -232,10 +232,27 @@ app.registerExtension({
"name": "My Color Palette",
"colors": {
"node_slot": {
},
"litegraph_base": {
},
"comfy_base": {
}
}
};
// Copy over missing keys from default color palette
const defaultColorPalette = colorPalettes[defaultColorPaletteId];
for (const key in defaultColorPalette.colors.litegraph_base) {
if (!colorPalette.colors.litegraph_base[key]) {
colorPalette.colors.litegraph_base[key] = "";
}
}
for (const key in defaultColorPalette.colors.comfy_base) {
if (!colorPalette.colors.comfy_base[key]) {
colorPalette.colors.comfy_base[key] = "";
}
}
return completeColorPalette(colorPalette);
};

View File

@ -0,0 +1,144 @@
import { app } from "/scripts/app.js";
// Allows you to edit the attention weight by holding ctrl (or cmd) and using the up/down arrow keys
app.registerExtension({
name: "Comfy.EditAttention",
init() {
const editAttentionDelta = app.ui.settings.addSetting({
id: "Comfy.EditAttention.Delta",
name: "Ctrl+up/down precision",
type: "slider",
attrs: {
min: 0.01,
max: 0.5,
step: 0.01,
},
defaultValue: 0.05,
});
function incrementWeight(weight, delta) {
const floatWeight = parseFloat(weight);
if (isNaN(floatWeight)) return weight;
const newWeight = floatWeight + delta;
if (newWeight < 0) return "0";
return String(Number(newWeight.toFixed(10)));
}
function findNearestEnclosure(text, cursorPos) {
let start = cursorPos, end = cursorPos;
let openCount = 0, closeCount = 0;
// Find opening parenthesis before cursor
while (start >= 0) {
start--;
if (text[start] === "(" && openCount === closeCount) break;
if (text[start] === "(") openCount++;
if (text[start] === ")") closeCount++;
}
if (start < 0) return false;
openCount = 0;
closeCount = 0;
// Find closing parenthesis after cursor
while (end < text.length) {
if (text[end] === ")" && openCount === closeCount) break;
if (text[end] === "(") openCount++;
if (text[end] === ")") closeCount++;
end++;
}
if (end === text.length) return false;
return { start: start + 1, end: end };
}
function addWeightToParentheses(text) {
const parenRegex = /^\((.*)\)$/;
const parenMatch = text.match(parenRegex);
const floatRegex = /:([+-]?(\d*\.)?\d+([eE][+-]?\d+)?)/;
const floatMatch = text.match(floatRegex);
if (parenMatch && !floatMatch) {
return `(${parenMatch[1]}:1.0)`;
} else {
return text;
}
};
function editAttention(event) {
const inputField = event.composedPath()[0];
const delta = parseFloat(editAttentionDelta.value);
if (inputField.tagName !== "TEXTAREA") return;
if (!(event.key === "ArrowUp" || event.key === "ArrowDown")) return;
if (!event.ctrlKey && !event.metaKey) return;
event.preventDefault();
let start = inputField.selectionStart;
let end = inputField.selectionEnd;
let selectedText = inputField.value.substring(start, end);
// If there is no selection, attempt to find the nearest enclosure, or select the current word
if (!selectedText) {
const nearestEnclosure = findNearestEnclosure(inputField.value, start);
if (nearestEnclosure) {
start = nearestEnclosure.start;
end = nearestEnclosure.end;
selectedText = inputField.value.substring(start, end);
} else {
// Select the current word, find the start and end of the word
const delimiters = " .,\\/!?%^*;:{}=-_`~()\r\n\t";
while (!delimiters.includes(inputField.value[start - 1]) && start > 0) {
start--;
}
while (!delimiters.includes(inputField.value[end]) && end < inputField.value.length) {
end++;
}
selectedText = inputField.value.substring(start, end);
if (!selectedText) return;
}
}
// If the selection ends with a space, remove it
if (selectedText[selectedText.length - 1] === " ") {
selectedText = selectedText.substring(0, selectedText.length - 1);
end -= 1;
}
// If there are parentheses left and right of the selection, select them
if (inputField.value[start - 1] === "(" && inputField.value[end] === ")") {
start -= 1;
end += 1;
selectedText = inputField.value.substring(start, end);
}
// If the selection is not enclosed in parentheses, add them
if (selectedText[0] !== "(" || selectedText[selectedText.length - 1] !== ")") {
selectedText = `(${selectedText})`;
}
// If the selection does not have a weight, add a weight of 1.0
selectedText = addWeightToParentheses(selectedText);
// Increment the weight
const weightDelta = event.key === "ArrowUp" ? delta : -delta;
const updatedText = selectedText.replace(/\((.*):(\d+(?:\.\d+)?)\)/, (match, text, weight) => {
weight = incrementWeight(weight, weightDelta);
if (weight == 1) {
return text;
} else {
return `(${text}:${weight})`;
}
});
inputField.setRangeText(updatedText, start, end, "select");
}
window.addEventListener("keydown", editAttention);
},
});

View File

@ -0,0 +1,76 @@
import { app } from "/scripts/app.js";
const id = "Comfy.Keybinds";
app.registerExtension({
name: id,
init() {
const keybindListener = function(event) {
const modifierPressed = event.ctrlKey || event.metaKey;
// Queue prompt using ctrl or command + enter
if (modifierPressed && (event.key === "Enter" || event.keyCode === 13 || event.keyCode === 10)) {
app.queuePrompt(event.shiftKey ? -1 : 0);
return;
}
const target = event.composedPath()[0];
if (target.tagName === "INPUT" || target.tagName === "TEXTAREA") {
return;
}
const modifierKeyIdMap = {
"s": "#comfy-save-button",
83: "#comfy-save-button",
"o": "#comfy-file-input",
79: "#comfy-file-input",
"Backspace": "#comfy-clear-button",
8: "#comfy-clear-button",
"Delete": "#comfy-clear-button",
46: "#comfy-clear-button",
"d": "#comfy-load-default-button",
68: "#comfy-load-default-button",
};
const modifierKeybindId = modifierKeyIdMap[event.key] || modifierKeyIdMap[event.keyCode];
if (modifierPressed && modifierKeybindId) {
event.preventDefault();
const elem = document.querySelector(modifierKeybindId);
elem.click();
return;
}
// Finished Handling all modifier keybinds, now handle the rest
if (event.ctrlKey || event.altKey || event.metaKey) {
return;
}
// Close out of modals using escape
if (event.key === "Escape" || event.keyCode === 27) {
const modals = document.querySelectorAll(".comfy-modal");
const modal = Array.from(modals).find(modal => window.getComputedStyle(modal).getPropertyValue("display") !== "none");
if (modal) {
modal.style.display = "none";
}
}
const keyIdMap = {
"q": "#comfy-view-queue-button",
81: "#comfy-view-queue-button",
"h": "#comfy-view-history-button",
72: "#comfy-view-history-button",
"r": "#comfy-refresh-button",
82: "#comfy-refresh-button",
};
const buttonId = keyIdMap[event.key] || keyIdMap[event.keyCode];
if (buttonId) {
const button = document.querySelector(buttonId);
button.click();
}
}
window.addEventListener("keydown", keybindListener, true);
}
});

View File

@ -0,0 +1,648 @@
import { app } from "/scripts/app.js";
import { ComfyDialog, $el } from "/scripts/ui.js";
import { ComfyApp } from "/scripts/app.js";
import { ClipspaceDialog } from "/extensions/core/clipspace.js";
// Helper function to convert a data URL to a Blob object
function dataURLToBlob(dataURL) {
const parts = dataURL.split(';base64,');
const contentType = parts[0].split(':')[1];
const byteString = atob(parts[1]);
const arrayBuffer = new ArrayBuffer(byteString.length);
const uint8Array = new Uint8Array(arrayBuffer);
for (let i = 0; i < byteString.length; i++) {
uint8Array[i] = byteString.charCodeAt(i);
}
return new Blob([arrayBuffer], { type: contentType });
}
function loadedImageToBlob(image) {
const canvas = document.createElement('canvas');
canvas.width = image.width;
canvas.height = image.height;
const ctx = canvas.getContext('2d');
ctx.drawImage(image, 0, 0);
const dataURL = canvas.toDataURL('image/png', 1);
const blob = dataURLToBlob(dataURL);
return blob;
}
async function uploadMask(filepath, formData) {
await fetch('/upload/mask', {
method: 'POST',
body: formData
}).then(response => {}).catch(error => {
console.error('Error:', error);
});
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image();
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString();
if(ComfyApp.clipspace.images)
ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath;
ClipspaceDialog.invalidatePreview();
}
function prepareRGB(image, backupCanvas, backupCtx) {
// paste mask data into alpha channel
backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height);
const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height);
// refine mask image
for (let i = 0; i < backupData.data.length; i += 4) {
if(backupData.data[i+3] == 255)
backupData.data[i+3] = 0;
else
backupData.data[i+3] = 255;
backupData.data[i] = 0;
backupData.data[i+1] = 0;
backupData.data[i+2] = 0;
}
backupCtx.globalCompositeOperation = 'source-over';
backupCtx.putImageData(backupData, 0, 0);
}
class MaskEditorDialog extends ComfyDialog {
static instance = null;
static getInstance() {
if(!MaskEditorDialog.instance) {
MaskEditorDialog.instance = new MaskEditorDialog(app);
}
return MaskEditorDialog.instance;
}
is_layout_created = false;
constructor() {
super();
this.element = $el("div.comfy-modal", { parent: document.body },
[ $el("div.comfy-modal-content",
[...this.createButtons()]),
]);
}
createButtons() {
return [];
}
createButton(name, callback) {
var button = document.createElement("button");
button.innerText = name;
button.addEventListener("click", callback);
return button;
}
createLeftButton(name, callback) {
var button = this.createButton(name, callback);
button.style.cssFloat = "left";
button.style.marginRight = "4px";
return button;
}
createRightButton(name, callback) {
var button = this.createButton(name, callback);
button.style.cssFloat = "right";
button.style.marginLeft = "4px";
return button;
}
createLeftSlider(self, name, callback) {
const divElement = document.createElement('div');
divElement.id = "maskeditor-slider";
divElement.style.cssFloat = "left";
divElement.style.fontFamily = "sans-serif";
divElement.style.marginRight = "4px";
divElement.style.color = "var(--input-text)";
divElement.style.backgroundColor = "var(--comfy-input-bg)";
divElement.style.borderRadius = "8px";
divElement.style.borderColor = "var(--border-color)";
divElement.style.borderStyle = "solid";
divElement.style.fontSize = "15px";
divElement.style.height = "21px";
divElement.style.padding = "1px 6px";
divElement.style.display = "flex";
divElement.style.position = "relative";
divElement.style.top = "2px";
self.brush_slider_input = document.createElement('input');
self.brush_slider_input.setAttribute('type', 'range');
self.brush_slider_input.setAttribute('min', '1');
self.brush_slider_input.setAttribute('max', '100');
self.brush_slider_input.setAttribute('value', '10');
const labelElement = document.createElement("label");
labelElement.textContent = name;
divElement.appendChild(labelElement);
divElement.appendChild(self.brush_slider_input);
self.brush_slider_input.addEventListener("change", callback);
return divElement;
}
setlayout(imgCanvas, maskCanvas) {
const self = this;
// If it is specified as relative, using it only as a hidden placeholder for padding is recommended
// to prevent anomalies where it exceeds a certain size and goes outside of the window.
var placeholder = document.createElement("div");
placeholder.style.position = "relative";
placeholder.style.height = "50px";
var bottom_panel = document.createElement("div");
bottom_panel.style.position = "absolute";
bottom_panel.style.bottom = "0px";
bottom_panel.style.left = "20px";
bottom_panel.style.right = "20px";
bottom_panel.style.height = "50px";
var brush = document.createElement("div");
brush.id = "brush";
brush.style.backgroundColor = "transparent";
brush.style.outline = "1px dashed black";
brush.style.boxShadow = "0 0 0 1px white";
brush.style.borderRadius = "50%";
brush.style.MozBorderRadius = "50%";
brush.style.WebkitBorderRadius = "50%";
brush.style.position = "absolute";
brush.style.zIndex = 8889;
brush.style.pointerEvents = "none";
this.brush = brush;
this.element.appendChild(imgCanvas);
this.element.appendChild(maskCanvas);
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
this.element.appendChild(bottom_panel);
document.body.appendChild(brush);
var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
self.brush_size = event.target.value;
self.updateBrushPreview(self, null, null);
});
var clearButton = this.createLeftButton("Clear",
() => {
self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height);
self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height);
});
var cancelButton = this.createRightButton("Cancel", () => {
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
self.close();
});
this.saveButton = this.createRightButton("Save", () => {
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
self.save();
});
this.element.appendChild(imgCanvas);
this.element.appendChild(maskCanvas);
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
this.element.appendChild(bottom_panel);
bottom_panel.appendChild(clearButton);
bottom_panel.appendChild(this.saveButton);
bottom_panel.appendChild(cancelButton);
bottom_panel.appendChild(brush_size_slider);
imgCanvas.style.position = "relative";
imgCanvas.style.top = "200";
imgCanvas.style.left = "0";
maskCanvas.style.position = "absolute";
}
show() {
if(!this.is_layout_created) {
// layout
const imgCanvas = document.createElement('canvas');
const maskCanvas = document.createElement('canvas');
const backupCanvas = document.createElement('canvas');
imgCanvas.id = "imageCanvas";
maskCanvas.id = "maskCanvas";
backupCanvas.id = "backupCanvas";
this.setlayout(imgCanvas, maskCanvas);
// prepare content
this.imgCanvas = imgCanvas;
this.maskCanvas = maskCanvas;
this.backupCanvas = backupCanvas;
this.maskCtx = maskCanvas.getContext('2d');
this.backupCtx = backupCanvas.getContext('2d');
this.setEventHandler(maskCanvas);
this.is_layout_created = true;
// replacement of onClose hook since close is not real close
const self = this;
const observer = new MutationObserver(function(mutations) {
mutations.forEach(function(mutation) {
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') {
ComfyApp.onClipspaceEditorClosed();
}
self.last_display_style = self.element.style.display;
}
});
});
const config = { attributes: true };
observer.observe(this.element, config);
}
this.setImages(this.imgCanvas, this.backupCanvas);
if(ComfyApp.clipspace_return_node) {
this.saveButton.innerText = "Save to node";
}
else {
this.saveButton.innerText = "Save";
}
this.saveButton.disabled = false;
this.element.style.display = "block";
this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority.
}
isOpened() {
return this.element.style.display == "block";
}
setImages(imgCanvas, backupCanvas) {
const imgCtx = imgCanvas.getContext('2d');
const backupCtx = backupCanvas.getContext('2d');
const maskCtx = this.maskCtx;
const maskCanvas = this.maskCanvas;
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height);
maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height);
// image load
const orig_image = new Image();
window.addEventListener("resize", () => {
// repositioning
imgCanvas.width = window.innerWidth - 250;
imgCanvas.height = window.innerHeight - 200;
// redraw image
let drawWidth = orig_image.width;
let drawHeight = orig_image.height;
if (orig_image.width > imgCanvas.width) {
drawWidth = imgCanvas.width;
drawHeight = (drawWidth / orig_image.width) * orig_image.height;
}
if (drawHeight > imgCanvas.height) {
drawHeight = imgCanvas.height;
drawWidth = (drawHeight / orig_image.height) * orig_image.width;
}
imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
// update mask
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
maskCanvas.width = drawWidth;
maskCanvas.height = drawHeight;
maskCanvas.style.top = imgCanvas.offsetTop + "px";
maskCanvas.style.left = imgCanvas.offsetLeft + "px";
maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
});
const filepath = ComfyApp.clipspace.images;
const touched_image = new Image();
touched_image.onload = function() {
backupCanvas.width = touched_image.width;
backupCanvas.height = touched_image.height;
prepareRGB(touched_image, backupCanvas, backupCtx);
};
const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src)
alpha_url.searchParams.delete('channel');
alpha_url.searchParams.set('channel', 'a');
touched_image.src = alpha_url;
// original image load
orig_image.onload = function() {
window.dispatchEvent(new Event('resize'));
};
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
rgb_url.searchParams.delete('channel');
rgb_url.searchParams.set('channel', 'rgb');
orig_image.src = rgb_url;
this.image = orig_image;
}
setEventHandler(maskCanvas) {
maskCanvas.addEventListener("contextmenu", (event) => {
event.preventDefault();
});
const self = this;
maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event));
maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event));
document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp);
maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event));
maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event));
maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; });
maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; });
document.addEventListener('keydown', MaskEditorDialog.handleKeyDown);
}
brush_size = 10;
drawing_mode = false;
lastx = -1;
lasty = -1;
lasttime = 0;
static handleKeyDown(event) {
const self = MaskEditorDialog.instance;
if (event.key === ']') {
self.brush_size = Math.min(self.brush_size+2, 100);
} else if (event.key === '[') {
self.brush_size = Math.max(self.brush_size-2, 1);
} else if(event.key === 'Enter') {
self.save();
}
self.updateBrushPreview(self);
}
static handlePointerUp(event) {
event.preventDefault();
MaskEditorDialog.instance.drawing_mode = false;
}
updateBrushPreview(self) {
const brush = self.brush;
var centerX = self.cursorX;
var centerY = self.cursorY;
brush.style.width = self.brush_size * 2 + "px";
brush.style.height = self.brush_size * 2 + "px";
brush.style.left = (centerX - self.brush_size) + "px";
brush.style.top = (centerY - self.brush_size) + "px";
}
handleWheelEvent(self, event) {
if(event.deltaY < 0)
self.brush_size = Math.min(self.brush_size+2, 100);
else
self.brush_size = Math.max(self.brush_size-2, 1);
self.brush_slider_input.value = self.brush_size;
self.updateBrushPreview(self);
}
draw_move(self, event) {
event.preventDefault();
this.cursorX = event.pageX;
this.cursorY = event.pageY;
self.updateBrushPreview(self);
if (window.TouchEvent && event instanceof TouchEvent || event.buttons == 1) {
var diff = performance.now() - self.lasttime;
const maskRect = self.maskCanvas.getBoundingClientRect();
var x = event.offsetX;
var y = event.offsetY
if(event.offsetX == null) {
x = event.targetTouches[0].clientX - maskRect.left;
}
if(event.offsetY == null) {
y = event.targetTouches[0].clientY - maskRect.top;
}
var brush_size = this.brush_size;
if(event instanceof PointerEvent && event.pointerType == 'pen') {
brush_size *= event.pressure;
this.last_pressure = event.pressure;
}
else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){
// The firing interval of PointerEvents in Pen is unreliable, so it is supplemented by TouchEvents.
brush_size *= this.last_pressure;
}
else {
brush_size = this.brush_size;
}
if(diff > 20 && !this.drawing_mode)
requestAnimationFrame(() => {
self.maskCtx.beginPath();
self.maskCtx.fillStyle = "rgb(0,0,0)";
self.maskCtx.globalCompositeOperation = "source-over";
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
self.maskCtx.fill();
self.lastx = x;
self.lasty = y;
});
else
requestAnimationFrame(() => {
self.maskCtx.beginPath();
self.maskCtx.fillStyle = "rgb(0,0,0)";
self.maskCtx.globalCompositeOperation = "source-over";
var dx = x - self.lastx;
var dy = y - self.lasty;
var distance = Math.sqrt(dx * dx + dy * dy);
var directionX = dx / distance;
var directionY = dy / distance;
for (var i = 0; i < distance; i+=5) {
var px = self.lastx + (directionX * i);
var py = self.lasty + (directionY * i);
self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false);
self.maskCtx.fill();
}
self.lastx = x;
self.lasty = y;
});
self.lasttime = performance.now();
}
else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) {
const maskRect = self.maskCanvas.getBoundingClientRect();
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
var brush_size = this.brush_size;
if(event instanceof PointerEvent && event.pointerType == 'pen') {
brush_size *= event.pressure;
this.last_pressure = event.pressure;
}
else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){
brush_size *= this.last_pressure;
}
else {
brush_size = this.brush_size;
}
if(diff > 20 && !drawing_mode) // cannot tracking drawing_mode for touch event
requestAnimationFrame(() => {
self.maskCtx.beginPath();
self.maskCtx.globalCompositeOperation = "destination-out";
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
self.maskCtx.fill();
self.lastx = x;
self.lasty = y;
});
else
requestAnimationFrame(() => {
self.maskCtx.beginPath();
self.maskCtx.globalCompositeOperation = "destination-out";
var dx = x - self.lastx;
var dy = y - self.lasty;
var distance = Math.sqrt(dx * dx + dy * dy);
var directionX = dx / distance;
var directionY = dy / distance;
for (var i = 0; i < distance; i+=5) {
var px = self.lastx + (directionX * i);
var py = self.lasty + (directionY * i);
self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false);
self.maskCtx.fill();
}
self.lastx = x;
self.lasty = y;
});
self.lasttime = performance.now();
}
}
handlePointerDown(self, event) {
var brush_size = this.brush_size;
if(event instanceof PointerEvent && event.pointerType == 'pen') {
brush_size *= event.pressure;
this.last_pressure = event.pressure;
}
if ([0, 2, 5].includes(event.button)) {
self.drawing_mode = true;
event.preventDefault();
const maskRect = self.maskCanvas.getBoundingClientRect();
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
self.maskCtx.beginPath();
if (event.button == 0) {
self.maskCtx.fillStyle = "rgb(0,0,0)";
self.maskCtx.globalCompositeOperation = "source-over";
} else {
self.maskCtx.globalCompositeOperation = "destination-out";
}
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
self.maskCtx.fill();
self.lastx = x;
self.lasty = y;
self.lasttime = performance.now();
}
}
async save() {
const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true});
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
backupCtx.drawImage(this.maskCanvas,
0, 0, this.maskCanvas.width, this.maskCanvas.height,
0, 0, this.backupCanvas.width, this.backupCanvas.height);
// paste mask data into alpha channel
const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height);
// refine mask image
for (let i = 0; i < backupData.data.length; i += 4) {
if(backupData.data[i+3] == 255)
backupData.data[i+3] = 0;
else
backupData.data[i+3] = 255;
backupData.data[i] = 0;
backupData.data[i+1] = 0;
backupData.data[i+2] = 0;
}
backupCtx.globalCompositeOperation = 'source-over';
backupCtx.putImageData(backupData, 0, 0);
const formData = new FormData();
const filename = "clipspace-mask-" + performance.now() + ".png";
const item =
{
"filename": filename,
"subfolder": "clipspace",
"type": "input",
};
if(ComfyApp.clipspace.images)
ComfyApp.clipspace.images[0] = item;
if(ComfyApp.clipspace.widgets) {
const index = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image');
if(index >= 0)
ComfyApp.clipspace.widgets[index].value = item;
}
const dataURL = this.backupCanvas.toDataURL();
const blob = dataURLToBlob(dataURL);
const original_blob = loadedImageToBlob(this.image);
formData.append('image', blob, filename);
formData.append('original_image', original_blob);
formData.append('type', "input");
formData.append('subfolder', "clipspace");
this.saveButton.innerText = "Saving...";
this.saveButton.disabled = true;
await uploadMask(item, formData);
ComfyApp.onClipspaceEditorSave();
this.close();
}
}
app.registerExtension({
name: "Comfy.MaskEditor",
init(app) {
ComfyApp.open_maskeditor =
function () {
const dlg = MaskEditorDialog.getInstance();
if(!dlg.isOpened()) {
dlg.show();
}
};
const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0
ClipspaceDialog.registerButton("MaskEditor", context_predicate, ComfyApp.open_maskeditor);
}
});

View File

@ -0,0 +1,41 @@
import {app} from "../../scripts/app.js";
import {ComfyWidgets} from "../../scripts/widgets.js";
// Node that add notes to your project
app.registerExtension({
name: "Comfy.NoteNode",
registerCustomNodes() {
class NoteNode {
color=LGraphCanvas.node_colors.yellow.color;
bgcolor=LGraphCanvas.node_colors.yellow.bgcolor;
groupcolor = LGraphCanvas.node_colors.yellow.groupcolor;
constructor() {
if (!this.properties) {
this.properties = {};
this.properties.text="";
}
ComfyWidgets.STRING(this, "", ["", {default:this.properties.text, multiline: true}], app)
this.serialize_widgets = true;
this.isVirtualNode = true;
}
}
// Load default visibility
LiteGraph.registerNodeType(
"Note",
Object.assign(NoteNode, {
title_mode: LiteGraph.NORMAL_TITLE,
title: "Note",
collapsable: true,
})
);
NoteNode.category = "utils";
},
});

View File

@ -1,21 +1,91 @@
import { app } from "/scripts/app.js";
import { ComfyWidgets } from "/scripts/widgets.js";
// Adds defaults for quickly adding nodes with middle click on the input/output
app.registerExtension({
name: "Comfy.SlotDefaults",
suggestionsNumber: null,
init() {
LiteGraph.search_filter_enabled = true;
LiteGraph.middle_click_slot_add_default_node = true;
LiteGraph.slot_types_default_in = {
MODEL: "CheckpointLoaderSimple",
LATENT: "EmptyLatentImage",
VAE: "VAELoader",
};
LiteGraph.slot_types_default_out = {
LATENT: "VAEDecode",
IMAGE: "SaveImage",
CLIP: "CLIPTextEncode",
};
this.suggestionsNumber = app.ui.settings.addSetting({
id: "Comfy.NodeSuggestions.number",
name: "number of nodes suggestions",
type: "slider",
attrs: {
min: 1,
max: 100,
step: 1,
},
defaultValue: 5,
onChange: (newVal, oldVal) => {
this.setDefaults(newVal);
}
});
},
slot_types_default_out: {},
slot_types_default_in: {},
async beforeRegisterNodeDef(nodeType, nodeData, app) {
var nodeId = nodeData.name;
var inputs = [];
inputs = nodeData["input"]["required"]; //only show required inputs to reduce the mess also not logical to create node with optional inputs
for (const inputKey in inputs) {
var input = (inputs[inputKey]);
if (typeof input[0] !== "string") continue;
var type = input[0]
if (type in ComfyWidgets) {
var customProperties = input[1]
if (!(customProperties?.forceInput)) continue; //ignore widgets that don't force input
}
if (!(type in this.slot_types_default_out)) {
this.slot_types_default_out[type] = ["Reroute"];
}
if (this.slot_types_default_out[type].includes(nodeId)) continue;
this.slot_types_default_out[type].push(nodeId);
// Input types have to be stored as lower case
// Store each node that can handle this input type
const lowerType = type.toLocaleLowerCase();
if (!(lowerType in LiteGraph.registered_slot_in_types)) {
LiteGraph.registered_slot_in_types[lowerType] = { nodes: [] };
}
LiteGraph.registered_slot_in_types[lowerType].nodes.push(nodeType.comfyClass);
}
var outputs = nodeData["output"];
for (const key in outputs) {
var type = outputs[key];
if (!(type in this.slot_types_default_in)) {
this.slot_types_default_in[type] = ["Reroute"];// ["Reroute", "Primitive"]; primitive doesn't always work :'()
}
this.slot_types_default_in[type].push(nodeId);
// Store each node that can handle this output type
if (!(type in LiteGraph.registered_slot_out_types)) {
LiteGraph.registered_slot_out_types[type] = { nodes: [] };
}
LiteGraph.registered_slot_out_types[type].nodes.push(nodeType.comfyClass);
if(!LiteGraph.slot_types_out.includes(type)) {
LiteGraph.slot_types_out.push(type);
}
}
var maxNum = this.suggestionsNumber.value;
this.setDefaults(maxNum);
},
setDefaults(maxNum) {
LiteGraph.slot_types_default_out = {};
LiteGraph.slot_types_default_in = {};
for (const type in this.slot_types_default_out) {
LiteGraph.slot_types_default_out[type] = this.slot_types_default_out[type].slice(0, maxNum);
}
for (const type in this.slot_types_default_in) {
LiteGraph.slot_types_default_in[type] = this.slot_types_default_in[type].slice(0, maxNum);
}
}
});

View File

@ -9,7 +9,7 @@ app.registerExtension({
app.ui.settings.addSetting({
id: "Comfy.SnapToGrid.GridSize",
name: "Grid Size",
type: "number",
type: "slider",
attrs: {
min: 1,
max: 500,

View File

@ -159,27 +159,33 @@ app.registerExtension({
const r = origOnInputDblClick ? origOnInputDblClick.apply(this, arguments) : undefined;
const input = this.inputs[slot];
if (input.widget && !input[ignoreDblClick]) {
const node = LiteGraph.createNode("PrimitiveNode");
app.graph.add(node);
// Calculate a position that wont directly overlap another node
const pos = [this.pos[0] - node.size[0] - 30, this.pos[1]];
while (isNodeAtPos(pos)) {
pos[1] += LiteGraph.NODE_TITLE_HEIGHT;
if (!input.widget || !input[ignoreDblClick]) {
// Not a widget input or already handled input
if (!(input.type in ComfyWidgets) && !(input.widget.config?.[0] instanceof Array)) {
return r; //also Not a ComfyWidgets input or combo (do nothing)
}
node.pos = pos;
node.connect(0, this, slot);
node.title = input.name;
// Prevent adding duplicates due to triple clicking
input[ignoreDblClick] = true;
setTimeout(() => {
delete input[ignoreDblClick];
}, 300);
}
// Create a primitive node
const node = LiteGraph.createNode("PrimitiveNode");
app.graph.add(node);
// Calculate a position that wont directly overlap another node
const pos = [this.pos[0] - node.size[0] - 30, this.pos[1]];
while (isNodeAtPos(pos)) {
pos[1] += LiteGraph.NODE_TITLE_HEIGHT;
}
node.pos = pos;
node.connect(0, this, slot);
node.title = input.name;
// Prevent adding duplicates due to triple clicking
input[ignoreDblClick] = true;
setTimeout(() => {
delete input[ignoreDblClick];
}, 300);
return r;
};
},
@ -233,7 +239,9 @@ app.registerExtension({
// Fires before the link is made allowing us to reject it if it isn't valid
// No widget, we cant connect
if (!input.widget) return false;
if (!input.widget) {
if (!(input.type in ComfyWidgets)) return false;
}
if (this.outputs[slot].links?.length) {
return this.#isValidConnection(input);
@ -252,9 +260,17 @@ app.registerExtension({
const input = theirNode.inputs[link.target_slot];
if (!input) return;
const widget = input.widget;
const { type, linkType } = getWidgetType(widget.config);
var _widget;
if (!input.widget) {
if (!(input.type in ComfyWidgets)) return;
_widget = { "name": input.name, "config": [input.type, {}] }//fake widget
} else {
_widget = input.widget;
}
const widget = _widget;
const { type, linkType } = getWidgetType(widget.config);
// Update our output to restrict to the widget type
this.outputs[0].type = linkType;
this.outputs[0].name = type;
@ -274,7 +290,7 @@ app.registerExtension({
if (type in ComfyWidgets) {
widget = (ComfyWidgets[type](this, "value", inputData, app) || {}).widget;
} else {
widget = this.addWidget(type, "value", null, () => {}, {});
widget = this.addWidget(type, "value", null, () => { }, {});
}
if (node?.widgets && widget) {
@ -284,7 +300,7 @@ app.registerExtension({
}
}
if (widget.type === "number") {
if (widget.type === "number" || widget.type === "combo") {
addValueControlWidget(this, widget, "fixed");
}
@ -319,7 +335,20 @@ app.registerExtension({
const config1 = this.outputs[0].widget.config;
const config2 = input.widget.config;
if (config1[0] !== config2[0]) return false;
if (config1[0] instanceof Array) {
// These checks shouldnt actually be necessary as the types should match
// but double checking doesn't hurt
// New input isnt a combo
if (!(config2[0] instanceof Array)) return false;
// New imput combo has a different size
if (config1[0].length !== config2[0].length) return false;
// New input combo has different elements
if (config1[0].find((v, i) => config2[0][i] !== v)) return false;
} else if (config1[0] !== config2[0]) {
// Configs dont match
return false;
}
for (const k in config1[1]) {
if (k !== "default") {

View File

@ -3628,6 +3628,18 @@
return size;
};
LGraphNode.prototype.inResizeCorner = function(canvasX, canvasY) {
var rows = this.outputs ? this.outputs.length : 1;
var outputs_offset = (this.constructor.slot_start_y || 0) + rows * LiteGraph.NODE_SLOT_HEIGHT;
return isInsideRectangle(canvasX,
canvasY,
this.pos[0] + this.size[0] - 15,
this.pos[1] + Math.max(this.size[1] - 15, outputs_offset),
20,
20
);
}
/**
* returns all the info available about a property of this node.
*
@ -5868,23 +5880,16 @@ LGraphNode.prototype.executeAction = function(action)
//when clicked on top of a node
//and it is not interactive
if (node && this.allow_interaction && !skip_action && !this.read_only) {
if (node && (this.allow_interaction || node.flags.allow_interaction) && !skip_action && !this.read_only) {
if (!this.live_mode && !node.flags.pinned) {
this.bringToFront(node);
} //if it wasn't selected?
//not dragging mouse to connect two slots
if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
if ( this.allow_interaction && !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
//Search for corner for resize
if ( !skip_action &&
node.resizable !== false &&
isInsideRectangle( e.canvasX,
e.canvasY,
node.pos[0] + node.size[0] - 5,
node.pos[1] + node.size[1] - 5,
10,
10
)
node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY)
) {
this.graph.beforeChange();
this.resizing_node = node;
@ -6028,7 +6033,7 @@ LGraphNode.prototype.executeAction = function(action)
}
//double clicking
if (is_double_click && this.selected_nodes[node.id]) {
if (this.allow_interaction && is_double_click && this.selected_nodes[node.id]) {
//double click node
if (node.onDblClick) {
node.onDblClick( e, pos, this );
@ -6302,6 +6307,9 @@ LGraphNode.prototype.executeAction = function(action)
this.dirty_canvas = true;
}
//get node over
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
if (this.dragging_rectangle)
{
this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0];
@ -6331,14 +6339,11 @@ LGraphNode.prototype.executeAction = function(action)
this.ds.offset[1] += delta[1] / this.ds.scale;
this.dirty_canvas = true;
this.dirty_bgcanvas = true;
} else if (this.allow_interaction && !this.read_only) {
} else if ((this.allow_interaction || (node && node.flags.allow_interaction)) && !this.read_only) {
if (this.connecting_node) {
this.dirty_canvas = true;
}
//get node over
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
//remove mouseover flag
for (var i = 0, l = this.graph._nodes.length; i < l; ++i) {
if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) {
@ -6424,16 +6429,7 @@ LGraphNode.prototype.executeAction = function(action)
//Search for corner
if (this.canvas) {
if (
isInsideRectangle(
e.canvasX,
e.canvasY,
node.pos[0] + node.size[0] - 5,
node.pos[1] + node.size[1] - 5,
5,
5
)
) {
if (node.inResizeCorner(e.canvasX, e.canvasY)) {
this.canvas.style.cursor = "se-resize";
} else {
this.canvas.style.cursor = "crosshair";
@ -9738,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action)
if (show_text) {
ctx.textAlign = "center";
ctx.fillStyle = text_color;
ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7);
ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7);
}
break;
case "toggle":
@ -9759,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action)
ctx.fill();
if (show_text) {
ctx.fillStyle = secondary_text_color;
if (w.name != null) {
ctx.fillText(w.name, margin * 2, y + H * 0.7);
const label = w.label || w.name;
if (label != null) {
ctx.fillText(label, margin * 2, y + H * 0.7);
}
ctx.fillStyle = w.value ? text_color : secondary_text_color;
ctx.textAlign = "right";
@ -9795,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action)
ctx.textAlign = "center";
ctx.fillStyle = text_color;
ctx.fillText(
w.name + " " + Number(w.value).toFixed(3),
w.label || w.name + " " + Number(w.value).toFixed(3),
widget_width * 0.5,
y + H * 0.7
);
@ -9830,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action)
ctx.fill();
}
ctx.fillStyle = secondary_text_color;
ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7);
ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7);
ctx.fillStyle = text_color;
ctx.textAlign = "right";
if (w.type == "number") {
@ -9882,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action)
//ctx.stroke();
ctx.fillStyle = secondary_text_color;
if (w.name != null) {
ctx.fillText(w.name, margin * 2, y + H * 0.7);
const label = w.label || w.name;
if (label != null) {
ctx.fillText(label, margin * 2, y + H * 0.7);
}
ctx.fillStyle = text_color;
ctx.textAlign = "right";
@ -9915,7 +9913,7 @@ LGraphNode.prototype.executeAction = function(action)
event,
active_widget
) {
if (!node.widgets || !node.widgets.length) {
if (!node.widgets || !node.widgets.length || (!this.allow_interaction && !node.flags.allow_interaction)) {
return null;
}
@ -9953,11 +9951,11 @@ LGraphNode.prototype.executeAction = function(action)
}
break;
case "slider":
var range = w.options.max - w.options.min;
var old_value = w.value;
var nvalue = Math.clamp((x - 15) / (widget_width - 30), 0, 1);
if(w.options.read_only) break;
w.value = w.options.min + (w.options.max - w.options.min) * nvalue;
if (w.callback) {
if (old_value != w.value) {
setTimeout(function() {
inner_value_change(w, w.value);
}, 20);
@ -10044,7 +10042,7 @@ LGraphNode.prototype.executeAction = function(action)
if (event.click_time < 200 && delta == 0) {
this.prompt("Value",w.value,function(v) {
// check if v is a valid equation or a number
if (/^[0-9+\-*/()\s]+$/.test(v)) {
if (/^[0-9+\-*/()\s]+|\d+\.\d+$/.test(v)) {
try {//solve the equation if possible
v = eval(v);
} catch (e) { }
@ -10304,6 +10302,119 @@ LGraphNode.prototype.executeAction = function(action)
canvas.graph.add(group);
};
/**
* Determines the furthest nodes in each direction
* @param nodes {LGraphNode[]} the nodes to from which boundary nodes will be extracted
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
*/
LGraphCanvas.getBoundaryNodes = function(nodes) {
let top = null;
let right = null;
let bottom = null;
let left = null;
for (const nID in nodes) {
const node = nodes[nID];
const [x, y] = node.pos;
const [width, height] = node.size;
if (top === null || y < top.pos[1]) {
top = node;
}
if (right === null || x + width > right.pos[0] + right.size[0]) {
right = node;
}
if (bottom === null || y + height > bottom.pos[1] + bottom.size[1]) {
bottom = node;
}
if (left === null || x < left.pos[0]) {
left = node;
}
}
return {
"top": top,
"right": right,
"bottom": bottom,
"left": left
};
}
/**
* Determines the furthest nodes in each direction for the currently selected nodes
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
*/
LGraphCanvas.prototype.boundaryNodesForSelection = function() {
return LGraphCanvas.getBoundaryNodes(Object.values(this.selected_nodes));
}
/**
*
* @param {LGraphNode[]} nodes a list of nodes
* @param {"top"|"bottom"|"left"|"right"} direction Direction to align the nodes
* @param {LGraphNode?} align_to Node to align to (if null, align to the furthest node in the given direction)
*/
LGraphCanvas.alignNodes = function (nodes, direction, align_to) {
if (!nodes) {
return;
}
const canvas = LGraphCanvas.active_canvas;
let boundaryNodes = []
if (align_to === undefined) {
boundaryNodes = LGraphCanvas.getBoundaryNodes(nodes)
} else {
boundaryNodes = {
"top": align_to,
"right": align_to,
"bottom": align_to,
"left": align_to
}
}
for (const [_, node] of Object.entries(canvas.selected_nodes)) {
switch (direction) {
case "right":
node.pos[0] = boundaryNodes["right"].pos[0] + boundaryNodes["right"].size[0] - node.size[0];
break;
case "left":
node.pos[0] = boundaryNodes["left"].pos[0];
break;
case "top":
node.pos[1] = boundaryNodes["top"].pos[1];
break;
case "bottom":
node.pos[1] = boundaryNodes["bottom"].pos[1] + boundaryNodes["bottom"].size[1] - node.size[1];
break;
}
}
canvas.dirty_canvas = true;
canvas.dirty_bgcanvas = true;
};
LGraphCanvas.onNodeAlign = function(value, options, event, prev_menu, node) {
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
event: event,
callback: inner_clicked,
parentMenu: prev_menu,
});
function inner_clicked(value) {
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase(), node);
}
}
LGraphCanvas.onGroupAlign = function(value, options, event, prev_menu) {
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
event: event,
callback: inner_clicked,
parentMenu: prev_menu,
});
function inner_clicked(value) {
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase());
}
}
LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) {
var canvas = LGraphCanvas.active_canvas;
@ -12904,6 +13015,14 @@ LGraphNode.prototype.executeAction = function(action)
options.push({ content: "Options", callback: that.showShowGraphOptionsPanel });
}*/
if (Object.keys(this.selected_nodes).length > 1) {
options.push({
content: "Align",
has_submenu: true,
callback: LGraphCanvas.onGroupAlign,
})
}
if (this._graph_stack && this._graph_stack.length > 0) {
options.push(null, {
content: "Close subgraph",
@ -13018,6 +13137,14 @@ LGraphNode.prototype.executeAction = function(action)
callback: LGraphCanvas.onMenuNodeToSubgraph
});
if (Object.keys(this.selected_nodes).length > 1) {
options.push({
content: "Align Selected To",
has_submenu: true,
callback: LGraphCanvas.onNodeAlign,
})
}
options.push(null, {
content: "Remove",
disabled: !(node.removable !== false && !node.block_delete ),

View File

@ -35,7 +35,7 @@ class ComfyApi extends EventTarget {
}
let opened = false;
let existingSession = sessionStorage["Comfy.SessionId"] || "";
let existingSession = window.name;
if (existingSession) {
existingSession = "?clientId=" + existingSession;
}
@ -75,7 +75,7 @@ class ComfyApi extends EventTarget {
case "status":
if (msg.data.sid) {
this.clientId = msg.data.sid;
sessionStorage["Comfy.SessionId"] = this.clientId;
window.name = this.clientId;
}
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
break;
@ -163,7 +163,7 @@ class ComfyApi extends EventTarget {
if (res.status !== 200) {
throw {
response: await res.text(),
response: await res.json(),
};
}
}

View File

@ -2,29 +2,167 @@ import { ComfyWidgets } from "./widgets.js";
import { ComfyUI, $el } from "./ui.js";
import { api } from "./api.js";
import { defaultGraph } from "./defaultGraph.js";
import { getPngMetadata, importA1111 } from "./pnginfo.js";
import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
class ComfyApp {
/**
* List of {number, batchCount} entries to queue
/**
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
*/
export class ComfyApp {
/**
* List of entries to queue
* @type {{number: number, batchCount: number}[]}
*/
#queueItems = [];
/**
* If the queue is currently being processed
* @type {boolean}
*/
#processingQueue = false;
/**
* Content Clipboard
* @type {serialized node object}
*/
static clipspace = null;
static clipspace_invalidate_handler = null;
static open_maskeditor = null;
static clipspace_return_node = null;
constructor() {
this.ui = new ComfyUI(this);
/**
* List of extensions that are registered with the app
* @type {ComfyExtension[]}
*/
this.extensions = [];
/**
* Stores the execution output data for each node
* @type {Record<string, any>}
*/
this.nodeOutputs = {};
/**
* If the shift key on the keyboard is pressed
* @type {boolean}
*/
this.shiftDown = false;
}
static isImageNode(node) {
return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0);
}
static onClipspaceEditorSave() {
if(ComfyApp.clipspace_return_node) {
ComfyApp.pasteFromClipspace(ComfyApp.clipspace_return_node);
}
}
static onClipspaceEditorClosed() {
ComfyApp.clipspace_return_node = null;
}
static copyToClipspace(node) {
var widgets = null;
if(node.widgets) {
widgets = node.widgets.map(({ type, name, value }) => ({ type, name, value }));
}
var imgs = undefined;
var orig_imgs = undefined;
if(node.imgs != undefined) {
imgs = [];
orig_imgs = [];
for (let i = 0; i < node.imgs.length; i++) {
imgs[i] = new Image();
imgs[i].src = node.imgs[i].src;
orig_imgs[i] = imgs[i];
}
}
var selectedIndex = 0;
if(node.imageIndex) {
selectedIndex = node.imageIndex;
}
ComfyApp.clipspace = {
'widgets': widgets,
'imgs': imgs,
'original_imgs': orig_imgs,
'images': node.images,
'selectedIndex': selectedIndex,
'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action
};
ComfyApp.clipspace_return_node = null;
if(ComfyApp.clipspace_invalidate_handler) {
ComfyApp.clipspace_invalidate_handler();
}
}
static pasteFromClipspace(node) {
if(ComfyApp.clipspace) {
// image paste
if(ComfyApp.clipspace.imgs && node.imgs) {
if(node.images && ComfyApp.clipspace.images) {
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
}
else
app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images;
}
if(ComfyApp.clipspace.imgs) {
// deep-copy to cut link with clipspace
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
const img = new Image();
img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
node.imgs = [img];
node.imageIndex = 0;
}
else {
const imgs = [];
for(let i=0; i<ComfyApp.clipspace.imgs.length; i++) {
imgs[i] = new Image();
imgs[i].src = ComfyApp.clipspace.imgs[i].src;
node.imgs = imgs;
}
}
}
}
if(node.widgets) {
if(ComfyApp.clipspace.images) {
const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']];
const index = node.widgets.findIndex(obj => obj.name === 'image');
if(index >= 0) {
node.widgets[index].value = clip_image;
}
}
if(ComfyApp.clipspace.widgets) {
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name);
if (prop && prop.type != 'button') {
prop.value = value;
prop.callback(value);
}
});
}
}
app.graph.setDirtyCanvas(true);
}
}
/**
* Invoke an extension callback
* @param {string} method The extension callback to execute
* @param {...any} args Any arguments to pass to the callback
* @param {keyof ComfyExtension} method The extension callback to execute
* @param {any[]} args Any arguments to pass to the callback
* @returns
*/
#invokeExtensions(method, ...args) {
@ -109,6 +247,32 @@ class ComfyApp {
);
}
}
// prevent conflict of clipspace content
if(!ComfyApp.clipspace_return_node) {
options.push({
content: "Copy (Clipspace)",
callback: (obj) => { ComfyApp.copyToClipspace(this); }
});
if(ComfyApp.clipspace != null) {
options.push({
content: "Paste (Clipspace)",
callback: () => { ComfyApp.pasteFromClipspace(this); }
});
}
if(ComfyApp.isImageNode(this)) {
options.push({
content: "Open in MaskEditor",
callback: (obj) => {
ComfyApp.copyToClipspace(this);
ComfyApp.clipspace_return_node = this;
ComfyApp.open_maskeditor();
}
});
}
}
};
}
@ -159,6 +323,34 @@ class ComfyApp {
*/
#addDrawBackgroundHandler(node) {
const app = this;
function getImageTop(node) {
let shiftY;
if (node.imageOffset != null) {
shiftY = node.imageOffset;
} else {
if (node.widgets?.length) {
const w = node.widgets[node.widgets.length - 1];
shiftY = w.last_y;
if (w.computeSize) {
shiftY += w.computeSize()[1] + 4;
} else {
shiftY += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
} else {
shiftY = node.computeSize()[1];
}
}
return shiftY;
}
node.prototype.setSizeForImage = function () {
const minHeight = getImageTop(this) + 220;
if (this.size[1] < minHeight) {
this.setSize([this.size[0], minHeight]);
}
};
node.prototype.onDrawBackground = function (ctx) {
if (!this.flags.collapsed) {
const output = app.nodeOutputs[this.id + ""];
@ -179,9 +371,7 @@ class ComfyApp {
).then((imgs) => {
if (this.images === output.images) {
this.imgs = imgs.filter(Boolean);
if (this.size[1] < 100) {
this.size[1] = 250;
}
this.setSizeForImage?.();
app.graph.setDirtyCanvas(true);
}
});
@ -206,12 +396,7 @@ class ComfyApp {
this.imageIndex = imageIndex = 0;
}
let shiftY;
if (this.imageOffset != null) {
shiftY = this.imageOffset;
} else {
shiftY = this.computeSize()[1];
}
const shiftY = getImageTop(this);
let dw = this.size[0];
let dh = this.size[1];
@ -599,7 +784,7 @@ class ComfyApp {
ctx.globalAlpha = 0.8;
ctx.beginPath();
if (shape == LiteGraph.BOX_SHAPE)
ctx.rect(-6, -6 + LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT);
ctx.rect(-6, -6 - LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT);
else if (shape == LiteGraph.ROUND_SHAPE || (shape == LiteGraph.CARD_SHAPE && node.flags.collapsed))
ctx.roundRect(
-6,
@ -611,12 +796,11 @@ class ComfyApp {
else if (shape == LiteGraph.CARD_SHAPE)
ctx.roundRect(
-6,
-6 + LiteGraph.NODE_TITLE_HEIGHT,
-6 - LiteGraph.NODE_TITLE_HEIGHT,
12 + size[0] + 1,
12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT,
this.round_radius * 2,
2
);
[this.round_radius * 2, this.round_radius * 2, 2, 2]
);
else if (shape == LiteGraph.CIRCLE_SHAPE)
ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI * 2);
ctx.strokeStyle = color;
@ -691,11 +875,6 @@ class ComfyApp {
#addKeyboardHandler() {
window.addEventListener("keydown", (e) => {
this.shiftDown = e.shiftKey;
// Queue prompt using ctrl or command + enter
if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) {
this.queuePrompt(e.shiftKey ? -1 : 0);
}
});
window.addEventListener("keyup", (e) => {
this.shiftDown = e.shiftKey;
@ -723,7 +902,9 @@ class ComfyApp {
await this.#loadExtensions();
// Create and mount the LiteGraph in the DOM
const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" }));
const mainCanvas = document.createElement("canvas")
mainCanvas.style.touchAction = "none"
const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" }));
canvasEl.tabIndex = "1";
document.body.prepend(canvasEl);
@ -835,7 +1016,8 @@ class ComfyApp {
for (const o in nodeData["output"]) {
const output = nodeData["output"][o];
const outputName = nodeData["output_name"][o] || output;
this.addOutput(outputName, output);
const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
this.addOutput(outputName, output, { shape: outputShape });
}
const s = this.computeSize();
@ -872,8 +1054,10 @@ class ComfyApp {
loadGraphData(graphData) {
this.clean();
let reset_invalid_values = false;
if (!graphData) {
graphData = structuredClone(defaultGraph);
reset_invalid_values = true;
}
const missingNodeTypes = [];
@ -949,9 +1133,20 @@ class ComfyApp {
widget.value = widget.value.slice(7);
}
}
}
if (node.type == "KSampler" || node.type == "KSamplerAdvanced" || node.type == "PrimitiveNode") {
if (widget.name == "control_after_generate") {
if (widget.value == true) {
if (widget.value === true) {
widget.value = "randomize";
} else if (widget.value === false) {
widget.value = "fixed";
}
}
}
if (reset_invalid_values) {
if (widget.type == "combo") {
if (!widget.options.values.includes(widget.value) && widget.options.values.length > 0) {
widget.value = widget.options.values[0];
}
}
}
@ -1067,7 +1262,7 @@ class ComfyApp {
try {
await api.queuePrompt(number, p);
} catch (error) {
this.ui.dialog.show(error.response || error.toString());
this.ui.dialog.show(error.response.error || error.toString());
break;
}
@ -1113,9 +1308,18 @@ class ComfyApp {
this.loadGraphData(JSON.parse(reader.result));
};
reader.readAsText(file);
} else if (file.name?.endsWith(".latent")) {
const info = await getLatentMetadata(file);
if (info.workflow) {
this.loadGraphData(JSON.parse(info.workflow));
}
}
}
/**
* Registers a Comfy web extension with the app
* @param {ComfyExtension} extension
*/
registerExtension(extension) {
if (!extension.name) {
throw new Error("Extensions must have a 'name' property.");
@ -1139,12 +1343,12 @@ class ComfyApp {
for(const widgetNum in node.widgets) {
const widget = node.widgets[widgetNum]
if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) {
widget.options.values = def["input"]["required"][widget.name][0];
if(!widget.options.values.includes(widget.value)) {
if(widget.name != 'image' && !widget.options.values.includes(widget.value)) {
widget.value = widget.options.values[0];
widget.callback(widget.value);
}
}
}

View File

@ -47,6 +47,22 @@ export function getPngMetadata(file) {
});
}
export function getLatentMetadata(file) {
return new Promise((r) => {
const reader = new FileReader();
reader.onload = (event) => {
const safetensorsData = new Uint8Array(event.target.result);
const dataView = new DataView(safetensorsData.buffer);
let header_size = dataView.getUint32(0, true);
let offset = 8;
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
r(header.__metadata__);
};
reader.readAsArrayBuffer(file);
});
}
export async function importA1111(graph, parameters) {
const p = parameters.lastIndexOf("\nSteps:");
if (p > -1) {
@ -131,6 +147,7 @@ export async function importA1111(graph, parameters) {
}
function replaceEmbeddings(text) {
if(!embeddings.length) return text;
return text.replaceAll(
new RegExp(
"\\b(" + embeddings.map((e) => e.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")).join("\\b|\\b") + ")\\b",

View File

@ -270,6 +270,30 @@ class ComfySettingsDialog extends ComfyDialog {
]),
]);
break;
case "slider":
element = $el("div", [
$el("label", { textContent: name }, [
$el("input", {
type: "range",
value,
oninput: (e) => {
setter(e.target.value);
e.target.nextElementSibling.value = e.target.value;
},
...attrs
}),
$el("input", {
type: "number",
value,
oninput: (e) => {
setter(e.target.value);
e.target.previousElementSibling.value = e.target.value;
},
...attrs
}),
]),
]);
break;
default:
console.warn("Unsupported setting type, defaulting to text");
element = $el("div", [
@ -431,9 +455,17 @@ export class ComfyUI {
defaultValue: true,
});
const promptFilename = this.settings.addSetting({
id: "Comfy.PromptFilename",
name: "Prompt for filename when saving workflow",
type: "boolean",
defaultValue: true,
});
const fileInput = $el("input", {
id: "comfy-file-input",
type: "file",
accept: ".json,image/png",
accept: ".json,image/png,.latent",
style: { display: "none" },
parent: document.body,
onchange: () => {
@ -448,6 +480,7 @@ export class ComfyUI {
$el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }),
]),
$el("button.comfy-queue-btn", {
id: "queue-button",
textContent: "Queue Prompt",
onclick: () => app.queuePrompt(0, this.batchCount),
}),
@ -496,9 +529,10 @@ export class ComfyUI {
]),
]),
$el("div.comfy-menu-btns", [
$el("button", { textContent: "Queue Front", onclick: () => app.queuePrompt(-1, this.batchCount) }),
$el("button", { id: "queue-front-button", textContent: "Queue Front", onclick: () => app.queuePrompt(-1, this.batchCount) }),
$el("button", {
$: (b) => (this.queue.button = b),
id: "comfy-view-queue-button",
textContent: "View Queue",
onclick: () => {
this.history.hide();
@ -507,6 +541,7 @@ export class ComfyUI {
}),
$el("button", {
$: (b) => (this.history.button = b),
id: "comfy-view-history-button",
textContent: "View History",
onclick: () => {
this.queue.hide();
@ -517,14 +552,23 @@ export class ComfyUI {
this.queue.element,
this.history.element,
$el("button", {
id: "comfy-save-button",
textContent: "Save",
onclick: () => {
let filename = "workflow.json";
if (promptFilename.value) {
filename = prompt("Save workflow as:", filename);
if (!filename) return;
if (!filename.toLowerCase().endsWith(".json")) {
filename += ".json";
}
}
const json = JSON.stringify(app.graph.serialize(), null, 2); // convert the data to a JSON string
const blob = new Blob([json], { type: "application/json" });
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: "workflow.json",
download: filename,
style: { display: "none" },
parent: document.body,
});
@ -535,15 +579,16 @@ export class ComfyUI {
}, 0);
},
}),
$el("button", { textContent: "Load", onclick: () => fileInput.click() }),
$el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
$el("button", { textContent: "Clear", onclick: () => {
$el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }),
$el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
$el("button", { id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace() }),
$el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => {
if (!confirmClear.value || confirm("Clear workflow?")) {
app.clean();
app.graph.clear();
}
}}),
$el("button", { textContent: "Load Default", onclick: () => {
$el("button", { id: "comfy-load-default-button", textContent: "Load Default", onclick: () => {
if (!confirmClear.value || confirm("Load default workflow?")) {
app.loadGraphData()
}

View File

@ -19,35 +19,60 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
var v = valueControl.value;
let min = targetWidget.options.min;
let max = targetWidget.options.max;
// limit to something that javascript can handle
max = Math.min(1125899906842624, max);
min = Math.max(-1125899906842624, min);
let range = (max - min) / (targetWidget.options.step / 10);
if (targetWidget.type == "combo" && v !== "fixed") {
let current_index = targetWidget.options.values.indexOf(targetWidget.value);
let current_length = targetWidget.options.values.length;
//adjust values based on valueControl Behaviour
switch (v) {
case "fixed":
break;
case "increment":
targetWidget.value += targetWidget.options.step / 10;
break;
case "decrement":
targetWidget.value -= targetWidget.options.step / 10;
break;
case "randomize":
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
default:
break;
switch (v) {
case "increment":
current_index += 1;
break;
case "decrement":
current_index -= 1;
break;
case "randomize":
current_index = Math.floor(Math.random() * current_length);
default:
break;
}
current_index = Math.max(0, current_index);
current_index = Math.min(current_length - 1, current_index);
if (current_index >= 0) {
let value = targetWidget.options.values[current_index];
targetWidget.value = value;
targetWidget.callback(value);
}
} else { //number
let min = targetWidget.options.min;
let max = targetWidget.options.max;
// limit to something that javascript can handle
max = Math.min(1125899906842624, max);
min = Math.max(-1125899906842624, min);
let range = (max - min) / (targetWidget.options.step / 10);
//adjust values based on valueControl Behaviour
switch (v) {
case "fixed":
break;
case "increment":
targetWidget.value += targetWidget.options.step / 10;
break;
case "decrement":
targetWidget.value -= targetWidget.options.step / 10;
break;
case "randomize":
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
default:
break;
}
/*check if values are over or under their respective
* ranges and set them to min or max.*/
if (targetWidget.value < min)
targetWidget.value = min;
if (targetWidget.value > max)
targetWidget.value = max;
}
/*check if values are over or under their respective
* ranges and set them to min or max.*/
if (targetWidget.value < min)
targetWidget.value = min;
if (targetWidget.value > max)
targetWidget.value = max;
}
return valueControl;
};
@ -136,9 +161,11 @@ function addMultilineWidget(node, name, opts, app) {
left: `${t.a * margin + t.e}px`,
top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`,
width: `${(widgetWidth - margin * 2 - 3) * t.a}px`,
background: (!node.color)?'':node.color,
height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`,
position: "absolute",
zIndex: 1,
color: (!node.color)?'':'white',
zIndex: app.graph._nodes.indexOf(node),
fontSize: `${t.d * 10.0}px`,
});
this.inputEl.hidden = !visible;
@ -259,19 +286,51 @@ export const ComfyWidgets = {
let uploadWidget;
function showImage(name) {
// Position the image somewhere sensible
if (!node.imageOffset) {
node.imageOffset = uploadWidget.last_y ? uploadWidget.last_y + 25 : 75;
}
const img = new Image();
img.onload = () => {
node.imgs = [img];
app.graph.setDirtyCanvas(true);
};
img.src = `/view?filename=${name}&type=input`;
let folder_separator = name.lastIndexOf("/");
let subfolder = "";
if (folder_separator > -1) {
subfolder = name.substring(0, folder_separator);
name = name.substring(folder_separator + 1);
}
img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`;
node.setSizeForImage?.();
}
var default_value = imageWidget.value;
Object.defineProperty(imageWidget, "value", {
set : function(value) {
this._real_value = value;
},
get : function() {
let value = "";
if (this._real_value) {
value = this._real_value;
} else {
return default_value;
}
if (value.filename) {
let real_value = value;
value = "";
if (real_value.subfolder) {
value = real_value.subfolder + "/";
}
value += real_value.filename;
if(real_value.type && real_value.type !== "input")
value += ` [${real_value.type}]`;
}
return value;
}
});
// Add our own callback to the combo widget to render an image when it changes
const cb = node.callback;
imageWidget.callback = function () {

View File

@ -120,7 +120,7 @@ body {
.comfy-menu > button,
.comfy-menu-btns button,
.comfy-menu .comfy-list button,
.comfy-modal button{
.comfy-modal button {
color: var(--input-text);
background-color: var(--comfy-input-bg);
border-radius: 8px;
@ -129,6 +129,15 @@ body {
margin-top: 2px;
}
.comfy-menu > button:hover,
.comfy-menu-btns button:hover,
.comfy-menu .comfy-list button:hover,
.comfy-modal button:hover,
.comfy-settings-btn:hover {
filter: brightness(1.2);
cursor: pointer;
}
.comfy-menu span.drag-handle {
width: 10px;
height: 20px;
@ -160,9 +169,9 @@ body {
.comfy-list {
color: var(--descrip-text);
background-color: #333;
background-color: var(--comfy-menu-bg);
margin-bottom: 10px;
border-color: #4e4e4e;
border-color: var(--border-color);
border-style: solid;
}
@ -217,6 +226,14 @@ button.comfy-queue-btn {
z-index: 99;
}
.comfy-modal.comfy-settings input[type="range"] {
vertical-align: middle;
}
.comfy-modal.comfy-settings input[type="range"] + input[type="number"] {
width: 3.5em;
}
.comfy-modal input,
.comfy-modal select {
color: var(--input-text);
@ -240,8 +257,11 @@ button.comfy-queue-btn {
}
}
/* Input popup */
.graphdialog {
min-height: 1em;
background-color: var(--comfy-menu-bg);
}
.graphdialog .name {
@ -265,15 +285,66 @@ button.comfy-queue-btn {
border-radius: 12px 0 0 12px;
}
/* Context menu */
.litegraph .litemenu-entry.has_submenu {
position: relative;
padding-right: 20px;
}
}
.litemenu-entry.has_submenu::after {
.litemenu-entry.has_submenu::after {
content: ">";
position: absolute;
top: 0;
right: 2px;
}
}
.litegraph.litecontextmenu,
.litegraph.litecontextmenu.dark {
z-index: 9999 !important;
background-color: var(--comfy-menu-bg) !important;
filter: brightness(95%);
}
.litegraph.litecontextmenu .litemenu-entry:hover:not(.disabled):not(.separator) {
background-color: var(--comfy-menu-bg) !important;
filter: brightness(155%);
color: var(--input-text);
}
.litegraph.litecontextmenu .litemenu-entry.submenu,
.litegraph.litecontextmenu.dark .litemenu-entry.submenu {
background-color: var(--comfy-menu-bg) !important;
color: var(--input-text);
}
.litegraph.litecontextmenu input {
background-color: var(--comfy-input-bg) !important;
color: var(--input-text) !important;
}
/* Search box */
.litegraph.litesearchbox {
z-index: 9999 !important;
background-color: var(--comfy-menu-bg) !important;
overflow: hidden;
}
.litegraph.litesearchbox input,
.litegraph.litesearchbox select {
background-color: var(--comfy-input-bg) !important;
color: var(--input-text);
}
.litegraph.lite-search-item {
color: var(--input-text);
background-color: var(--comfy-input-bg);
filter: brightness(80%);
padding-left: 0.2em;
}
.litegraph.lite-search-item.generic_type {
color: var(--input-text);
filter: brightness(50%);
}

78
web/types/comfy.d.ts vendored Normal file
View File

@ -0,0 +1,78 @@
import { LGraphNode, IWidget } from "./litegraph";
import { ComfyApp } from "/scripts/app";
export interface ComfyExtension {
/**
* The name of the extension
*/
name: string;
/**
* Allows any initialisation, e.g. loading resources. Called after the canvas is created but before nodes are added
* @param app The ComfyUI app instance
*/
init(app: ComfyApp): Promise<void>;
/**
* Allows any additonal setup, called after the application is fully set up and running
* @param app The ComfyUI app instance
*/
setup(app: ComfyApp): Promise<void>;
/**
* Called before nodes are registered with the graph
* @param defs The collection of node definitions, add custom ones or edit existing ones
* @param app The ComfyUI app instance
*/
addCustomNodeDefs(defs: Record<string, ComfyObjectInfo>, app: ComfyApp): Promise<void>;
/**
* Allows the extension to add custom widgets
* @param app The ComfyUI app instance
* @returns An array of {[widget name]: widget data}
*/
getCustomWidgets(
app: ComfyApp
): Promise<
Array<
Record<string, (node, inputName, inputData, app) => { widget?: IWidget; minWidth?: number; minHeight?: number }>
>
>;
/**
* Allows the extension to add additional handling to the node before it is registered with LGraph
* @param nodeType The node class (not an instance)
* @param nodeData The original node object info config object
* @param app The ComfyUI app instance
*/
beforeRegisterNodeDef(nodeType: typeof LGraphNode, nodeData: ComfyObjectInfo, app: ComfyApp): Promise<void>;
/**
* Allows the extension to register additional nodes with LGraph after standard nodes are added
* @param app The ComfyUI app instance
*/
registerCustomNodes(app: ComfyApp): Promise<void>;
/**
* Allows the extension to modify a node that has been reloaded onto the graph.
* If you break something in the backend and want to patch workflows in the frontend
* This is the place to do this
* @param node The node that has been loaded
* @param app The ComfyUI app instance
*/
loadedGraphNode(node: LGraphNode, app: ComfyApp);
/**
* Allows the extension to run code after the constructor of the node
* @param node The node that has been created
* @param app The ComfyUI app instance
*/
nodeCreated(node: LGraphNode, app: ComfyApp);
}
export type ComfyObjectInfo = {
name: string;
display_name?: string;
description?: string;
category: string;
input?: {
required?: Record<string, ComfyObjectInfoConfig>;
optional?: Record<string, ComfyObjectInfoConfig>;
};
output?: string[];
output_name: string[];
};
export type ComfyObjectInfoConfig = [string | any[]] | [string | any[], any];

1506
web/types/litegraph.d.ts vendored Normal file

File diff suppressed because it is too large Load Diff