mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-10 08:17:44 +08:00
Merge branch 'master' into bool-input
This commit is contained in:
commit
d9973a036e
@ -1,65 +0,0 @@
|
||||
import pygit2
|
||||
from datetime import datetime
|
||||
import sys
|
||||
|
||||
def pull(repo, remote_name='origin', branch='master'):
|
||||
for remote in repo.remotes:
|
||||
if remote.name == remote_name:
|
||||
remote.fetch()
|
||||
remote_master_id = repo.lookup_reference('refs/remotes/origin/%s' % (branch)).target
|
||||
merge_result, _ = repo.merge_analysis(remote_master_id)
|
||||
# Up to date, do nothing
|
||||
if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
|
||||
return
|
||||
# We can just fastforward
|
||||
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
|
||||
repo.checkout_tree(repo.get(remote_master_id))
|
||||
try:
|
||||
master_ref = repo.lookup_reference('refs/heads/%s' % (branch))
|
||||
master_ref.set_target(remote_master_id)
|
||||
except KeyError:
|
||||
repo.create_branch(branch, repo.get(remote_master_id))
|
||||
repo.head.set_target(remote_master_id)
|
||||
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL:
|
||||
repo.merge(remote_master_id)
|
||||
|
||||
if repo.index.conflicts is not None:
|
||||
for conflict in repo.index.conflicts:
|
||||
print('Conflicts found in:', conflict[0].path)
|
||||
raise AssertionError('Conflicts, ahhhhh!!')
|
||||
|
||||
user = repo.default_signature
|
||||
tree = repo.index.write_tree()
|
||||
commit = repo.create_commit('HEAD',
|
||||
user,
|
||||
user,
|
||||
'Merge!',
|
||||
tree,
|
||||
[repo.head.target, remote_master_id])
|
||||
# We need to do this or git CLI will think we are still merging.
|
||||
repo.state_cleanup()
|
||||
else:
|
||||
raise AssertionError('Unknown merge analysis result')
|
||||
|
||||
|
||||
repo = pygit2.Repository(str(sys.argv[1]))
|
||||
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
||||
try:
|
||||
print("stashing current changes")
|
||||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash")
|
||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||
print("creating backup branch: {}".format(backup_branch_name))
|
||||
repo.branches.local.create(backup_branch_name, repo.head.peel())
|
||||
|
||||
print("checking out master branch")
|
||||
branch = repo.lookup_branch('master')
|
||||
ref = repo.lookup_reference(branch.name)
|
||||
repo.checkout(ref)
|
||||
|
||||
print("pulling latest changes")
|
||||
pull(repo)
|
||||
|
||||
print("Done!")
|
||||
|
||||
@ -1,2 +0,0 @@
|
||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
||||
pause
|
||||
@ -1,3 +1,3 @@
|
||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -r ../ComfyUI/requirements.txt pygit2
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2
|
||||
pause
|
||||
|
||||
@ -1,27 +0,0 @@
|
||||
HOW TO RUN:
|
||||
|
||||
if you have a NVIDIA gpu:
|
||||
|
||||
run_nvidia_gpu.bat
|
||||
|
||||
|
||||
|
||||
To run it in slow CPU mode:
|
||||
|
||||
run_cpu.bat
|
||||
|
||||
|
||||
|
||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||
|
||||
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
|
||||
|
||||
|
||||
|
||||
RECOMMENDED WAY TO UPDATE:
|
||||
To update the ComfyUI code: update\update_comfyui.bat
|
||||
|
||||
|
||||
|
||||
To update ComfyUI with the python dependencies:
|
||||
update\update_comfyui_and_python_dependencies.bat
|
||||
@ -1,2 +0,0 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --cpu --windows-standalone-build
|
||||
pause
|
||||
30
.github/workflows/windows_release_cu118_dependencies_2.yml
vendored
Normal file
30
.github/workflows/windows_release_cu118_dependencies_2.yml
vendored
Normal file
@ -0,0 +1,30 @@
|
||||
name: "Windows Release cu118 dependencies 2"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
jobs:
|
||||
build_dependencies:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10.9'
|
||||
|
||||
- shell: bash
|
||||
run: |
|
||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||
echo installed basic
|
||||
ls -lah temp_wheel_dir
|
||||
mv temp_wheel_dir cu118_python_deps
|
||||
tar cf cu118_python_deps.tar cu118_python_deps
|
||||
|
||||
- uses: actions/cache/save@v3
|
||||
with:
|
||||
path: cu118_python_deps.tar
|
||||
key: ${{ runner.os }}-build-cu118
|
||||
@ -19,21 +19,21 @@ jobs:
|
||||
fetch-depth: 0
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10.9'
|
||||
python-version: '3.11.3'
|
||||
- shell: bash
|
||||
run: |
|
||||
cd ..
|
||||
cp -r ComfyUI ComfyUI_copy
|
||||
curl https://www.python.org/ftp/python/3.10.9/python-3.10.9-embed-amd64.zip -o python_embeded.zip
|
||||
curl https://www.python.org/ftp/python/3.11.3/python-3.11.3-embed-amd64.zip -o python_embeded.zip
|
||||
unzip python_embeded.zip -d python_embeded
|
||||
cd python_embeded
|
||||
echo 'import site' >> ./python310._pth
|
||||
echo 'import site' >> ./python311._pth
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||
ls ../temp_wheel_dir
|
||||
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
||||
sed -i '1i../ComfyUI' ./python310._pth
|
||||
sed -i '1i../ComfyUI' ./python311._pth
|
||||
cd ..
|
||||
|
||||
|
||||
@ -46,6 +46,8 @@ jobs:
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp -r ComfyUI/.ci/nightly/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/nightly/windows_base_files/* ./
|
||||
|
||||
cd ..
|
||||
|
||||
|
||||
37
README.md
37
README.md
@ -7,6 +7,8 @@ A powerful and modular stable diffusion GUI and backend.
|
||||
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
||||
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
|
||||
### [Installing ComfyUI](#installing)
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x and SD2.x
|
||||
@ -17,6 +19,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
|
||||
- Embeddings/Textual inversion
|
||||
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
|
||||
- Loading full workflows (with seeds) from generated PNG files.
|
||||
- Saving/Loading workflows as Json files.
|
||||
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
|
||||
@ -25,6 +28,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
|
||||
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
||||
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: will never download anything.
|
||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||
@ -32,14 +36,29 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
|
||||
## Shortcuts
|
||||
- **Ctrl + A** select all nodes
|
||||
- **Ctrl + M** mute/unmute selected nodes
|
||||
- **Delete** or **Backspace** delete selected nodes
|
||||
- **Space** Holding space key while moving the cursor moves the canvas around. It works when holding the mouse button down so it is easier to connect different nodes when the canvas gets too large.
|
||||
- **Ctrl/Shift + Click** Add clicked node to selection.
|
||||
- **Ctrl + C/Ctrl + V** - Copy and paste selected nodes, without maintaining the connection to the outputs of unselected nodes.
|
||||
- **Ctrl + C/Ctrl + Shift + V** - Copy and paste selected nodes, and maintaining the connection from the outputs of unselected nodes to the inputs of the newly pasted nodes.
|
||||
- Holding **Shift** and drag selected nodes - Move multiple selected nodes at the same time.
|
||||
|
||||
| Keybind | Explanation |
|
||||
| - | - |
|
||||
| Ctrl + Enter | Queue up current graph for generation |
|
||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||
| Ctrl + S | Save workflow |
|
||||
| Ctrl + O | Load workflow |
|
||||
| Ctrl + A | Select all nodes |
|
||||
| Ctrl + M | Mute/unmute selected nodes |
|
||||
| Delete/Backspace | Delete selected nodes |
|
||||
| Ctrl + Delete/Backspace | Delete the current graph |
|
||||
| Space | Move the canvas around when held and moving the cursor |
|
||||
| Ctrl/Shift + Click | Add clicked node to selection |
|
||||
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||
| Ctrl + C/Ctrl + Shift + V| Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||
| Shift + Drag | Move multiple selected nodes at the same time |
|
||||
| Ctrl + D | Load default graph |
|
||||
| Q | Toggle visibility of the queue |
|
||||
| H | Toggle visibility of history |
|
||||
| R | Refresh graph |
|
||||
| Double-Click LMB | Open node quick search palette |
|
||||
|
||||
Ctrl can also be replaced with Cmd instead for MacOS users
|
||||
|
||||
# Installing
|
||||
|
||||
@ -69,7 +88,7 @@ Put your VAE in: models/vae
|
||||
|
||||
At the time of writing this pytorch has issues with python versions higher than 3.10 so make sure your python/pip versions are 3.10.
|
||||
|
||||
### AMD (Linux only)
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
|
||||
|
||||
@ -5,17 +5,17 @@ import torch
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
from ..ldm.modules.diffusionmodules.util import (
|
||||
conv_nd,
|
||||
linear,
|
||||
zero_module,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from ldm.modules.attention import SpatialTransformer
|
||||
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ldm.util import log_txt_as_img, exists, instantiate_from_config
|
||||
from ..ldm.modules.attention import SpatialTransformer
|
||||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
||||
from ..ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ..ldm.util import log_txt_as_img, exists, instantiate_from_config
|
||||
|
||||
|
||||
class ControlledUnetModel(UNetModel):
|
||||
|
||||
@ -7,9 +7,11 @@ parser.add_argument("--port", type=int, default=8188, help="Set the listen port.
|
||||
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
||||
@ -29,3 +31,6 @@ parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test
|
||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.windows_standalone_build:
|
||||
args.auto_launch = True
|
||||
|
||||
@ -712,7 +712,7 @@ class UniPC:
|
||||
|
||||
def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
||||
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
||||
atol=0.0078, rtol=0.05, corrector=False,
|
||||
atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
|
||||
):
|
||||
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||
t_T = self.noise_schedule.T if t_start is None else t_start
|
||||
@ -723,7 +723,7 @@ class UniPC:
|
||||
# timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
||||
assert timesteps.shape[0] - 1 == steps
|
||||
# with torch.no_grad():
|
||||
for step_index in trange(steps):
|
||||
for step_index in trange(steps, disable=disable_pbar):
|
||||
if self.noise_mask is not None:
|
||||
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
|
||||
if step_index == 0:
|
||||
@ -766,6 +766,8 @@ class UniPC:
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
model_prev_list[-1] = model_x
|
||||
if callback is not None:
|
||||
callback(step_index, model_prev_list[-1], x, steps)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if denoise_to_zero:
|
||||
@ -833,7 +835,7 @@ def expand_dims(v, dims):
|
||||
|
||||
|
||||
|
||||
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=None, noise_mask=None, variant='bh1'):
|
||||
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
||||
to_zero = False
|
||||
if sigmas[-1] == 0:
|
||||
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
|
||||
@ -877,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
||||
|
||||
order = min(3, len(timesteps) - 1)
|
||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True)
|
||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
||||
if not to_zero:
|
||||
x /= ns.marginal_alpha(timesteps[-1])
|
||||
return x
|
||||
|
||||
360
comfy/gligen.py
Normal file
360
comfy/gligen.py
Normal file
@ -0,0 +1,360 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
from .ldm.modules.attention import CrossAttention
|
||||
from inspect import isfunction
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * torch.nn.functional.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class GatedCrossAttentionDense(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||
super().__init__()
|
||||
|
||||
self.attn = CrossAttention(
|
||||
query_dim=query_dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head)
|
||||
self.ff = FeedForward(query_dim, glu=True)
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_dim)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
|
||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||
|
||||
# this can be useful: we can externally change magnitude of tanh(alpha)
|
||||
# for example, when it is set to 0, then the entire model is same as
|
||||
# original one
|
||||
self.scale = 1
|
||||
|
||||
def forward(self, x, objs):
|
||||
|
||||
x = x + self.scale * \
|
||||
torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
|
||||
x = x + self.scale * \
|
||||
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class GatedSelfAttentionDense(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||
super().__init__()
|
||||
|
||||
# we need a linear projection since we need cat visual feature and obj
|
||||
# feature
|
||||
self.linear = nn.Linear(context_dim, query_dim)
|
||||
|
||||
self.attn = CrossAttention(
|
||||
query_dim=query_dim,
|
||||
context_dim=query_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head)
|
||||
self.ff = FeedForward(query_dim, glu=True)
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_dim)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
|
||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||
|
||||
# this can be useful: we can externally change magnitude of tanh(alpha)
|
||||
# for example, when it is set to 0, then the entire model is same as
|
||||
# original one
|
||||
self.scale = 1
|
||||
|
||||
def forward(self, x, objs):
|
||||
|
||||
N_visual = x.shape[1]
|
||||
objs = self.linear(objs)
|
||||
|
||||
x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
|
||||
self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
|
||||
x = x + self.scale * \
|
||||
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class GatedSelfAttentionDense2(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||
super().__init__()
|
||||
|
||||
# we need a linear projection since we need cat visual feature and obj
|
||||
# feature
|
||||
self.linear = nn.Linear(context_dim, query_dim)
|
||||
|
||||
self.attn = CrossAttention(
|
||||
query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
|
||||
self.ff = FeedForward(query_dim, glu=True)
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_dim)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
|
||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||
|
||||
# this can be useful: we can externally change magnitude of tanh(alpha)
|
||||
# for example, when it is set to 0, then the entire model is same as
|
||||
# original one
|
||||
self.scale = 1
|
||||
|
||||
def forward(self, x, objs):
|
||||
|
||||
B, N_visual, _ = x.shape
|
||||
B, N_ground, _ = objs.shape
|
||||
|
||||
objs = self.linear(objs)
|
||||
|
||||
# sanity check
|
||||
size_v = math.sqrt(N_visual)
|
||||
size_g = math.sqrt(N_ground)
|
||||
assert int(size_v) == size_v, "Visual tokens must be square rootable"
|
||||
assert int(size_g) == size_g, "Grounding tokens must be square rootable"
|
||||
size_v = int(size_v)
|
||||
size_g = int(size_g)
|
||||
|
||||
# select grounding token and resize it to visual token size as residual
|
||||
out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
|
||||
:, N_visual:, :]
|
||||
out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
|
||||
out = torch.nn.functional.interpolate(
|
||||
out, (size_v, size_v), mode='bicubic')
|
||||
residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)
|
||||
|
||||
# add residual to visual feature
|
||||
x = x + self.scale * torch.tanh(self.alpha_attn) * residual
|
||||
x = x + self.scale * \
|
||||
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FourierEmbedder():
|
||||
def __init__(self, num_freqs=64, temperature=100):
|
||||
|
||||
self.num_freqs = num_freqs
|
||||
self.temperature = temperature
|
||||
self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, x, cat_dim=-1):
|
||||
"x: arbitrary shape of tensor. dim: cat dim"
|
||||
out = []
|
||||
for freq in self.freq_bands:
|
||||
out.append(torch.sin(freq * x))
|
||||
out.append(torch.cos(freq * x))
|
||||
return torch.cat(out, cat_dim)
|
||||
|
||||
|
||||
class PositionNet(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, fourier_freqs=8):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
|
||||
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
|
||||
|
||||
self.linears = nn.Sequential(
|
||||
nn.Linear(self.in_dim + self.position_dim, 512),
|
||||
nn.SiLU(),
|
||||
nn.Linear(512, 512),
|
||||
nn.SiLU(),
|
||||
nn.Linear(512, out_dim),
|
||||
)
|
||||
|
||||
self.null_positive_feature = torch.nn.Parameter(
|
||||
torch.zeros([self.in_dim]))
|
||||
self.null_position_feature = torch.nn.Parameter(
|
||||
torch.zeros([self.position_dim]))
|
||||
|
||||
def forward(self, boxes, masks, positive_embeddings):
|
||||
B, N, _ = boxes.shape
|
||||
masks = masks.unsqueeze(-1)
|
||||
|
||||
# embedding position (it may includes padding as placeholder)
|
||||
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
|
||||
|
||||
# learnable null embedding
|
||||
positive_null = self.null_positive_feature.view(1, 1, -1)
|
||||
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
||||
|
||||
# replace padding with learnable null embedding
|
||||
positive_embeddings = positive_embeddings * \
|
||||
masks + (1 - masks) * positive_null
|
||||
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
|
||||
|
||||
objs = self.linears(
|
||||
torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
|
||||
assert objs.shape == torch.Size([B, N, self.out_dim])
|
||||
return objs
|
||||
|
||||
|
||||
class Gligen(nn.Module):
|
||||
def __init__(self, modules, position_net, key_dim):
|
||||
super().__init__()
|
||||
self.module_list = nn.ModuleList(modules)
|
||||
self.position_net = position_net
|
||||
self.key_dim = key_dim
|
||||
self.max_objs = 30
|
||||
self.lowvram = False
|
||||
|
||||
def _set_position(self, boxes, masks, positive_embeddings):
|
||||
if self.lowvram == True:
|
||||
self.position_net.to(boxes.device)
|
||||
|
||||
objs = self.position_net(boxes, masks, positive_embeddings)
|
||||
|
||||
if self.lowvram == True:
|
||||
self.position_net.cpu()
|
||||
def func_lowvram(key, x):
|
||||
module = self.module_list[key]
|
||||
module.to(x.device)
|
||||
r = module(x, objs)
|
||||
module.cpu()
|
||||
return r
|
||||
return func_lowvram
|
||||
else:
|
||||
def func(key, x):
|
||||
module = self.module_list[key]
|
||||
return module(x, objs)
|
||||
return func
|
||||
|
||||
def set_position(self, latent_image_shape, position_params, device):
|
||||
batch, c, h, w = latent_image_shape
|
||||
masks = torch.zeros([self.max_objs], device="cpu")
|
||||
boxes = []
|
||||
positive_embeddings = []
|
||||
for p in position_params:
|
||||
x1 = (p[4]) / w
|
||||
y1 = (p[3]) / h
|
||||
x2 = (p[4] + p[2]) / w
|
||||
y2 = (p[3] + p[1]) / h
|
||||
masks[len(boxes)] = 1.0
|
||||
boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
|
||||
positive_embeddings += [p[0]]
|
||||
append_boxes = []
|
||||
append_conds = []
|
||||
if len(boxes) < self.max_objs:
|
||||
append_boxes = [torch.zeros(
|
||||
[self.max_objs - len(boxes), 4], device="cpu")]
|
||||
append_conds = [torch.zeros(
|
||||
[self.max_objs - len(boxes), self.key_dim], device="cpu")]
|
||||
|
||||
box_out = torch.cat(
|
||||
boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
|
||||
masks = masks.unsqueeze(0).repeat(batch, 1)
|
||||
conds = torch.cat(positive_embeddings +
|
||||
append_conds).unsqueeze(0).repeat(batch, 1, 1)
|
||||
return self._set_position(
|
||||
box_out.to(device),
|
||||
masks.to(device),
|
||||
conds.to(device))
|
||||
|
||||
def set_empty(self, latent_image_shape, device):
|
||||
batch, c, h, w = latent_image_shape
|
||||
masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
|
||||
box_out = torch.zeros([self.max_objs, 4],
|
||||
device="cpu").repeat(batch, 1, 1)
|
||||
conds = torch.zeros([self.max_objs, self.key_dim],
|
||||
device="cpu").repeat(batch, 1, 1)
|
||||
return self._set_position(
|
||||
box_out.to(device),
|
||||
masks.to(device),
|
||||
conds.to(device))
|
||||
|
||||
def set_lowvram(self, value=True):
|
||||
self.lowvram = value
|
||||
|
||||
def cleanup(self):
|
||||
self.lowvram = False
|
||||
|
||||
def get_models(self):
|
||||
return [self]
|
||||
|
||||
def load_gligen(sd):
|
||||
sd_k = sd.keys()
|
||||
output_list = []
|
||||
key_dim = 768
|
||||
for a in ["input_blocks", "middle_block", "output_blocks"]:
|
||||
for b in range(20):
|
||||
k_temp = filter(lambda k: "{}.{}.".format(a, b)
|
||||
in k and ".fuser." in k, sd_k)
|
||||
k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)
|
||||
|
||||
n_sd = {}
|
||||
for k in k_temp:
|
||||
n_sd[k[1]] = sd[k[0]]
|
||||
if len(n_sd) > 0:
|
||||
query_dim = n_sd["linear.weight"].shape[0]
|
||||
key_dim = n_sd["linear.weight"].shape[1]
|
||||
|
||||
if key_dim == 768: # SD1.x
|
||||
n_heads = 8
|
||||
d_head = query_dim // n_heads
|
||||
else:
|
||||
d_head = 64
|
||||
n_heads = query_dim // d_head
|
||||
|
||||
gated = GatedSelfAttentionDense(
|
||||
query_dim, key_dim, n_heads, d_head)
|
||||
gated.load_state_dict(n_sd, strict=False)
|
||||
output_list.append(gated)
|
||||
|
||||
if "position_net.null_positive_feature" in sd_k:
|
||||
in_dim = sd["position_net.null_positive_feature"].shape[0]
|
||||
out_dim = sd["position_net.linears.4.weight"].shape[0]
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
w = WeightsLoader()
|
||||
w.position_net = PositionNet(in_dim, out_dim)
|
||||
w.load_state_dict(sd, strict=False)
|
||||
|
||||
gligen = Gligen(output_list, w.position_net, key_dim)
|
||||
return gligen
|
||||
@ -3,11 +3,11 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.modules.ema import LitEma
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
from comfy.ldm.modules.ema import LitEma
|
||||
|
||||
# class AutoencoderKL(pl.LightningModule):
|
||||
class AutoencoderKL(torch.nn.Module):
|
||||
|
||||
@ -4,7 +4,7 @@ import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
@ -81,6 +81,7 @@ class DDIMSampler(object):
|
||||
extra_args=None,
|
||||
to_zero=True,
|
||||
end_step=None,
|
||||
disable_pbar=False,
|
||||
**kwargs
|
||||
):
|
||||
self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose)
|
||||
@ -103,7 +104,8 @@ class DDIMSampler(object):
|
||||
denoise_function=denoise_function,
|
||||
extra_args=extra_args,
|
||||
to_zero=to_zero,
|
||||
end_step=end_step
|
||||
end_step=end_step,
|
||||
disable_pbar=disable_pbar
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@ -185,7 +187,7 @@ class DDIMSampler(object):
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
||||
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None):
|
||||
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
@ -204,7 +206,7 @@ class DDIMSampler(object):
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
# print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step)
|
||||
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
|
||||
@ -19,12 +19,12 @@ from tqdm import tqdm
|
||||
from torchvision.utils import make_grid
|
||||
# from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
|
||||
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
||||
from ldm.modules.ema import LitEma
|
||||
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
||||
from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
|
||||
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from comfy.ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
||||
from comfy.ldm.modules.ema import LitEma
|
||||
from comfy.ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
||||
from ..autoencoder import IdentityFirstStage, AutoencoderKL
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
||||
from .ddim import DDIMSampler
|
||||
|
||||
|
||||
__conditioning_keys__ = {'concat': 'c_concat',
|
||||
|
||||
@ -6,10 +6,10 @@ from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from typing import Optional, Any
|
||||
|
||||
from ldm.modules.diffusionmodules.util import checkpoint
|
||||
from .diffusionmodules.util import checkpoint
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
import model_management
|
||||
from comfy import model_management
|
||||
|
||||
from . import tomesd
|
||||
|
||||
@ -21,7 +21,7 @@ if model_management.xformers_enabled():
|
||||
import os
|
||||
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
||||
|
||||
from cli_args import args
|
||||
from comfy.cli_args import args
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
@ -163,13 +163,17 @@ class CrossAttentionBirchSan(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
query = self.to_q(x)
|
||||
context = default(context, x)
|
||||
key = self.to_k(context)
|
||||
value = self.to_v(context)
|
||||
if value is not None:
|
||||
value = self.to_v(value)
|
||||
else:
|
||||
value = self.to_v(context)
|
||||
|
||||
del context, x
|
||||
|
||||
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
@ -256,13 +260,17 @@ class CrossAttentionDoggettx(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k_in = self.to_k(context)
|
||||
v_in = self.to_v(context)
|
||||
if value is not None:
|
||||
v_in = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v_in = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
@ -350,13 +358,17 @@ class CrossAttention(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
if value is not None:
|
||||
v = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
@ -402,11 +414,15 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
if value is not None:
|
||||
v = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
@ -447,19 +463,19 @@ class CrossAttentionPytorch(nn.Module):
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
if value is not None:
|
||||
v = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
||||
.contiguous(),
|
||||
lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
@ -468,10 +484,7 @@ class CrossAttentionPytorch(nn.Module):
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
||||
out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head)
|
||||
)
|
||||
|
||||
return self.to_out(out)
|
||||
@ -510,16 +523,52 @@ class BasicTransformerBlock(nn.Module):
|
||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, transformer_options={}):
|
||||
current_index = None
|
||||
if "current_index" in transformer_options:
|
||||
current_index = transformer_options["current_index"]
|
||||
if "patches" in transformer_options:
|
||||
transformer_patches = transformer_options["patches"]
|
||||
else:
|
||||
transformer_patches = {}
|
||||
|
||||
n = self.norm1(x)
|
||||
if self.disable_self_attn:
|
||||
context_attn1 = context
|
||||
else:
|
||||
context_attn1 = None
|
||||
value_attn1 = None
|
||||
|
||||
if "attn1_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_patch"]
|
||||
if context_attn1 is None:
|
||||
context_attn1 = n
|
||||
value_attn1 = context_attn1
|
||||
for p in patch:
|
||||
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
|
||||
|
||||
if "tomesd" in transformer_options:
|
||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||
n = u(self.attn1(m(n), context=context if self.disable_self_attn else None))
|
||||
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
|
||||
else:
|
||||
n = self.attn1(n, context=context if self.disable_self_attn else None)
|
||||
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
||||
|
||||
x += n
|
||||
if "middle_patch" in transformer_patches:
|
||||
patch = transformer_patches["middle_patch"]
|
||||
for p in patch:
|
||||
x = p(current_index, x)
|
||||
|
||||
n = self.norm2(x)
|
||||
n = self.attn2(n, context=context)
|
||||
|
||||
context_attn2 = context
|
||||
value_attn2 = None
|
||||
if "attn2_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn2_patch"]
|
||||
value_attn2 = context_attn2
|
||||
for p in patch:
|
||||
n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2)
|
||||
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||
|
||||
x += n
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
|
||||
@ -6,8 +6,8 @@ import numpy as np
|
||||
from einops import rearrange
|
||||
from typing import Optional, Any
|
||||
|
||||
from ldm.modules.attention import MemoryEfficientCrossAttention
|
||||
import model_management
|
||||
from ..attention import MemoryEfficientCrossAttention
|
||||
from comfy import model_management
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
import xformers
|
||||
|
||||
@ -6,7 +6,7 @@ import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
from .util import (
|
||||
checkpoint,
|
||||
conv_nd,
|
||||
linear,
|
||||
@ -15,8 +15,8 @@ from ldm.modules.diffusionmodules.util import (
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
)
|
||||
from ldm.modules.attention import SpatialTransformer
|
||||
from ldm.util import exists
|
||||
from ..attention import SpatialTransformer
|
||||
from comfy.ldm.util import exists
|
||||
|
||||
|
||||
# dummy replace
|
||||
@ -76,16 +76,31 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, x, emb, context=None, transformer_options={}):
|
||||
def forward(self, x, emb, context=None, transformer_options={}, output_shape=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context, transformer_options)
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
|
||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
|
||||
for layer in ts:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context, transformer_options)
|
||||
transformer_options["current_index"] += 1
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
@ -105,14 +120,20 @@ class Upsample(nn.Module):
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, output_shape=None):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(
|
||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
||||
)
|
||||
shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
|
||||
if output_shape is not None:
|
||||
shape[1] = output_shape[3]
|
||||
shape[2] = output_shape[4]
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
shape = [x.shape[2] * 2, x.shape[3] * 2]
|
||||
if output_shape is not None:
|
||||
shape[0] = output_shape[2]
|
||||
shape[1] = output_shape[3]
|
||||
|
||||
x = F.interpolate(x, size=shape, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
@ -782,6 +803,8 @@ class UNetModel(nn.Module):
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
transformer_options["original_shape"] = list(x.shape)
|
||||
transformer_options["current_index"] = 0
|
||||
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
@ -795,13 +818,13 @@ class UNetModel(nn.Module):
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
h = module(h, emb, context, transformer_options)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
||||
if control is not None and 'input' in control and len(control['input']) > 0:
|
||||
ctrl = control['input'].pop()
|
||||
if ctrl is not None:
|
||||
h += ctrl
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context, transformer_options)
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||
h += control['middle'].pop()
|
||||
|
||||
@ -811,9 +834,14 @@ class UNetModel(nn.Module):
|
||||
ctrl = control['output'].pop()
|
||||
if ctrl is not None:
|
||||
hsp += ctrl
|
||||
|
||||
h = th.cat([h, hsp], dim=1)
|
||||
del hsp
|
||||
h = module(h, emb, context, transformer_options)
|
||||
if len(hs) > 0:
|
||||
output_shape = hs[-1].shape
|
||||
else:
|
||||
output_shape = None
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
return self.id_predictor(h)
|
||||
|
||||
@ -3,8 +3,8 @@ import torch.nn as nn
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
|
||||
from ldm.util import default
|
||||
from .util import extract_into_tensor, make_beta_schedule
|
||||
from comfy.ldm.util import default
|
||||
|
||||
|
||||
class AbstractLowScaleModel(nn.Module):
|
||||
|
||||
@ -15,7 +15,7 @@ import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
from ..diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from ..diffusionmodules.openaimodel import Timestep
|
||||
import torch
|
||||
|
||||
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
||||
|
||||
@ -24,7 +24,7 @@ except ImportError:
|
||||
from torch import Tensor
|
||||
from typing import List
|
||||
|
||||
import model_management
|
||||
from comfy import model_management
|
||||
|
||||
def dynamic_slice(
|
||||
x: Tensor,
|
||||
|
||||
@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||
"""
|
||||
B, N, _ = metric.shape
|
||||
|
||||
if r <= 0:
|
||||
if r <= 0 or w == 1 or h == 1:
|
||||
return do_nothing, do_nothing
|
||||
|
||||
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import psutil
|
||||
from enum import Enum
|
||||
from cli_args import args
|
||||
from comfy.cli_args import args
|
||||
|
||||
class VRAMState(Enum):
|
||||
CPU = 0
|
||||
@ -20,15 +20,30 @@ total_vram_available_mb = -1
|
||||
accelerate_enabled = False
|
||||
xpu_available = False
|
||||
|
||||
directml_enabled = False
|
||||
if args.directml is not None:
|
||||
import torch_directml
|
||||
directml_enabled = True
|
||||
device_index = args.directml
|
||||
if device_index < 0:
|
||||
directml_device = torch_directml.device()
|
||||
else:
|
||||
directml_device = torch_directml.device(device_index)
|
||||
print("Using directml with device:", torch_directml.device_name(device_index))
|
||||
# torch_directml.disable_tiled_resources(True)
|
||||
|
||||
try:
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
xpu_available = True
|
||||
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
|
||||
except:
|
||||
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
|
||||
if directml_enabled:
|
||||
total_vram = 4097 #TODO
|
||||
else:
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
xpu_available = True
|
||||
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
|
||||
except:
|
||||
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
|
||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
if not args.normalvram and not args.cpu:
|
||||
if total_vram <= 4096:
|
||||
@ -112,6 +127,32 @@ if args.cpu:
|
||||
|
||||
print(f"Set vram state to: {vram_state.name}")
|
||||
|
||||
def get_torch_device():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if directml_enabled:
|
||||
global directml_device
|
||||
return directml_device
|
||||
if vram_state == VRAMState.MPS:
|
||||
return torch.device("mps")
|
||||
if vram_state == VRAMState.CPU:
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
if xpu_available:
|
||||
return torch.device("xpu")
|
||||
else:
|
||||
return torch.cuda.current_device()
|
||||
|
||||
def get_torch_device_name(device):
|
||||
if hasattr(device, 'type'):
|
||||
return "{}".format(device.type)
|
||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||
|
||||
try:
|
||||
print("Using device:", get_torch_device_name(get_torch_device()))
|
||||
except:
|
||||
print("Could not pick default device.")
|
||||
|
||||
|
||||
current_loaded_model = None
|
||||
current_gpu_controlnets = []
|
||||
@ -133,6 +174,7 @@ def unload_model():
|
||||
#never unload models from GPU on high vram
|
||||
if vram_state != VRAMState.HIGH_VRAM:
|
||||
current_loaded_model.model.cpu()
|
||||
current_loaded_model.model_patches_to("cpu")
|
||||
current_loaded_model.unpatch_model()
|
||||
current_loaded_model = None
|
||||
|
||||
@ -156,6 +198,8 @@ def load_model_gpu(model):
|
||||
except Exception as e:
|
||||
model.unpatch_model()
|
||||
raise e
|
||||
|
||||
model.model_patches_to(get_torch_device())
|
||||
current_loaded_model = model
|
||||
if vram_state == VRAMState.CPU:
|
||||
pass
|
||||
@ -176,16 +220,23 @@ def load_model_gpu(model):
|
||||
model_accelerated = True
|
||||
return current_loaded_model
|
||||
|
||||
def load_controlnet_gpu(models):
|
||||
def load_controlnet_gpu(control_models):
|
||||
global current_gpu_controlnets
|
||||
global vram_state
|
||||
if vram_state == VRAMState.CPU:
|
||||
return
|
||||
|
||||
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||
for m in control_models:
|
||||
if hasattr(m, 'set_lowvram'):
|
||||
m.set_lowvram(True)
|
||||
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
|
||||
return
|
||||
|
||||
models = []
|
||||
for m in control_models:
|
||||
models += m.get_models()
|
||||
|
||||
for m in current_gpu_controlnets:
|
||||
if m not in models:
|
||||
m.cpu()
|
||||
@ -208,18 +259,6 @@ def unload_if_low_vram(model):
|
||||
return model.cpu()
|
||||
return model
|
||||
|
||||
def get_torch_device():
|
||||
global xpu_available
|
||||
if vram_state == VRAMState.MPS:
|
||||
return torch.device("mps")
|
||||
if vram_state == VRAMState.CPU:
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
if xpu_available:
|
||||
return torch.device("xpu")
|
||||
else:
|
||||
return torch.cuda.current_device()
|
||||
|
||||
def get_autocast_device(dev):
|
||||
if hasattr(dev, 'type'):
|
||||
return dev.type
|
||||
@ -227,8 +266,14 @@ def get_autocast_device(dev):
|
||||
|
||||
|
||||
def xformers_enabled():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if vram_state == VRAMState.CPU:
|
||||
return False
|
||||
if xpu_available:
|
||||
return False
|
||||
if directml_enabled:
|
||||
return False
|
||||
return XFORMERS_IS_AVAILABLE
|
||||
|
||||
|
||||
@ -240,10 +285,20 @@ def xformers_enabled_vae():
|
||||
return XFORMERS_ENABLED_VAE
|
||||
|
||||
def pytorch_attention_enabled():
|
||||
global ENABLE_PYTORCH_ATTENTION
|
||||
return ENABLE_PYTORCH_ATTENTION
|
||||
|
||||
def pytorch_attention_flash_attention():
|
||||
global ENABLE_PYTORCH_ATTENTION
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
#TODO: more reliable way of checking for flash attention?
|
||||
if torch.version.cuda: #pytorch flash attention only works on Nvidia
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_free_memory(dev=None, torch_free_too=False):
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
@ -251,7 +306,10 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
mem_free_total = psutil.virtual_memory().available
|
||||
mem_free_torch = mem_free_total
|
||||
else:
|
||||
if xpu_available:
|
||||
if directml_enabled:
|
||||
mem_free_total = 1024 * 1024 * 1024 #TODO
|
||||
mem_free_torch = mem_free_total
|
||||
elif xpu_available:
|
||||
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
|
||||
mem_free_torch = mem_free_total
|
||||
else:
|
||||
@ -273,7 +331,12 @@ def maximum_batch_area():
|
||||
return 0
|
||||
|
||||
memory_free = get_free_memory() / (1024 * 1024)
|
||||
area = ((memory_free - 1024) * 0.9) / (0.6)
|
||||
if xformers_enabled() or pytorch_attention_flash_attention():
|
||||
#TODO: this needs to be tweaked
|
||||
area = 20 * memory_free
|
||||
else:
|
||||
#TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future
|
||||
area = ((memory_free - 1024) * 0.9) / (0.6)
|
||||
return int(max(area, 0))
|
||||
|
||||
def cpu_mode():
|
||||
@ -286,9 +349,14 @@ def mps_mode():
|
||||
|
||||
def should_use_fp16():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
||||
if directml_enabled:
|
||||
return False
|
||||
|
||||
if cpu_mode() or mps_mode() or xpu_available:
|
||||
return False #TODO ?
|
||||
|
||||
@ -307,6 +375,15 @@ def should_use_fp16():
|
||||
|
||||
return True
|
||||
|
||||
def soft_empty_cache():
|
||||
global xpu_available
|
||||
if xpu_available:
|
||||
torch.xpu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
|
||||
92
comfy/sample.py
Normal file
92
comfy/sample.py
Normal file
@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
"""
|
||||
creates random noise given a latent image and a seed.
|
||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||
"""
|
||||
generator = torch.manual_seed(seed)
|
||||
if noise_inds is None:
|
||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
|
||||
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
||||
noises = []
|
||||
for i in range(unique_inds[-1]+1):
|
||||
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
if i in unique_inds:
|
||||
noises.append(noise)
|
||||
noises = [noises[i] for i in inverse]
|
||||
noises = torch.cat(noises, axis=0)
|
||||
return noises
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
"""ensures noise mask is of proper dimensions"""
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
||||
noise_mask = noise_mask.round()
|
||||
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
||||
if noise_mask.shape[0] < shape[0]:
|
||||
noise_mask = noise_mask.repeat(math.ceil(shape[0] / noise_mask.shape[0]), 1, 1, 1)[:shape[0]]
|
||||
noise_mask = noise_mask.to(device)
|
||||
return noise_mask
|
||||
|
||||
def broadcast_cond(cond, batch, device):
|
||||
"""broadcasts conditioning to the batch size"""
|
||||
copy = []
|
||||
for p in cond:
|
||||
t = p[0]
|
||||
if t.shape[0] < batch:
|
||||
t = torch.cat([t] * batch)
|
||||
t = t.to(device)
|
||||
copy += [[t] + p[1:]]
|
||||
return copy
|
||||
|
||||
def get_models_from_cond(cond, model_type):
|
||||
models = []
|
||||
for c in cond:
|
||||
if model_type in c[1]:
|
||||
models += [c[1][model_type]]
|
||||
return models
|
||||
|
||||
def load_additional_models(positive, negative):
|
||||
"""loads additional models in positive and negative conditioning"""
|
||||
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")
|
||||
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
|
||||
gligen = [x[1] for x in gligen]
|
||||
models = control_nets + gligen
|
||||
comfy.model_management.load_controlnet_gpu(models)
|
||||
return models
|
||||
|
||||
def cleanup_additional_models(models):
|
||||
"""cleanup additional models that were loaded"""
|
||||
for m in models:
|
||||
m.cleanup()
|
||||
|
||||
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
if noise_mask is not None:
|
||||
noise_mask = prepare_mask(noise_mask, noise.shape, device)
|
||||
|
||||
real_model = None
|
||||
comfy.model_management.load_model_gpu(model)
|
||||
real_model = model.model
|
||||
|
||||
noise = noise.to(device)
|
||||
latent_image = latent_image.to(device)
|
||||
|
||||
positive_copy = broadcast_cond(positive, noise.shape[0], device)
|
||||
negative_copy = broadcast_cond(negative, noise.shape[0], device)
|
||||
|
||||
models = load_additional_models(positive, negative)
|
||||
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar)
|
||||
samples = samples.cpu()
|
||||
|
||||
cleanup_additional_models(models)
|
||||
return samples
|
||||
@ -3,26 +3,13 @@ from .k_diffusion import external as k_diffusion_external
|
||||
from .extra_samplers import uni_pc
|
||||
import torch
|
||||
import contextlib
|
||||
import model_management
|
||||
from comfy import model_management
|
||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||
import math
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
if len(uncond[0]) == len(cond[0]) and x.shape[0] * x.shape[2] * x.shape[3] < (96 * 96): #TODO check memory instead
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||
else:
|
||||
cond = self.inner_model(x, sigma, cond=cond)
|
||||
uncond = self.inner_model(x, sigma, cond=uncond)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||
return abs(a*b) // math.gcd(a, b)
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns predicted noise
|
||||
@ -36,25 +23,40 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
strength = cond[1]['strength']
|
||||
|
||||
adm_cond = None
|
||||
if 'adm' in cond[1]:
|
||||
adm_cond = cond[1]['adm']
|
||||
if 'adm_encoded' in cond[1]:
|
||||
adm_cond = cond[1]['adm_encoded']
|
||||
|
||||
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
mult = torch.ones_like(input_x) * strength
|
||||
if 'mask' in cond[1]:
|
||||
# Scale the mask to the size of the input
|
||||
# The mask should have been resized as we began the sampling process
|
||||
mask_strength = 1.0
|
||||
if "mask_strength" in cond[1]:
|
||||
mask_strength = cond[1]["mask_strength"]
|
||||
mask = cond[1]['mask']
|
||||
assert(mask.shape[1] == x_in.shape[2])
|
||||
assert(mask.shape[2] == x_in.shape[3])
|
||||
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
|
||||
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
||||
else:
|
||||
mask = torch.ones_like(input_x)
|
||||
mult = mask * strength
|
||||
|
||||
if 'mask' not in cond[1]:
|
||||
rr = 8
|
||||
if area[2] != 0:
|
||||
for t in range(rr):
|
||||
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
|
||||
if (area[0] + area[2]) < x_in.shape[2]:
|
||||
for t in range(rr):
|
||||
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
|
||||
if area[3] != 0:
|
||||
for t in range(rr):
|
||||
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
|
||||
if (area[1] + area[3]) < x_in.shape[3]:
|
||||
for t in range(rr):
|
||||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
||||
|
||||
rr = 8
|
||||
if area[2] != 0:
|
||||
for t in range(rr):
|
||||
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
|
||||
if (area[0] + area[2]) < x_in.shape[2]:
|
||||
for t in range(rr):
|
||||
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
|
||||
if area[3] != 0:
|
||||
for t in range(rr):
|
||||
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
|
||||
if (area[1] + area[3]) < x_in.shape[3]:
|
||||
for t in range(rr):
|
||||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
||||
conditionning = {}
|
||||
conditionning['c_crossattn'] = cond[0]
|
||||
if cond_concat_in is not None and len(cond_concat_in) > 0:
|
||||
@ -70,7 +72,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
control = None
|
||||
if 'control' in cond[1]:
|
||||
control = cond[1]['control']
|
||||
return (input_x, mult, conditionning, area, control)
|
||||
|
||||
patches = None
|
||||
if 'gligen' in cond[1]:
|
||||
gligen = cond[1]['gligen']
|
||||
patches = {}
|
||||
gligen_type = gligen[0]
|
||||
gligen_model = gligen[1]
|
||||
if gligen_type == "position":
|
||||
gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device)
|
||||
else:
|
||||
gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device)
|
||||
|
||||
patches['middle_patch'] = [gligen_patch]
|
||||
|
||||
return (input_x, mult, conditionning, area, control, patches)
|
||||
|
||||
def cond_equal_size(c1, c2):
|
||||
if c1 is c2:
|
||||
@ -78,8 +94,16 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
if c1.keys() != c2.keys():
|
||||
return False
|
||||
if 'c_crossattn' in c1:
|
||||
if c1['c_crossattn'].shape != c2['c_crossattn'].shape:
|
||||
return False
|
||||
s1 = c1['c_crossattn'].shape
|
||||
s2 = c2['c_crossattn'].shape
|
||||
if s1 != s2:
|
||||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||
return False
|
||||
|
||||
mult_min = lcm(s1[1], s2[1])
|
||||
diff = mult_min // min(s1[1], s2[1])
|
||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||
return False
|
||||
if 'c_concat' in c1:
|
||||
if c1['c_concat'].shape != c2['c_concat'].shape:
|
||||
return False
|
||||
@ -91,28 +115,49 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
def can_concat_cond(c1, c2):
|
||||
if c1[0].shape != c2[0].shape:
|
||||
return False
|
||||
|
||||
#control
|
||||
if (c1[4] is None) != (c2[4] is None):
|
||||
return False
|
||||
if c1[4] is not None:
|
||||
if c1[4] is not c2[4]:
|
||||
return False
|
||||
|
||||
#patches
|
||||
if (c1[5] is None) != (c2[5] is None):
|
||||
return False
|
||||
if (c1[5] is not None):
|
||||
if c1[5] is not c2[5]:
|
||||
return False
|
||||
|
||||
return cond_equal_size(c1[2], c2[2])
|
||||
|
||||
def cond_cat(c_list):
|
||||
c_crossattn = []
|
||||
c_concat = []
|
||||
c_adm = []
|
||||
crossattn_max_len = 0
|
||||
for x in c_list:
|
||||
if 'c_crossattn' in x:
|
||||
c_crossattn.append(x['c_crossattn'])
|
||||
c = x['c_crossattn']
|
||||
if crossattn_max_len == 0:
|
||||
crossattn_max_len = c.shape[1]
|
||||
else:
|
||||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
||||
c_crossattn.append(c)
|
||||
if 'c_concat' in x:
|
||||
c_concat.append(x['c_concat'])
|
||||
if 'c_adm' in x:
|
||||
c_adm.append(x['c_adm'])
|
||||
out = {}
|
||||
if len(c_crossattn) > 0:
|
||||
out['c_crossattn'] = [torch.cat(c_crossattn)]
|
||||
c_crossattn_out = []
|
||||
for c in c_crossattn:
|
||||
if c.shape[1] < crossattn_max_len:
|
||||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||||
c_crossattn_out.append(c)
|
||||
|
||||
if len(c_crossattn_out) > 0:
|
||||
out['c_crossattn'] = [torch.cat(c_crossattn_out)]
|
||||
if len(c_concat) > 0:
|
||||
out['c_concat'] = [torch.cat(c_concat)]
|
||||
if len(c_adm) > 0:
|
||||
@ -166,6 +211,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
cond_or_uncond = []
|
||||
area = []
|
||||
control = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = to_run.pop(x)
|
||||
p = o[0]
|
||||
@ -175,6 +221,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
area += [p[3]]
|
||||
cond_or_uncond += [o[1]]
|
||||
control = p[4]
|
||||
patches = p[5]
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
@ -184,8 +231,22 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
|
||||
|
||||
transformer_options = {}
|
||||
if 'transformer_options' in model_options:
|
||||
c['transformer_options'] = model_options['transformer_options']
|
||||
transformer_options = model_options['transformer_options'].copy()
|
||||
|
||||
if patches is not None:
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
||||
del input_x
|
||||
@ -211,7 +272,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
|
||||
max_total_area = model_management.maximum_batch_area()
|
||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
if "sampler_cfg_function" in model_options:
|
||||
return model_options["sampler_cfg_function"](cond, uncond, cond_scale)
|
||||
else:
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
|
||||
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
|
||||
@ -276,6 +340,60 @@ def blank_inpaint_image_like(latent_image):
|
||||
blank_image[:,3] *= 0.1380
|
||||
return blank_image
|
||||
|
||||
def get_mask_aabb(masks):
|
||||
if masks.numel() == 0:
|
||||
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
||||
|
||||
b = masks.shape[0]
|
||||
|
||||
bounding_boxes = torch.zeros((b, 4), device=masks.device, dtype=torch.int)
|
||||
is_empty = torch.zeros((b), device=masks.device, dtype=torch.bool)
|
||||
for i in range(b):
|
||||
mask = masks[i]
|
||||
if mask.numel() == 0:
|
||||
continue
|
||||
if torch.max(mask != 0) == False:
|
||||
is_empty[i] = True
|
||||
continue
|
||||
y, x = torch.where(mask)
|
||||
bounding_boxes[i, 0] = torch.min(x)
|
||||
bounding_boxes[i, 1] = torch.min(y)
|
||||
bounding_boxes[i, 2] = torch.max(x)
|
||||
bounding_boxes[i, 3] = torch.max(y)
|
||||
|
||||
return bounding_boxes, is_empty
|
||||
|
||||
def resolve_cond_masks(conditions, h, w, device):
|
||||
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
|
||||
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
|
||||
for i in range(len(conditions)):
|
||||
c = conditions[i]
|
||||
if 'mask' in c[1]:
|
||||
mask = c[1]['mask']
|
||||
mask = mask.to(device=device)
|
||||
modified = c[1].copy()
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[2] != h or mask.shape[3] != w:
|
||||
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
|
||||
|
||||
if modified.get("set_area_to_bounds", False):
|
||||
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
|
||||
boxes, is_empty = get_mask_aabb(bounds)
|
||||
if is_empty[0]:
|
||||
# Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway)
|
||||
modified['area'] = (8, 8, 0, 0)
|
||||
else:
|
||||
box = boxes[0]
|
||||
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
|
||||
H = max(8, H)
|
||||
W = max(8, W)
|
||||
area = (int(H), int(W), int(Y), int(X))
|
||||
modified['area'] = area
|
||||
|
||||
modified['mask'] = mask
|
||||
conditions[i] = [c[0], modified]
|
||||
|
||||
def create_cond_with_same_area_if_none(conds, c):
|
||||
if 'area' not in c[1]:
|
||||
return
|
||||
@ -306,8 +424,7 @@ def create_cond_with_same_area_if_none(conds, c):
|
||||
n = c[1].copy()
|
||||
conds += [[smallest[0], n]]
|
||||
|
||||
|
||||
def apply_control_net_to_equal_area(conds, uncond):
|
||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
cond_cnets = []
|
||||
cond_other = []
|
||||
uncond_cnets = []
|
||||
@ -315,15 +432,15 @@ def apply_control_net_to_equal_area(conds, uncond):
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
if 'area' not in x[1]:
|
||||
if 'control' in x[1] and x[1]['control'] is not None:
|
||||
cond_cnets.append(x[1]['control'])
|
||||
if name in x[1] and x[1][name] is not None:
|
||||
cond_cnets.append(x[1][name])
|
||||
else:
|
||||
cond_other.append((x, t))
|
||||
for t in range(len(uncond)):
|
||||
x = uncond[t]
|
||||
if 'area' not in x[1]:
|
||||
if 'control' in x[1] and x[1]['control'] is not None:
|
||||
uncond_cnets.append(x[1]['control'])
|
||||
if name in x[1] and x[1][name] is not None:
|
||||
uncond_cnets.append(x[1][name])
|
||||
else:
|
||||
uncond_other.append((x, t))
|
||||
|
||||
@ -333,15 +450,16 @@ def apply_control_net_to_equal_area(conds, uncond):
|
||||
for x in range(len(cond_cnets)):
|
||||
temp = uncond_other[x % len(uncond_other)]
|
||||
o = temp[0]
|
||||
if 'control' in o[1] and o[1]['control'] is not None:
|
||||
if name in o[1] and o[1][name] is not None:
|
||||
n = o[1].copy()
|
||||
n['control'] = cond_cnets[x]
|
||||
n[name] = uncond_fill_func(cond_cnets, x)
|
||||
uncond += [[o[0], n]]
|
||||
else:
|
||||
n = o[1].copy()
|
||||
n['control'] = cond_cnets[x]
|
||||
n[name] = uncond_fill_func(cond_cnets, x)
|
||||
uncond[temp[1]] = [o[0], n]
|
||||
|
||||
|
||||
def encode_adm(noise_augmentor, conds, batch_size, device):
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
@ -371,12 +489,13 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
|
||||
else:
|
||||
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
|
||||
x[1] = x[1].copy()
|
||||
x[1]["adm"] = torch.cat([adm_out] * batch_size)
|
||||
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size)
|
||||
|
||||
return conds
|
||||
|
||||
|
||||
class KSampler:
|
||||
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
||||
SCHEDULERS = ["normal", "karras", "simple", "ddim_uniform"]
|
||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
||||
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||
@ -403,7 +522,7 @@ class KSampler:
|
||||
self.denoise = denoise
|
||||
self.model_options = model_options
|
||||
|
||||
def _calculate_sigmas(self, steps):
|
||||
def calculate_sigmas(self, steps):
|
||||
sigmas = None
|
||||
|
||||
discard_penultimate_sigma = False
|
||||
@ -412,13 +531,13 @@ class KSampler:
|
||||
discard_penultimate_sigma = True
|
||||
|
||||
if self.scheduler == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, device=self.device)
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
||||
elif self.scheduler == "normal":
|
||||
sigmas = self.model_wrap.get_sigmas(steps).to(self.device)
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
elif self.scheduler == "simple":
|
||||
sigmas = simple_scheduler(self.model_wrap, steps).to(self.device)
|
||||
sigmas = simple_scheduler(self.model_wrap, steps)
|
||||
elif self.scheduler == "ddim_uniform":
|
||||
sigmas = ddim_scheduler(self.model_wrap, steps).to(self.device)
|
||||
sigmas = ddim_scheduler(self.model_wrap, steps)
|
||||
else:
|
||||
print("error invalid scheduler", self.scheduler)
|
||||
|
||||
@ -429,15 +548,15 @@ class KSampler:
|
||||
def set_steps(self, steps, denoise=None):
|
||||
self.steps = steps
|
||||
if denoise is None or denoise > 0.9999:
|
||||
self.sigmas = self._calculate_sigmas(steps)
|
||||
self.sigmas = self.calculate_sigmas(steps).to(self.device)
|
||||
else:
|
||||
new_steps = int(steps/denoise)
|
||||
sigmas = self._calculate_sigmas(new_steps)
|
||||
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
||||
self.sigmas = sigmas[-(steps + 1):]
|
||||
|
||||
|
||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None):
|
||||
sigmas = self.sigmas
|
||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False):
|
||||
if sigmas is None:
|
||||
sigmas = self.sigmas
|
||||
sigma_min = self.sigma_min
|
||||
|
||||
if last_step is not None and last_step < (len(sigmas) - 1):
|
||||
@ -457,13 +576,18 @@ class KSampler:
|
||||
|
||||
positive = positive[:]
|
||||
negative = negative[:]
|
||||
|
||||
resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
|
||||
resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
|
||||
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for c in positive:
|
||||
create_cond_with_same_area_if_none(negative, c)
|
||||
for c in negative:
|
||||
create_cond_with_same_area_if_none(positive, c)
|
||||
|
||||
apply_control_net_to_equal_area(positive, negative)
|
||||
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
|
||||
if self.model.model.diffusion_model.dtype == torch.float16:
|
||||
precision_scope = torch.autocast
|
||||
@ -499,9 +623,9 @@ class KSampler:
|
||||
|
||||
with precision_scope(model_management.get_autocast_device(self.device)):
|
||||
if self.sampler == "uni_pc":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask)
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
||||
elif self.sampler == "uni_pc_bh2":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2')
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
||||
elif self.sampler == "ddim":
|
||||
timesteps = []
|
||||
for s in range(sigmas.shape[0]):
|
||||
@ -509,6 +633,12 @@ class KSampler:
|
||||
noise_mask = None
|
||||
if denoise_mask is not None:
|
||||
noise_mask = 1.0 - denoise_mask
|
||||
|
||||
ddim_callback = None
|
||||
if callback is not None:
|
||||
total_steps = len(timesteps) - 1
|
||||
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
||||
|
||||
sampler = DDIMSampler(self.model, device=self.device)
|
||||
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
||||
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
|
||||
@ -522,11 +652,13 @@ class KSampler:
|
||||
eta=0.0,
|
||||
x_T=z_enc,
|
||||
x0=latent_image,
|
||||
img_callback=ddim_callback,
|
||||
denoise_function=sampling_function,
|
||||
extra_args=extra_args,
|
||||
mask=noise_mask,
|
||||
to_zero=sigmas[-1]==0,
|
||||
end_step=sigmas.shape[0] - 1)
|
||||
end_step=sigmas.shape[0] - 1,
|
||||
disable_pbar=disable_pbar)
|
||||
|
||||
else:
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
@ -535,13 +667,18 @@ class KSampler:
|
||||
|
||||
noise = noise * sigmas[0]
|
||||
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
if callback is not None:
|
||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||
|
||||
if latent_image is not None:
|
||||
noise += latent_image
|
||||
if self.sampler == "dpm_fast":
|
||||
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args)
|
||||
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
elif self.sampler == "dpm_adaptive":
|
||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args)
|
||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
else:
|
||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args)
|
||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
|
||||
return samples.to(torch.float32)
|
||||
|
||||
201
comfy/sd.py
201
comfy/sd.py
@ -2,9 +2,9 @@ import torch
|
||||
import contextlib
|
||||
import copy
|
||||
|
||||
import sd1_clip
|
||||
import sd2_clip
|
||||
import model_management
|
||||
from . import sd1_clip
|
||||
from . import sd2_clip
|
||||
from comfy import model_management
|
||||
from .ldm.util import instantiate_from_config
|
||||
from .ldm.models.autoencoder import AutoencoderKL
|
||||
import yaml
|
||||
@ -13,6 +13,7 @@ from .t2i_adapter import adapter
|
||||
|
||||
from . import utils
|
||||
from . import clip_vision
|
||||
from . import gligen
|
||||
|
||||
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
@ -110,6 +111,8 @@ def load_lora(path, to_load):
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
|
||||
######## loha
|
||||
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||
@ -131,6 +134,54 @@ def load_lora(path, to_load):
|
||||
loaded_keys.add(hada_w2_a_name)
|
||||
loaded_keys.add(hada_w2_b_name)
|
||||
|
||||
|
||||
######## lokr
|
||||
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||
|
||||
lokr_w1 = None
|
||||
if lokr_w1_name in lora.keys():
|
||||
lokr_w1 = lora[lokr_w1_name]
|
||||
loaded_keys.add(lokr_w1_name)
|
||||
|
||||
lokr_w2 = None
|
||||
if lokr_w2_name in lora.keys():
|
||||
lokr_w2 = lora[lokr_w2_name]
|
||||
loaded_keys.add(lokr_w2_name)
|
||||
|
||||
lokr_w1_a = None
|
||||
if lokr_w1_a_name in lora.keys():
|
||||
lokr_w1_a = lora[lokr_w1_a_name]
|
||||
loaded_keys.add(lokr_w1_a_name)
|
||||
|
||||
lokr_w1_b = None
|
||||
if lokr_w1_b_name in lora.keys():
|
||||
lokr_w1_b = lora[lokr_w1_b_name]
|
||||
loaded_keys.add(lokr_w1_b_name)
|
||||
|
||||
lokr_w2_a = None
|
||||
if lokr_w2_a_name in lora.keys():
|
||||
lokr_w2_a = lora[lokr_w2_a_name]
|
||||
loaded_keys.add(lokr_w2_a_name)
|
||||
|
||||
lokr_w2_b = None
|
||||
if lokr_w2_b_name in lora.keys():
|
||||
lokr_w2_b = lora[lokr_w2_b_name]
|
||||
loaded_keys.add(lokr_w2_b_name)
|
||||
|
||||
lokr_t2 = None
|
||||
if lokr_t2_name in lora.keys():
|
||||
lokr_t2 = lora[lokr_t2_name]
|
||||
loaded_keys.add(lokr_t2_name)
|
||||
|
||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
|
||||
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
print("lora key not loaded", x)
|
||||
@ -250,6 +301,32 @@ class ModelPatcher:
|
||||
def set_model_tomesd(self, ratio):
|
||||
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
|
||||
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||
|
||||
|
||||
def set_model_patch(self, patch, name):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" not in to:
|
||||
to["patches"] = {}
|
||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||
|
||||
def set_model_attn1_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn1_patch")
|
||||
|
||||
def set_model_attn2_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn2_patch")
|
||||
|
||||
def model_patches_to(self, device):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
patches = to["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "to"):
|
||||
patch_list[i] = patch_list[i].to(device)
|
||||
|
||||
def model_dtype(self):
|
||||
return self.model.diffusion_model.dtype
|
||||
|
||||
@ -288,6 +365,33 @@ class ModelPatcher:
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
|
||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
||||
elif len(v) == 8: #lokr
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
dim = None
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(w1_a.float(), w1_b.float())
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(w2_a.float(), w2_b.float())
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float())
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
if v[2] is not None and dim is not None:
|
||||
alpha *= v[2] / dim
|
||||
|
||||
weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
||||
else: #loha
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
@ -342,10 +446,10 @@ class CLIP:
|
||||
else:
|
||||
params = {}
|
||||
|
||||
if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder":
|
||||
if self.target_clip.endswith("FrozenOpenCLIPEmbedder"):
|
||||
clip = sd2_clip.SD2ClipModel
|
||||
tokenizer = sd2_clip.SD2Tokenizer
|
||||
elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder":
|
||||
elif self.target_clip.endswith("FrozenCLIPEmbedder"):
|
||||
clip = sd1_clip.SD1ClipModel
|
||||
tokenizer = sd1_clip.SD1Tokenizer
|
||||
|
||||
@ -372,10 +476,12 @@ class CLIP:
|
||||
def clip_layer(self, layer_idx):
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def encode(self, text):
|
||||
def tokenize(self, text, return_word_ids=False):
|
||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
def encode_from_tokens(self, tokens, return_pooled=False):
|
||||
if self.layer_idx is not None:
|
||||
self.cond_stage_model.clip_layer(self.layer_idx)
|
||||
tokens = self.tokenizer.tokenize_with_weights(text)
|
||||
try:
|
||||
self.patcher.patch_model()
|
||||
cond = self.cond_stage_model.encode_token_weights(tokens)
|
||||
@ -383,8 +489,16 @@ class CLIP:
|
||||
except Exception as e:
|
||||
self.patcher.unpatch_model()
|
||||
raise e
|
||||
if return_pooled:
|
||||
eos_token_index = max(range(len(tokens[0])), key=tokens[0].__getitem__)
|
||||
pooled = cond[:, eos_token_index]
|
||||
return cond, pooled
|
||||
return cond
|
||||
|
||||
def encode(self, text):
|
||||
tokens = self.tokenize(text)
|
||||
return self.encode_from_tokens(tokens)
|
||||
|
||||
class VAE:
|
||||
def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None):
|
||||
if config is None:
|
||||
@ -400,11 +514,16 @@ class VAE:
|
||||
self.device = device
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = utils.ProgressBar(steps)
|
||||
|
||||
decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0)
|
||||
output = torch.clamp((
|
||||
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8))
|
||||
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar))
|
||||
/ 3.0) / 2.0, min=0.0, max=1.0)
|
||||
return output
|
||||
|
||||
@ -448,20 +567,23 @@ class VAE:
|
||||
model_management.unload_model()
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
|
||||
samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4)
|
||||
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4)
|
||||
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4)
|
||||
|
||||
steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = utils.ProgressBar(steps)
|
||||
|
||||
samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||
samples /= 3.0
|
||||
self.first_stage_model = self.first_stage_model.cpu()
|
||||
samples = samples.cpu()
|
||||
return samples
|
||||
|
||||
def resize_image_to(tensor, target_latent_tensor, batched_number):
|
||||
tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center")
|
||||
target_batch_size = target_latent_tensor.shape[0]
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
print(current_batch_size, target_batch_size)
|
||||
#print(current_batch_size, target_batch_size)
|
||||
if current_batch_size == 1:
|
||||
return tensor
|
||||
|
||||
@ -498,7 +620,9 @@ class ControlNet:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device)
|
||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
if self.control_model.dtype == torch.float16:
|
||||
precision_scope = torch.autocast
|
||||
@ -555,10 +679,10 @@ class ControlNet:
|
||||
c.strength = self.strength
|
||||
return c
|
||||
|
||||
def get_control_models(self):
|
||||
def get_models(self):
|
||||
out = []
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_control_models()
|
||||
out += self.previous_controlnet.get_models()
|
||||
out.append(self.control_model)
|
||||
return out
|
||||
|
||||
@ -669,10 +793,14 @@ class T2IAdapter:
|
||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.control_input = None
|
||||
self.cond_hint = None
|
||||
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device)
|
||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device)
|
||||
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
||||
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
if self.control_input is None:
|
||||
self.t2i_model.to(self.device)
|
||||
self.control_input = self.t2i_model(self.cond_hint)
|
||||
self.t2i_model.cpu()
|
||||
@ -728,10 +856,10 @@ class T2IAdapter:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
|
||||
def get_control_models(self):
|
||||
def get_models(self):
|
||||
out = []
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_control_models()
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
def load_t2i_adapter(t2i_data):
|
||||
@ -771,13 +899,20 @@ def load_clip(ckpt_path, embedding_directory=None):
|
||||
clip_data = utils.load_torch_file(ckpt_path)
|
||||
config = {}
|
||||
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
|
||||
config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
||||
config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
||||
else:
|
||||
config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder'
|
||||
config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder'
|
||||
clip = CLIP(config=config, embedding_directory=embedding_directory)
|
||||
clip.load_from_state_dict(clip_data)
|
||||
return clip
|
||||
|
||||
def load_gligen(ckpt_path):
|
||||
data = utils.load_torch_file(ckpt_path)
|
||||
model = gligen.load_gligen(data)
|
||||
if model_management.should_use_fp16():
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
with open(config_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
@ -842,9 +977,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if output_clip:
|
||||
clip_config = {}
|
||||
if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys:
|
||||
clip_config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
||||
clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
||||
else:
|
||||
clip_config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder'
|
||||
clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder'
|
||||
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
load_state_dict_to = [w]
|
||||
@ -865,7 +1000,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0]
|
||||
noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2"
|
||||
params["noise_schedule_config"] = noise_schedule_config
|
||||
noise_aug_config['target'] = "ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation"
|
||||
noise_aug_config['target'] = "comfy.ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation"
|
||||
if size == 1280: #h
|
||||
params["timestep_dim"] = 1024
|
||||
elif size == 1024: #l
|
||||
@ -917,19 +1052,19 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
|
||||
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
|
||||
|
||||
sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
||||
model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
|
||||
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
||||
model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
|
||||
|
||||
if noise_aug_config is not None: #SD2.x unclip model
|
||||
sd_config["noise_aug_config"] = noise_aug_config
|
||||
sd_config["image_size"] = 96
|
||||
sd_config["embedding_dropout"] = 0.25
|
||||
sd_config["conditioning_key"] = 'crossattn-adm'
|
||||
model_config["target"] = "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
|
||||
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
|
||||
elif unet_config["in_channels"] > 4: #inpainting model
|
||||
sd_config["conditioning_key"] = "hybrid"
|
||||
sd_config["finetune_keys"] = None
|
||||
model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||
else:
|
||||
sd_config["conditioning_key"] = "crossattn"
|
||||
|
||||
|
||||
@ -2,6 +2,8 @@ import os
|
||||
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
|
||||
import torch
|
||||
import traceback
|
||||
import zipfile
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
@ -170,10 +172,39 @@ def unescape_important(text):
|
||||
text = text.replace("\0\2", "(")
|
||||
return text
|
||||
|
||||
def safe_load_embed_zip(embed_path):
|
||||
with zipfile.ZipFile(embed_path) as myzip:
|
||||
names = list(filter(lambda a: "data/" in a, myzip.namelist()))
|
||||
names.reverse()
|
||||
for n in names:
|
||||
with myzip.open(n) as myfile:
|
||||
data = myfile.read()
|
||||
number = len(data) // 4
|
||||
length_embed = 1024 #sd2.x
|
||||
if number < 768:
|
||||
continue
|
||||
if number % 768 == 0:
|
||||
length_embed = 768 #sd1.x
|
||||
num_embeds = number // length_embed
|
||||
embed = torch.frombuffer(data, dtype=torch.float)
|
||||
out = embed.reshape((num_embeds, length_embed)).clone()
|
||||
del embed
|
||||
return out
|
||||
|
||||
def expand_directory_list(directories):
|
||||
dirs = set()
|
||||
for x in directories:
|
||||
dirs.add(x)
|
||||
for root, subdir, file in os.walk(x, followlinks=True):
|
||||
dirs.add(root)
|
||||
return list(dirs)
|
||||
|
||||
def load_embed(embedding_name, embedding_directory):
|
||||
if isinstance(embedding_directory, str):
|
||||
embedding_directory = [embedding_directory]
|
||||
|
||||
embedding_directory = expand_directory_list(embedding_directory)
|
||||
|
||||
valid_file = None
|
||||
for embed_dir in embedding_directory:
|
||||
embed_path = os.path.join(embed_dir, embedding_name)
|
||||
@ -194,19 +225,33 @@ def load_embed(embedding_name, embedding_directory):
|
||||
|
||||
embed_path = valid_file
|
||||
|
||||
if embed_path.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
||||
else:
|
||||
if 'weights_only' in torch.load.__code__.co_varnames:
|
||||
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
||||
embed_out = None
|
||||
|
||||
try:
|
||||
if embed_path.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
||||
else:
|
||||
embed = torch.load(embed_path, map_location="cpu")
|
||||
if 'string_to_param' in embed:
|
||||
values = embed['string_to_param'].values()
|
||||
else:
|
||||
values = embed.values()
|
||||
return next(iter(values))
|
||||
if 'weights_only' in torch.load.__code__.co_varnames:
|
||||
try:
|
||||
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
||||
except:
|
||||
embed_out = safe_load_embed_zip(embed_path)
|
||||
else:
|
||||
embed = torch.load(embed_path, map_location="cpu")
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
print()
|
||||
print("error loading embedding, skipping loading:", embedding_name)
|
||||
return None
|
||||
|
||||
if embed_out is None:
|
||||
if 'string_to_param' in embed:
|
||||
values = embed['string_to_param'].values()
|
||||
else:
|
||||
values = embed.values()
|
||||
embed_out = next(iter(values))
|
||||
return embed_out
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):
|
||||
@ -224,60 +269,97 @@ class SD1Tokenizer:
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
self.embedding_directory = embedding_directory
|
||||
self.max_word_length = 8
|
||||
self.embedding_identifier = "embedding:"
|
||||
|
||||
def _try_get_embedding(self, embedding_name:str):
|
||||
'''
|
||||
Takes a potential embedding name and tries to retrieve it.
|
||||
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
||||
'''
|
||||
embed = load_embed(embedding_name, self.embedding_directory)
|
||||
if embed is None:
|
||||
stripped = embedding_name.strip(',')
|
||||
if len(stripped) < len(embedding_name):
|
||||
embed = load_embed(stripped, self.embedding_directory)
|
||||
return (embed, embedding_name[len(stripped):])
|
||||
return (embed, "")
|
||||
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
'''
|
||||
Takes a prompt and converts it to a list of (token, weight, word id) elements.
|
||||
Tokens can both be integer tokens and pre computed CLIP tensors.
|
||||
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
|
||||
Returned list has the dimensions NxM where M is the input size of CLIP
|
||||
'''
|
||||
if self.pad_with_end:
|
||||
pad_token = self.end_token
|
||||
else:
|
||||
pad_token = 0
|
||||
|
||||
def tokenize_with_weights(self, text):
|
||||
text = escape_important(text)
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
|
||||
#tokenize words
|
||||
tokens = []
|
||||
for t in parsed_weights:
|
||||
to_tokenize = unescape_important(t[0]).replace("\n", " ").split(' ')
|
||||
while len(to_tokenize) > 0:
|
||||
word = to_tokenize.pop(0)
|
||||
temp_tokens = []
|
||||
embedding_identifier = "embedding:"
|
||||
if word.startswith(embedding_identifier) and self.embedding_directory is not None:
|
||||
embedding_name = word[len(embedding_identifier):].strip('\n')
|
||||
embed = load_embed(embedding_name, self.embedding_directory)
|
||||
for weighted_segment, weight in parsed_weights:
|
||||
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
|
||||
to_tokenize = [x for x in to_tokenize if x != ""]
|
||||
for word in to_tokenize:
|
||||
#if we find an embedding, deal with the embedding
|
||||
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
|
||||
embedding_name = word[len(self.embedding_identifier):].strip('\n')
|
||||
embed, leftover = self._try_get_embedding(embedding_name)
|
||||
if embed is None:
|
||||
stripped = embedding_name.strip(',')
|
||||
if len(stripped) < len(embedding_name):
|
||||
embed = load_embed(stripped, self.embedding_directory)
|
||||
if embed is not None:
|
||||
to_tokenize.insert(0, embedding_name[len(stripped):])
|
||||
|
||||
if embed is not None:
|
||||
if len(embed.shape) == 1:
|
||||
temp_tokens += [(embed, t[1])]
|
||||
else:
|
||||
for x in range(embed.shape[0]):
|
||||
temp_tokens += [(embed[x], t[1])]
|
||||
print(f"warning, embedding:{embedding_name} does not exist, ignoring")
|
||||
else:
|
||||
print("warning, embedding:{} does not exist, ignoring".format(embedding_name))
|
||||
elif len(word) > 0:
|
||||
tt = self.tokenizer(word)["input_ids"][1:-1]
|
||||
for x in tt:
|
||||
temp_tokens += [(x, t[1])]
|
||||
tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section)
|
||||
if len(embed.shape) == 1:
|
||||
tokens.append([(embed, weight)])
|
||||
else:
|
||||
tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
|
||||
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word
|
||||
if leftover != "":
|
||||
word = leftover
|
||||
else:
|
||||
continue
|
||||
#parse word
|
||||
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]])
|
||||
|
||||
#try not to split words in different sections
|
||||
if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length):
|
||||
for x in range(tokens_left):
|
||||
tokens += [(self.end_token, 1.0)]
|
||||
tokens += temp_tokens
|
||||
#reshape token array to CLIP input size
|
||||
batched_tokens = []
|
||||
batch = [(self.start_token, 1.0, 0)]
|
||||
batched_tokens.append(batch)
|
||||
for i, t_group in enumerate(tokens):
|
||||
#determine if we're going to try and keep the tokens in a single batch
|
||||
is_large = len(t_group) >= self.max_word_length
|
||||
|
||||
out_tokens = []
|
||||
for x in range(0, len(tokens), self.max_tokens_per_section):
|
||||
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))]
|
||||
o_token += [(self.end_token, 1.0)]
|
||||
if self.pad_with_end:
|
||||
o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token))
|
||||
else:
|
||||
o_token +=[(0, 1.0)] * (self.max_length - len(o_token))
|
||||
while len(t_group) > 0:
|
||||
if len(t_group) + len(batch) > self.max_length - 1:
|
||||
remaining_length = self.max_length - len(batch) - 1
|
||||
#break word in two and add end token
|
||||
if is_large:
|
||||
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
t_group = t_group[remaining_length:]
|
||||
#add end token and pad
|
||||
else:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
||||
#start new batch
|
||||
batch = [(self.start_token, 1.0, 0)]
|
||||
batched_tokens.append(batch)
|
||||
else:
|
||||
batch.extend([(t,w,i+1) for t,w in t_group])
|
||||
t_group = []
|
||||
|
||||
out_tokens += [o_token]
|
||||
#fill last batch
|
||||
batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
|
||||
|
||||
if not return_word_ids:
|
||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||
|
||||
return batched_tokens
|
||||
|
||||
return out_tokens
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import sd1_clip
|
||||
from comfy import sd1_clip
|
||||
import torch
|
||||
import os
|
||||
|
||||
|
||||
@ -56,7 +56,12 @@ class Downsample(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
if not self.use_conv:
|
||||
padding = [x.shape[2] % 2, x.shape[3] % 2]
|
||||
self.op.padding = padding
|
||||
|
||||
x = self.op(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
|
||||
@ -1,11 +1,15 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
def load_torch_file(ckpt):
|
||||
def load_torch_file(ckpt, safe_load=False):
|
||||
if ckpt.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if safe_load:
|
||||
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
if "state_dict" in pl_sd:
|
||||
@ -59,8 +63,11 @@ def common_upscale(samples, width, height, upscale_method, crop):
|
||||
s = samples
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
|
||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3):
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None):
|
||||
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu")
|
||||
for b in range(samples.shape[0]):
|
||||
s = samples[b:b+1]
|
||||
@ -80,6 +87,33 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
|
||||
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
|
||||
out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask
|
||||
out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
|
||||
output[b:b+1] = out/out_div
|
||||
return output
|
||||
|
||||
|
||||
PROGRESS_BAR_HOOK = None
|
||||
def set_progress_bar_global_hook(function):
|
||||
global PROGRESS_BAR_HOOK
|
||||
PROGRESS_BAR_HOOK = function
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self, total):
|
||||
global PROGRESS_BAR_HOOK
|
||||
self.total = total
|
||||
self.current = 0
|
||||
self.hook = PROGRESS_BAR_HOOK
|
||||
|
||||
def update_absolute(self, value, total=None):
|
||||
if total is not None:
|
||||
self.total = total
|
||||
if value > self.total:
|
||||
value = self.total
|
||||
self.current = value
|
||||
if self.hook is not None:
|
||||
self.hook(self.current, self.total)
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
||||
@ -4,7 +4,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Literal
|
||||
try:
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
from typing_extensions import Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
110
comfy_extras/nodes_hypernetwork.py
Normal file
110
comfy_extras/nodes_hypernetwork.py
Normal file
@ -0,0 +1,110 @@
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import torch
|
||||
|
||||
def load_hypernetwork_patch(path, strength):
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
activation_func = sd.get('activation_func', 'linear')
|
||||
is_layer_norm = sd.get('is_layer_norm', False)
|
||||
use_dropout = sd.get('use_dropout', False)
|
||||
activate_output = sd.get('activate_output', False)
|
||||
last_layer_dropout = sd.get('last_layer_dropout', False)
|
||||
|
||||
valid_activation = {
|
||||
"linear": torch.nn.Identity,
|
||||
"relu": torch.nn.ReLU,
|
||||
"leakyrelu": torch.nn.LeakyReLU,
|
||||
"elu": torch.nn.ELU,
|
||||
"swish": torch.nn.Hardswish,
|
||||
"tanh": torch.nn.Tanh,
|
||||
"sigmoid": torch.nn.Sigmoid,
|
||||
"softsign": torch.nn.Softsign,
|
||||
}
|
||||
|
||||
if activation_func not in valid_activation:
|
||||
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
|
||||
return None
|
||||
|
||||
out = {}
|
||||
|
||||
for d in sd:
|
||||
try:
|
||||
dim = int(d)
|
||||
except:
|
||||
continue
|
||||
|
||||
output = []
|
||||
for index in [0, 1]:
|
||||
attn_weights = sd[dim][index]
|
||||
keys = attn_weights.keys()
|
||||
|
||||
linears = filter(lambda a: a.endswith(".weight"), keys)
|
||||
linears = list(map(lambda a: a[:-len(".weight")], linears))
|
||||
layers = []
|
||||
|
||||
for i in range(len(linears)):
|
||||
lin_name = linears[i]
|
||||
last_layer = (i == (len(linears) - 1))
|
||||
penultimate_layer = (i == (len(linears) - 2))
|
||||
|
||||
lin_weight = attn_weights['{}.weight'.format(lin_name)]
|
||||
lin_bias = attn_weights['{}.bias'.format(lin_name)]
|
||||
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
|
||||
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
|
||||
layers.append(layer)
|
||||
if activation_func != "linear":
|
||||
if (not last_layer) or (activate_output):
|
||||
layers.append(valid_activation[activation_func]())
|
||||
if is_layer_norm:
|
||||
layers.append(torch.nn.LayerNorm(lin_weight.shape[0]))
|
||||
if use_dropout:
|
||||
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
|
||||
layers.append(torch.nn.Dropout(p=0.3))
|
||||
|
||||
output.append(torch.nn.Sequential(*layers))
|
||||
out[dim] = torch.nn.ModuleList(output)
|
||||
|
||||
class hypernetwork_patch:
|
||||
def __init__(self, hypernet, strength):
|
||||
self.hypernet = hypernet
|
||||
self.strength = strength
|
||||
def __call__(self, current_index, q, k, v):
|
||||
dim = k.shape[-1]
|
||||
if dim in self.hypernet:
|
||||
hn = self.hypernet[dim]
|
||||
k = k + hn[0](k) * self.strength
|
||||
v = v + hn[1](v) * self.strength
|
||||
|
||||
return q, k, v
|
||||
|
||||
def to(self, device):
|
||||
for d in self.hypernet.keys():
|
||||
self.hypernet[d] = self.hypernet[d].to(device)
|
||||
return self
|
||||
|
||||
return hypernetwork_patch(out, strength)
|
||||
|
||||
class HypernetworkLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_hypernetwork"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
||||
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
|
||||
model_hypernetwork = model.clone()
|
||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||
if patch is not None:
|
||||
model_hypernetwork.set_model_attn1_patch(patch)
|
||||
model_hypernetwork.set_model_attn2_patch(patch)
|
||||
return (model_hypernetwork,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"HypernetworkLoader": HypernetworkLoader
|
||||
}
|
||||
262
comfy_extras/nodes_mask.py
Normal file
262
comfy_extras/nodes_mask.py
Normal file
@ -0,0 +1,262 @@
|
||||
import torch
|
||||
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
class LatentCompositeMasked:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"destination": ("LATENT",),
|
||||
"source": ("LATENT",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK",),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "composite"
|
||||
|
||||
CATEGORY = "latent"
|
||||
|
||||
def composite(self, destination, source, x, y, mask = None):
|
||||
output = destination.copy()
|
||||
destination = destination["samples"].clone()
|
||||
source = source["samples"]
|
||||
|
||||
x = max(-source.shape[3] * 8, min(x, destination.shape[3] * 8))
|
||||
y = max(-source.shape[2] * 8, min(y, destination.shape[2] * 8))
|
||||
|
||||
left, top = (x // 8, y // 8)
|
||||
right, bottom = (left + source.shape[3], top + source.shape[2],)
|
||||
|
||||
|
||||
if mask is None:
|
||||
mask = torch.ones_like(source)
|
||||
else:
|
||||
mask = mask.clone()
|
||||
mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear")
|
||||
mask = mask.repeat((source.shape[0], source.shape[1], 1, 1))
|
||||
|
||||
# calculate the bounds of the source that will be overlapping the destination
|
||||
# this prevents the source trying to overwrite latent pixels that are out of bounds
|
||||
# of the destination
|
||||
visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)
|
||||
|
||||
mask = mask[:, :, :visible_height, :visible_width]
|
||||
inverse_mask = torch.ones_like(mask) - mask
|
||||
|
||||
source_portion = mask * source[:, :, :visible_height, :visible_width]
|
||||
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
|
||||
|
||||
destination[:, :, top:bottom, left:right] = source_portion + destination_portion
|
||||
|
||||
output["samples"] = destination
|
||||
|
||||
return (output,)
|
||||
|
||||
class MaskToImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "mask_to_image"
|
||||
|
||||
def mask_to_image(self, mask):
|
||||
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||
return (result,)
|
||||
|
||||
class ImageToMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"channel": (["red", "green", "blue"],),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "image_to_mask"
|
||||
|
||||
def image_to_mask(self, image, channel):
|
||||
channels = ["red", "green", "blue"]
|
||||
mask = image[0, :, :, channels.index(channel)]
|
||||
return (mask,)
|
||||
|
||||
class SolidMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "solid"
|
||||
|
||||
def solid(self, value, width, height):
|
||||
out = torch.full((height, width), value, dtype=torch.float32, device="cpu")
|
||||
return (out,)
|
||||
|
||||
class InvertMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "invert"
|
||||
|
||||
def invert(self, mask):
|
||||
out = 1.0 - mask
|
||||
return (out,)
|
||||
|
||||
class CropMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "crop"
|
||||
|
||||
def crop(self, mask, x, y, width, height):
|
||||
out = mask[y:y + height, x:x + width]
|
||||
return (out,)
|
||||
|
||||
class MaskComposite:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"destination": ("MASK",),
|
||||
"source": ("MASK",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"operation": (["multiply", "add", "subtract"],),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "combine"
|
||||
|
||||
def combine(self, destination, source, x, y, operation):
|
||||
output = destination.clone()
|
||||
|
||||
left, top = (x, y,)
|
||||
right, bottom = (min(left + source.shape[1], destination.shape[1]), min(top + source.shape[0], destination.shape[0]))
|
||||
visible_width, visible_height = (right - left, bottom - top,)
|
||||
|
||||
source_portion = source[:visible_height, :visible_width]
|
||||
destination_portion = destination[top:bottom, left:right]
|
||||
|
||||
if operation == "multiply":
|
||||
output[top:bottom, left:right] = destination_portion * source_portion
|
||||
elif operation == "add":
|
||||
output[top:bottom, left:right] = destination_portion + source_portion
|
||||
elif operation == "subtract":
|
||||
output[top:bottom, left:right] = destination_portion - source_portion
|
||||
|
||||
output = torch.clamp(output, 0.0, 1.0)
|
||||
|
||||
return (output,)
|
||||
|
||||
class FeatherMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "feather"
|
||||
|
||||
def feather(self, mask, left, top, right, bottom):
|
||||
output = mask.clone()
|
||||
|
||||
left = min(left, output.shape[1])
|
||||
right = min(right, output.shape[1])
|
||||
top = min(top, output.shape[0])
|
||||
bottom = min(bottom, output.shape[0])
|
||||
|
||||
for x in range(left):
|
||||
feather_rate = (x + 1.0) / left
|
||||
output[:, x] *= feather_rate
|
||||
|
||||
for x in range(right):
|
||||
feather_rate = (x + 1) / right
|
||||
output[:, -x] *= feather_rate
|
||||
|
||||
for y in range(top):
|
||||
feather_rate = (y + 1) / top
|
||||
output[y, :] *= feather_rate
|
||||
|
||||
for y in range(bottom):
|
||||
feather_rate = (y + 1) / bottom
|
||||
output[-y, :] *= feather_rate
|
||||
|
||||
return (output,)
|
||||
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LatentCompositeMasked": LatentCompositeMasked,
|
||||
"MaskToImage": MaskToImage,
|
||||
"ImageToMask": ImageToMask,
|
||||
"SolidMask": SolidMask,
|
||||
"InvertMask": InvertMask,
|
||||
"CropMask": CropMask,
|
||||
"MaskComposite": MaskComposite,
|
||||
"FeatherMask": FeatherMask,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageToMask": "Convert Image to Mask",
|
||||
"MaskToImage": "Convert Mask to Image",
|
||||
}
|
||||
108
comfy_extras/nodes_rebatch.py
Normal file
108
comfy_extras/nodes_rebatch.py
Normal file
@ -0,0 +1,108 @@
|
||||
import torch
|
||||
|
||||
class LatentRebatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "latents": ("LATENT",),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
INPUT_IS_LIST = True
|
||||
OUTPUT_IS_LIST = (True, )
|
||||
|
||||
FUNCTION = "rebatch"
|
||||
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
@staticmethod
|
||||
def get_batch(latents, list_ind, offset):
|
||||
'''prepare a batch out of the list of latents'''
|
||||
samples = latents[list_ind]['samples']
|
||||
shape = samples.shape
|
||||
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
|
||||
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
|
||||
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
|
||||
if mask.shape[0] < samples.shape[0]:
|
||||
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
|
||||
if 'batch_index' in latents[list_ind]:
|
||||
batch_inds = latents[list_ind]['batch_index']
|
||||
else:
|
||||
batch_inds = [x+offset for x in range(shape[0])]
|
||||
return samples, mask, batch_inds
|
||||
|
||||
@staticmethod
|
||||
def get_slices(indexable, num, batch_size):
|
||||
'''divides an indexable object into num slices of length batch_size, and a remainder'''
|
||||
slices = []
|
||||
for i in range(num):
|
||||
slices.append(indexable[i*batch_size:(i+1)*batch_size])
|
||||
if num * batch_size < len(indexable):
|
||||
return slices, indexable[num * batch_size:]
|
||||
else:
|
||||
return slices, None
|
||||
|
||||
@staticmethod
|
||||
def slice_batch(batch, num, batch_size):
|
||||
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
|
||||
return list(zip(*result))
|
||||
|
||||
@staticmethod
|
||||
def cat_batch(batch1, batch2):
|
||||
if batch1[0] is None:
|
||||
return batch2
|
||||
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
|
||||
return result
|
||||
|
||||
def rebatch(self, latents, batch_size):
|
||||
batch_size = batch_size[0]
|
||||
|
||||
output_list = []
|
||||
current_batch = (None, None, None)
|
||||
processed = 0
|
||||
|
||||
for i in range(len(latents)):
|
||||
# fetch new entry of list
|
||||
#samples, masks, indices = self.get_batch(latents, i)
|
||||
next_batch = self.get_batch(latents, i, processed)
|
||||
processed += len(next_batch[2])
|
||||
# set to current if current is None
|
||||
if current_batch[0] is None:
|
||||
current_batch = next_batch
|
||||
# add previous to list if dimensions do not match
|
||||
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
|
||||
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
||||
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||
current_batch = next_batch
|
||||
# cat if everything checks out
|
||||
else:
|
||||
current_batch = self.cat_batch(current_batch, next_batch)
|
||||
|
||||
# add to list if dimensions gone above target batch size
|
||||
if current_batch[0].shape[0] > batch_size:
|
||||
num = current_batch[0].shape[0] // batch_size
|
||||
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
|
||||
|
||||
for i in range(num):
|
||||
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
|
||||
|
||||
current_batch = remainder
|
||||
|
||||
#add remainder
|
||||
if current_batch[0] is not None:
|
||||
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
||||
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||
|
||||
#get rid of empty masks
|
||||
for s in output_list:
|
||||
if s['noise_mask'].mean() == 1.0:
|
||||
del s['noise_mask']
|
||||
|
||||
return (output_list,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"RebatchLatents": LatentRebatch,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"RebatchLatents": "Rebatch Latents",
|
||||
}
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from comfy_extras.chainner_models import model_loading
|
||||
import model_management
|
||||
from comfy import model_management
|
||||
import torch
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
@ -17,7 +17,7 @@ class UpscaleModelLoader:
|
||||
|
||||
def load_model(self, model_name):
|
||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
||||
sd = comfy.utils.load_torch_file(model_path)
|
||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||
out = model_loading.load_state_dict(sd).eval()
|
||||
return (out, )
|
||||
|
||||
@ -37,7 +37,12 @@ class ImageUpscaleWithModel:
|
||||
device = model_management.get_torch_device()
|
||||
upscale_model.to(device)
|
||||
in_img = image.movedim(-1,-3).to(device)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=128 + 64, tile_y=128 + 64, overlap = 8, upscale_amount=upscale_model.scale)
|
||||
|
||||
tile = 128 + 64
|
||||
overlap = 8
|
||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||
upscale_model.cpu()
|
||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
||||
return (s,)
|
||||
|
||||
248
execution.py
248
execution.py
@ -6,10 +6,13 @@ import threading
|
||||
import heapq
|
||||
import traceback
|
||||
import gc
|
||||
import time
|
||||
|
||||
import torch
|
||||
import nodes
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_data_all = {}
|
||||
@ -24,29 +27,88 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
input_data_all[x] = obj
|
||||
else:
|
||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
||||
input_data_all[x] = input_data
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = prompt
|
||||
input_data_all[x] = [prompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
if "extra_pnginfo" in extra_data:
|
||||
input_data_all[x] = extra_data['extra_pnginfo']
|
||||
input_data_all[x] = [extra_data['extra_pnginfo']]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = unique_id
|
||||
input_data_all[x] = [unique_id]
|
||||
return input_data_all
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
|
||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
# check if node wants the lists
|
||||
intput_is_list = False
|
||||
if hasattr(obj, "INPUT_IS_LIST"):
|
||||
intput_is_list = obj.INPUT_IS_LIST
|
||||
|
||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||
|
||||
# get a slice of inputs, repeat last input when list isn't long enough
|
||||
def slice_dict(d, i):
|
||||
d_new = dict()
|
||||
for k,v in d.items():
|
||||
d_new[k] = v[i if len(v) > i else -1]
|
||||
return d_new
|
||||
|
||||
results = []
|
||||
if intput_is_list:
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**input_data_all))
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||
return results
|
||||
|
||||
def get_output_data(obj, input_data_all):
|
||||
|
||||
results = []
|
||||
uis = []
|
||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
||||
|
||||
for r in return_values:
|
||||
if isinstance(r, dict):
|
||||
if 'ui' in r:
|
||||
uis.append(r['ui'])
|
||||
if 'result' in r:
|
||||
results.append(r['result'])
|
||||
else:
|
||||
results.append(r)
|
||||
|
||||
output = []
|
||||
if len(results) > 0:
|
||||
# check which outputs need concatenating
|
||||
output_is_list = [False] * len(results[0])
|
||||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||
output_is_list = obj.OUTPUT_IS_LIST
|
||||
|
||||
# merge node execution results
|
||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||
if is_list:
|
||||
output.append([x for o in results for x in o[i]])
|
||||
else:
|
||||
output.append([o[i] for o in results])
|
||||
|
||||
ui = dict()
|
||||
if len(uis) > 0:
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if unique_id in outputs:
|
||||
return []
|
||||
|
||||
executed = []
|
||||
return
|
||||
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
@ -55,22 +117,21 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
executed += recursive_execute(server, prompt, outputs, input_unique_id, extra_data)
|
||||
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
|
||||
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = unique_id
|
||||
server.send_sync("executing", { "node": unique_id }, server.client_id)
|
||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||
obj = class_def()
|
||||
|
||||
nodes.before_node_execution()
|
||||
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
|
||||
if "ui" in outputs[unique_id]:
|
||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||
outputs[unique_id] = output_data
|
||||
if len(output_ui) > 0:
|
||||
outputs_ui[unique_id] = output_ui
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
|
||||
if "result" in outputs[unique_id]:
|
||||
outputs[unique_id] = outputs[unique_id]["result"]
|
||||
return executed + [unique_id]
|
||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
executed.add(unique_id)
|
||||
|
||||
def recursive_will_execute(prompt, outputs, current_item):
|
||||
unique_id = current_item
|
||||
@ -97,40 +158,45 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
|
||||
is_changed_old = ''
|
||||
is_changed = ''
|
||||
to_delete = False
|
||||
if hasattr(class_def, 'IS_CHANGED'):
|
||||
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
|
||||
is_changed_old = old_prompt[unique_id]['is_changed']
|
||||
if 'is_changed' not in prompt[unique_id]:
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
||||
if input_data_all is not None:
|
||||
is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||
prompt[unique_id]['is_changed'] = is_changed
|
||||
try:
|
||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
prompt[unique_id]['is_changed'] = is_changed
|
||||
except:
|
||||
to_delete = True
|
||||
else:
|
||||
is_changed = prompt[unique_id]['is_changed']
|
||||
|
||||
if unique_id not in outputs:
|
||||
return True
|
||||
|
||||
to_delete = False
|
||||
if is_changed != is_changed_old:
|
||||
to_delete = True
|
||||
elif unique_id not in old_prompt:
|
||||
to_delete = True
|
||||
elif inputs == old_prompt[unique_id]['inputs']:
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
if not to_delete:
|
||||
if is_changed != is_changed_old:
|
||||
to_delete = True
|
||||
elif unique_id not in old_prompt:
|
||||
to_delete = True
|
||||
elif inputs == old_prompt[unique_id]['inputs']:
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id in outputs:
|
||||
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
|
||||
else:
|
||||
to_delete = True
|
||||
if to_delete:
|
||||
break
|
||||
else:
|
||||
to_delete = True
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id in outputs:
|
||||
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
|
||||
else:
|
||||
to_delete = True
|
||||
if to_delete:
|
||||
break
|
||||
else:
|
||||
to_delete = True
|
||||
|
||||
if to_delete:
|
||||
d = outputs.pop(unique_id)
|
||||
@ -140,10 +206,11 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
class PromptExecutor:
|
||||
def __init__(self, server):
|
||||
self.outputs = {}
|
||||
self.outputs_ui = {}
|
||||
self.old_prompt = {}
|
||||
self.server = server
|
||||
|
||||
def execute(self, prompt, extra_data={}):
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
nodes.interrupt_processing(False)
|
||||
|
||||
if "client_id" in extra_data:
|
||||
@ -151,40 +218,55 @@ class PromptExecutor:
|
||||
else:
|
||||
self.server.client_id = None
|
||||
|
||||
execution_start_time = time.perf_counter()
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
|
||||
|
||||
with torch.inference_mode():
|
||||
#delete cached outputs if nodes don't exist for them
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if o not in prompt:
|
||||
to_delete += [o]
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
|
||||
for x in prompt:
|
||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
||||
|
||||
current_outputs = set(self.outputs.keys())
|
||||
executed = []
|
||||
for x in list(self.outputs_ui.keys()):
|
||||
if x not in current_outputs:
|
||||
d = self.outputs_ui.pop(x)
|
||||
del d
|
||||
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
|
||||
executed = set()
|
||||
try:
|
||||
to_execute = []
|
||||
for x in prompt:
|
||||
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
||||
if hasattr(class_, 'OUTPUT_NODE'):
|
||||
to_execute += [(0, x)]
|
||||
for x in list(execute_outputs):
|
||||
to_execute += [(0, x)]
|
||||
|
||||
while len(to_execute) > 0:
|
||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||
x = to_execute.pop(0)[-1]
|
||||
|
||||
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
||||
if hasattr(class_, 'OUTPUT_NODE'):
|
||||
if class_.OUTPUT_NODE == True:
|
||||
valid = False
|
||||
try:
|
||||
m = validate_inputs(prompt, x)
|
||||
valid = m[0]
|
||||
except:
|
||||
valid = False
|
||||
if valid:
|
||||
executed += recursive_execute(self.server, prompt, self.outputs, x, extra_data)
|
||||
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
if isinstance(e, comfy.model_management.InterruptProcessingException):
|
||||
print("Processing interrupted")
|
||||
else:
|
||||
message = str(traceback.format_exc())
|
||||
print(message)
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id)
|
||||
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if o not in current_outputs:
|
||||
if (o not in current_outputs) and (o not in executed):
|
||||
to_delete += [o]
|
||||
if o in self.old_prompt:
|
||||
d = self.old_prompt.pop(o)
|
||||
@ -192,24 +274,23 @@ class PromptExecutor:
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
else:
|
||||
executed = set(executed)
|
||||
finally:
|
||||
for x in executed:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
finally:
|
||||
self.server.last_node_id = None
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("executing", { "node": None }, self.server.client_id)
|
||||
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
|
||||
|
||||
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
comfy.model_management.soft_empty_cache()
|
||||
|
||||
|
||||
def validate_inputs(prompt, item):
|
||||
def validate_inputs(prompt, item, validated):
|
||||
unique_id = item
|
||||
if unique_id in validated:
|
||||
return validated[unique_id]
|
||||
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
@ -230,8 +311,9 @@ def validate_inputs(prompt, item):
|
||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||
if r[val[1]] != type_input:
|
||||
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input))
|
||||
r = validate_inputs(prompt, o_id)
|
||||
r = validate_inputs(prompt, o_id, validated)
|
||||
if r[0] == False:
|
||||
validated[o_id] = r
|
||||
return r
|
||||
else:
|
||||
if type_input == "INT":
|
||||
@ -250,10 +332,21 @@ def validate_inputs(prompt, item):
|
||||
if "max" in info[1] and val > info[1]["max"]:
|
||||
return (False, "Value bigger than max. {}, {}".format(class_type, x))
|
||||
|
||||
if isinstance(type_input, list):
|
||||
if val not in type_input:
|
||||
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
||||
return (True, "")
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
||||
for r in ret:
|
||||
if r != True:
|
||||
return (False, "{}, {}".format(class_type, r))
|
||||
else:
|
||||
if isinstance(type_input, list):
|
||||
if val not in type_input:
|
||||
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
||||
|
||||
ret = (True, "")
|
||||
validated[unique_id] = ret
|
||||
return ret
|
||||
|
||||
def validate_prompt(prompt):
|
||||
outputs = set()
|
||||
@ -267,19 +360,21 @@ def validate_prompt(prompt):
|
||||
|
||||
good_outputs = set()
|
||||
errors = []
|
||||
validated = {}
|
||||
for o in outputs:
|
||||
valid = False
|
||||
reason = ""
|
||||
try:
|
||||
m = validate_inputs(prompt, o)
|
||||
m = validate_inputs(prompt, o, validated)
|
||||
valid = m[0]
|
||||
reason = m[1]
|
||||
except:
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
valid = False
|
||||
reason = "Parsing error"
|
||||
|
||||
if valid == True:
|
||||
good_outputs.add(x)
|
||||
good_outputs.add(o)
|
||||
else:
|
||||
print("Failed to validate prompt for output {} {}".format(o, reason))
|
||||
print("output will be ignored")
|
||||
@ -289,7 +384,7 @@ def validate_prompt(prompt):
|
||||
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
|
||||
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list))
|
||||
|
||||
return (True, "")
|
||||
return (True, "", list(good_outputs))
|
||||
|
||||
|
||||
class PromptQueue:
|
||||
@ -325,8 +420,7 @@ class PromptQueue:
|
||||
prompt = self.currently_running.pop(item_id)
|
||||
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
|
||||
for o in outputs:
|
||||
if "ui" in outputs[o]:
|
||||
self.history[prompt[1]]["outputs"][o] = outputs[o]["ui"]
|
||||
self.history[prompt[1]]["outputs"][o] = outputs[o]
|
||||
self.server.queue_updated()
|
||||
|
||||
def get_current_queue(self):
|
||||
|
||||
@ -13,11 +13,13 @@ a111:
|
||||
models/ESRGAN
|
||||
models/SwinIR
|
||||
embeddings: embeddings
|
||||
hypernetworks: models/hypernetworks
|
||||
controlnet: models/ControlNet
|
||||
|
||||
#other_ui:
|
||||
# base_path: path/to/ui
|
||||
# checkpoints: models/checkpoints
|
||||
|
||||
# gligen: models/gligen
|
||||
# custom_nodes: path/custom_nodes
|
||||
|
||||
|
||||
|
||||
@ -12,8 +12,8 @@ except:
|
||||
|
||||
folder_names_and_paths = {}
|
||||
|
||||
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
base_path = os.path.dirname(os.path.realpath(__file__))
|
||||
models_dir = os.path.join(base_path, "models")
|
||||
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions)
|
||||
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
|
||||
|
||||
@ -26,8 +26,14 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")]
|
||||
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
|
||||
|
||||
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
|
||||
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
|
||||
|
||||
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
|
||||
|
||||
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
||||
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
||||
@ -63,6 +69,46 @@ def get_directory_by_type(type_name):
|
||||
return None
|
||||
|
||||
|
||||
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
||||
# otherwise use default_path as base_dir
|
||||
def annotated_filepath(name):
|
||||
if name.endswith("[output]"):
|
||||
base_dir = get_output_directory()
|
||||
name = name[:-9]
|
||||
elif name.endswith("[input]"):
|
||||
base_dir = get_input_directory()
|
||||
name = name[:-8]
|
||||
elif name.endswith("[temp]"):
|
||||
base_dir = get_temp_directory()
|
||||
name = name[:-7]
|
||||
else:
|
||||
return name, None
|
||||
|
||||
return name, base_dir
|
||||
|
||||
|
||||
def get_annotated_filepath(name, default_dir=None):
|
||||
name, base_dir = annotated_filepath(name)
|
||||
|
||||
if base_dir is None:
|
||||
if default_dir is not None:
|
||||
base_dir = default_dir
|
||||
else:
|
||||
base_dir = get_input_directory() # fallback path
|
||||
|
||||
return os.path.join(base_dir, name)
|
||||
|
||||
|
||||
def exists_annotated_filepath(name):
|
||||
name, base_dir = annotated_filepath(name)
|
||||
|
||||
if base_dir is None:
|
||||
base_dir = get_input_directory() # fallback path
|
||||
|
||||
filepath = os.path.join(base_dir, name)
|
||||
return os.path.exists(filepath)
|
||||
|
||||
|
||||
def add_model_folder_path(folder_name, full_folder_path):
|
||||
global folder_names_and_paths
|
||||
if folder_name in folder_names_and_paths:
|
||||
@ -101,4 +147,37 @@ def get_filename_list(folder_name):
|
||||
output_list.update(filter_files_extensions(recursive_search(x), folders[1]))
|
||||
return sorted(list(output_list))
|
||||
|
||||
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
|
||||
def map_filename(filename):
|
||||
prefix_len = len(os.path.basename(filename_prefix))
|
||||
prefix = filename[:prefix_len + 1]
|
||||
try:
|
||||
digits = int(filename[prefix_len + 1:].split('_')[0])
|
||||
except:
|
||||
digits = 0
|
||||
return (digits, prefix)
|
||||
|
||||
def compute_vars(input, image_width, image_height):
|
||||
input = input.replace("%width%", str(image_width))
|
||||
input = input.replace("%height%", str(image_height))
|
||||
return input
|
||||
|
||||
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
||||
|
||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||
|
||||
full_output_folder = os.path.join(output_dir, subfolder)
|
||||
|
||||
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
|
||||
print("Saving image outside the output folder is not allowed.")
|
||||
return {}
|
||||
|
||||
try:
|
||||
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
||||
except ValueError:
|
||||
counter = 1
|
||||
except FileNotFoundError:
|
||||
os.makedirs(full_output_folder, exist_ok=True)
|
||||
counter = 1
|
||||
return full_output_folder, filename, counter, subfolder, filename_prefix
|
||||
|
||||
40
main.py
40
main.py
@ -5,6 +5,7 @@ import shutil
|
||||
import threading
|
||||
|
||||
from comfy.cli_args import args
|
||||
import comfy.utils
|
||||
|
||||
if os.name == "nt":
|
||||
import logging
|
||||
@ -32,21 +33,16 @@ def prompt_worker(q, server):
|
||||
e = execution.PromptExecutor(server)
|
||||
while True:
|
||||
item, item_id = q.get()
|
||||
e.execute(item[-2], item[-1])
|
||||
q.task_done(item_id, e.outputs)
|
||||
e.execute(item[2], item[1], item[3], item[4])
|
||||
q.task_done(item_id, e.outputs_ui)
|
||||
|
||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
||||
|
||||
def hijack_progress(server):
|
||||
from tqdm.auto import tqdm
|
||||
orig_func = getattr(tqdm, "update")
|
||||
def wrapped_func(*args, **kwargs):
|
||||
pbar = args[0]
|
||||
v = orig_func(*args, **kwargs)
|
||||
server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id)
|
||||
return v
|
||||
setattr(tqdm, "update", wrapped_func)
|
||||
def hook(value, total):
|
||||
server.send_sync("progress", { "value": value, "max": total}, server.client_id)
|
||||
comfy.utils.set_progress_bar_global_hook(hook)
|
||||
|
||||
def cleanup_temp():
|
||||
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||
@ -81,16 +77,6 @@ if __name__ == "__main__":
|
||||
server = server.PromptServer(loop)
|
||||
q = execution.PromptQueue(server)
|
||||
|
||||
init_custom_nodes()
|
||||
server.add_routes()
|
||||
hijack_progress(server)
|
||||
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
|
||||
|
||||
address = args.listen
|
||||
|
||||
dont_print = args.dont_print_server
|
||||
|
||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||
if os.path.isfile(extra_model_paths_config_path):
|
||||
load_extra_path_config(extra_model_paths_config_path)
|
||||
@ -99,18 +85,22 @@ if __name__ == "__main__":
|
||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||
load_extra_path_config(config_path)
|
||||
|
||||
init_custom_nodes()
|
||||
server.add_routes()
|
||||
hijack_progress(server)
|
||||
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
|
||||
|
||||
if args.output_directory:
|
||||
output_dir = os.path.abspath(args.output_directory)
|
||||
print(f"Setting output directory to: {output_dir}")
|
||||
folder_paths.set_output_directory(output_dir)
|
||||
|
||||
port = args.port
|
||||
|
||||
if args.quick_test_for_ci:
|
||||
exit(0)
|
||||
|
||||
call_on_start = None
|
||||
if args.windows_standalone_build:
|
||||
if args.auto_launch:
|
||||
def startup_server(address, port):
|
||||
import webbrowser
|
||||
webbrowser.open("http://{}:{}".format(address, port))
|
||||
@ -118,10 +108,10 @@ if __name__ == "__main__":
|
||||
|
||||
if os.name == "nt":
|
||||
try:
|
||||
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
|
||||
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
else:
|
||||
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
|
||||
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
||||
|
||||
cleanup_temp()
|
||||
|
||||
0
models/gligen/put_gligen_models_here
Normal file
0
models/gligen/put_gligen_models_here
Normal file
0
models/hypernetworks/put_hypernetworks_here
Normal file
0
models/hypernetworks/put_hypernetworks_here
Normal file
569
nodes.py
569
nodes.py
@ -5,10 +5,13 @@ import sys
|
||||
import json
|
||||
import hashlib
|
||||
import traceback
|
||||
import math
|
||||
import time
|
||||
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||
@ -16,21 +19,23 @@ sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "co
|
||||
|
||||
import comfy.diffusers_convert
|
||||
import comfy.samplers
|
||||
import comfy.sample
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
|
||||
import comfy.clip_vision
|
||||
|
||||
import model_management
|
||||
import comfy.model_management
|
||||
import importlib
|
||||
|
||||
import folder_paths
|
||||
|
||||
|
||||
def before_node_execution():
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
def interrupt_processing(value=True):
|
||||
model_management.interrupt_current_processing(value)
|
||||
comfy.model_management.interrupt_current_processing(value)
|
||||
|
||||
MAX_RESOLUTION=8192
|
||||
|
||||
@ -58,14 +63,44 @@ class ConditioningCombine:
|
||||
def combine(self, conditioning_1, conditioning_2):
|
||||
return (conditioning_1 + conditioning_2, )
|
||||
|
||||
class ConditioningAverage :
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ),
|
||||
"conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "addWeighted"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_strength):
|
||||
out = []
|
||||
|
||||
if len(conditioning_from) > 1:
|
||||
print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
||||
|
||||
cond_from = conditioning_from[0][0]
|
||||
|
||||
for i in range(len(conditioning_to)):
|
||||
t1 = conditioning_to[i][0]
|
||||
t0 = cond_from[:,:t1.shape[1]]
|
||||
if t0.shape[1] < t1.shape[1]:
|
||||
t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)
|
||||
|
||||
tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
|
||||
n = [tw, conditioning_to[i][1].copy()]
|
||||
out.append(n)
|
||||
return (out, )
|
||||
|
||||
class ConditioningSetArea:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
"width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
@ -73,21 +108,46 @@ class ConditioningSetArea:
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0):
|
||||
def append(self, conditioning, width, height, x, y, strength):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
|
||||
n[1]['strength'] = strength
|
||||
n[1]['min_sigma'] = min_sigma
|
||||
n[1]['max_sigma'] = max_sigma
|
||||
n[1]['set_area_to_bounds'] = False
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class ConditioningSetMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
"mask": ("MASK", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"set_cond_area": (["default", "mask bounds"],),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def append(self, conditioning, mask, set_cond_area, strength):
|
||||
c = []
|
||||
set_area_to_bounds = False
|
||||
if set_cond_area != "default":
|
||||
set_area_to_bounds = True
|
||||
if len(mask.shape) < 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
_, h, w = mask.shape
|
||||
n[1]['mask'] = mask
|
||||
n[1]['set_area_to_bounds'] = set_area_to_bounds
|
||||
n[1]['mask_strength'] = strength
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class VAEDecode:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||
@ -100,9 +160,6 @@ class VAEDecode:
|
||||
return (vae.decode(samples["samples"]), )
|
||||
|
||||
class VAEDecodeTiled:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||
@ -115,9 +172,6 @@ class VAEDecodeTiled:
|
||||
return (vae.decode_tiled(samples["samples"]), )
|
||||
|
||||
class VAEEncode:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
|
||||
@ -126,20 +180,22 @@ class VAEEncode:
|
||||
|
||||
CATEGORY = "latent"
|
||||
|
||||
def encode(self, vae, pixels):
|
||||
x = (pixels.shape[1] // 64) * 64
|
||||
y = (pixels.shape[2] // 64) * 64
|
||||
@staticmethod
|
||||
def vae_encode_crop_pixels(pixels):
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
pixels = pixels[:,:x,:y,:]
|
||||
t = vae.encode(pixels[:,:,:,:3])
|
||||
x_offset = (pixels.shape[1] % 8) // 2
|
||||
y_offset = (pixels.shape[2] % 8) // 2
|
||||
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
|
||||
return pixels
|
||||
|
||||
def encode(self, vae, pixels):
|
||||
pixels = self.vae_encode_crop_pixels(pixels)
|
||||
t = vae.encode(pixels[:,:,:,:3])
|
||||
return ({"samples":t}, )
|
||||
|
||||
|
||||
class VAEEncodeTiled:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
|
||||
@ -149,46 +205,123 @@ class VAEEncodeTiled:
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def encode(self, vae, pixels):
|
||||
x = (pixels.shape[1] // 64) * 64
|
||||
y = (pixels.shape[2] // 64) * 64
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
pixels = pixels[:,:x,:y,:]
|
||||
pixels = VAEEncode.vae_encode_crop_pixels(pixels)
|
||||
t = vae.encode_tiled(pixels[:,:,:,:3])
|
||||
|
||||
return ({"samples":t}, )
|
||||
class VAEEncodeForInpaint:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
class VAEEncodeForInpaint:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}}
|
||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "latent/inpaint"
|
||||
|
||||
def encode(self, vae, pixels, mask):
|
||||
x = (pixels.shape[1] // 64) * 64
|
||||
y = (pixels.shape[2] // 64) * 64
|
||||
mask = torch.nn.functional.interpolate(mask[None,None,], size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")[0][0]
|
||||
def encode(self, vae, pixels, mask, grow_mask_by=6):
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
|
||||
|
||||
pixels = pixels.clone()
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
pixels = pixels[:,:x,:y,:]
|
||||
mask = mask[:x,:y]
|
||||
x_offset = (pixels.shape[1] % 8) // 2
|
||||
y_offset = (pixels.shape[2] % 8) // 2
|
||||
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
|
||||
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
|
||||
|
||||
#grow mask by a few pixels to keep things seamless in latent space
|
||||
kernel_tensor = torch.ones((1, 1, 6, 6))
|
||||
mask_erosion = torch.clamp(torch.nn.functional.conv2d((mask.round())[None], kernel_tensor, padding=3), 0, 1)
|
||||
m = (1.0 - mask.round())
|
||||
if grow_mask_by == 0:
|
||||
mask_erosion = mask
|
||||
else:
|
||||
kernel_tensor = torch.ones((1, 1, grow_mask_by, grow_mask_by))
|
||||
padding = math.ceil((grow_mask_by - 1) / 2)
|
||||
|
||||
mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=padding), 0, 1)
|
||||
|
||||
m = (1.0 - mask.round()).squeeze(1)
|
||||
for i in range(3):
|
||||
pixels[:,:,:,i] -= 0.5
|
||||
pixels[:,:,:,i] *= m
|
||||
pixels[:,:,:,i] += 0.5
|
||||
t = vae.encode(pixels)
|
||||
|
||||
return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, )
|
||||
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
|
||||
|
||||
|
||||
class SaveLatent:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ),
|
||||
"filename_prefix": ("STRING", {"default": "latents/ComfyUI"})},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
|
||||
# support save metadata for latent sharing
|
||||
prompt_info = ""
|
||||
if prompt is not None:
|
||||
prompt_info = json.dumps(prompt)
|
||||
|
||||
metadata = {"prompt": prompt_info}
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
file = f"{filename}_{counter:05}_.latent"
|
||||
file = os.path.join(full_output_folder, file)
|
||||
|
||||
output = {}
|
||||
output["latent_tensor"] = samples["samples"]
|
||||
|
||||
safetensors.torch.save_file(output, file, metadata=metadata)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
class LoadLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
|
||||
return {"required": {"latent": [sorted(files), ]}, }
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
RETURN_TYPES = ("LATENT", )
|
||||
FUNCTION = "load"
|
||||
|
||||
def load(self, latent):
|
||||
latent_path = folder_paths.get_annotated_filepath(latent)
|
||||
latent = safetensors.torch.load_file(latent_path, device="cpu")
|
||||
samples = {"samples": latent["latent_tensor"].float()}
|
||||
return (samples, )
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, latent):
|
||||
image_path = folder_paths.get_annotated_filepath(latent)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, latent):
|
||||
if not folder_paths.exists_annotated_filepath(latent):
|
||||
return "Invalid latent file: {}".format(latent)
|
||||
return True
|
||||
|
||||
|
||||
class CheckpointLoader:
|
||||
@classmethod
|
||||
@ -226,7 +359,10 @@ class DiffusersLoader:
|
||||
paths = []
|
||||
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||
if os.path.exists(search_path):
|
||||
paths += next(os.walk(search_path))[1]
|
||||
for root, subdir, files in os.walk(search_path, followlinks=True):
|
||||
if "model_index.json" in files:
|
||||
paths.append(os.path.relpath(root, start=search_path))
|
||||
|
||||
return {"required": {"model_path": (paths,), }}
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
@ -236,12 +372,12 @@ class DiffusersLoader:
|
||||
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
|
||||
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||
if os.path.exists(search_path):
|
||||
paths = next(os.walk(search_path))[1]
|
||||
if model_path in paths:
|
||||
model_path = os.path.join(search_path, model_path)
|
||||
path = os.path.join(search_path, model_path)
|
||||
if os.path.exists(path):
|
||||
model_path = path
|
||||
break
|
||||
|
||||
return comfy.diffusers_convert.load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
|
||||
|
||||
class unCLIPCheckpointLoader:
|
||||
@ -373,7 +509,6 @@ class ControlNetApply:
|
||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||
c = []
|
||||
control_hint = image.movedim(-1,1)
|
||||
print(control_hint.shape)
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
c_net = control_net.copy().set_cond_hint(control_hint, strength)
|
||||
@ -490,6 +625,51 @@ class unCLIPConditioning:
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class GLIGENLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}}
|
||||
|
||||
RETURN_TYPES = ("GLIGEN",)
|
||||
FUNCTION = "load_gligen"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_gligen(self, gligen_name):
|
||||
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
|
||||
gligen = comfy.sd.load_gligen(gligen_path)
|
||||
return (gligen,)
|
||||
|
||||
class GLIGENTextBoxApply:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning_to": ("CONDITIONING", ),
|
||||
"clip": ("CLIP", ),
|
||||
"gligen_textbox_model": ("GLIGEN", ),
|
||||
"text": ("STRING", {"multiline": True}),
|
||||
"width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "conditioning/gligen"
|
||||
|
||||
def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
|
||||
c = []
|
||||
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
|
||||
for t in conditioning_to:
|
||||
n = [t[0], t[1].copy()]
|
||||
position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
|
||||
prev = []
|
||||
if "gligen" in n[1]:
|
||||
prev = n[1]['gligen'][2]
|
||||
|
||||
n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params)
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class EmptyLatentImage:
|
||||
def __init__(self, device="cpu"):
|
||||
@ -497,8 +677,8 @@ class EmptyLatentImage:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
|
||||
return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
@ -510,6 +690,63 @@ class EmptyLatentImage:
|
||||
return ({"samples":latent}, )
|
||||
|
||||
|
||||
class LatentFromBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
|
||||
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "frombatch"
|
||||
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
def frombatch(self, samples, batch_index, length):
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||
length = min(s_in.shape[0] - batch_index, length)
|
||||
s["samples"] = s_in[batch_index:batch_index + length].clone()
|
||||
if "noise_mask" in samples:
|
||||
masks = samples["noise_mask"]
|
||||
if masks.shape[0] == 1:
|
||||
s["noise_mask"] = masks.clone()
|
||||
else:
|
||||
if masks.shape[0] < s_in.shape[0]:
|
||||
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
|
||||
s["noise_mask"] = masks[batch_index:batch_index + length].clone()
|
||||
if "batch_index" not in s:
|
||||
s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
|
||||
else:
|
||||
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
|
||||
return (s,)
|
||||
|
||||
class RepeatLatentBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "repeat"
|
||||
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
def repeat(self, samples, amount):
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
|
||||
s["samples"] = s_in.repeat((amount, 1,1,1))
|
||||
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
|
||||
masks = samples["noise_mask"]
|
||||
if masks.shape[0] < s_in.shape[0]:
|
||||
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
|
||||
s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
|
||||
if "batch_index" in s:
|
||||
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
|
||||
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
|
||||
return (s,)
|
||||
|
||||
class LatentUpscale:
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
||||
@ -517,9 +754,10 @@ class LatentUpscale:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
|
||||
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"crop": ("BOOL", {"on": "center", "off": "disabled"}),}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "upscale"
|
||||
|
||||
@ -621,8 +859,8 @@ class LatentCrop:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
}}
|
||||
@ -647,16 +885,6 @@ class LatentCrop:
|
||||
new_width = width // 8
|
||||
to_x = new_width + x
|
||||
to_y = new_height + y
|
||||
def enforce_image_dim(d, to_d, max_d):
|
||||
if to_d > max_d:
|
||||
leftover = (to_d - max_d) % 8
|
||||
to_d = max_d
|
||||
d -= leftover
|
||||
return (d, to_d)
|
||||
|
||||
#make sure size is always multiple of 64
|
||||
x, to_x = enforce_image_dim(x, to_x, samples.shape[3])
|
||||
y, to_y = enforce_image_dim(y, to_y, samples.shape[2])
|
||||
s['samples'] = samples[:,:,y:to_y, x:to_x]
|
||||
return (s,)
|
||||
|
||||
@ -673,72 +901,30 @@ class SetLatentNoiseMask:
|
||||
|
||||
def set_mask(self, samples, mask):
|
||||
s = samples.copy()
|
||||
s["noise_mask"] = mask
|
||||
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||
return (s,)
|
||||
|
||||
|
||||
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
latent_image = latent["samples"]
|
||||
noise_mask = None
|
||||
device = model_management.get_torch_device()
|
||||
|
||||
if disable_noise:
|
||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||
else:
|
||||
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu")
|
||||
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
||||
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
|
||||
|
||||
noise_mask = None
|
||||
if "noise_mask" in latent:
|
||||
noise_mask = latent['noise_mask']
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear")
|
||||
noise_mask = noise_mask.round()
|
||||
noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
|
||||
noise_mask = torch.cat([noise_mask] * noise.shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
noise_mask = latent["noise_mask"]
|
||||
|
||||
real_model = None
|
||||
model_management.load_model_gpu(model)
|
||||
real_model = model.model
|
||||
|
||||
noise = noise.to(device)
|
||||
latent_image = latent_image.to(device)
|
||||
|
||||
positive_copy = []
|
||||
negative_copy = []
|
||||
|
||||
control_nets = []
|
||||
for p in positive:
|
||||
t = p[0]
|
||||
if t.shape[0] < noise.shape[0]:
|
||||
t = torch.cat([t] * noise.shape[0])
|
||||
t = t.to(device)
|
||||
if 'control' in p[1]:
|
||||
control_nets += [p[1]['control']]
|
||||
positive_copy += [[t] + p[1:]]
|
||||
for n in negative:
|
||||
t = n[0]
|
||||
if t.shape[0] < noise.shape[0]:
|
||||
t = torch.cat([t] * noise.shape[0])
|
||||
t = t.to(device)
|
||||
if 'control' in n[1]:
|
||||
control_nets += [n[1]['control']]
|
||||
negative_copy += [[t] + n[1:]]
|
||||
|
||||
control_net_models = []
|
||||
for x in control_nets:
|
||||
control_net_models += x.get_control_models()
|
||||
model_management.load_controlnet_gpu(control_net_models)
|
||||
|
||||
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
else:
|
||||
#other samplers
|
||||
pass
|
||||
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
|
||||
samples = samples.cpu()
|
||||
for c in control_nets:
|
||||
c.cleanup()
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
def callback(step, x0, x, total_steps):
|
||||
pbar.update_absolute(step + 1, total_steps)
|
||||
|
||||
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback)
|
||||
out = latent.copy()
|
||||
out["samples"] = samples
|
||||
return (out, )
|
||||
@ -817,39 +1003,7 @@ class SaveImage:
|
||||
CATEGORY = "image"
|
||||
|
||||
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
def map_filename(filename):
|
||||
prefix_len = len(os.path.basename(filename_prefix))
|
||||
prefix = filename[:prefix_len + 1]
|
||||
try:
|
||||
digits = int(filename[prefix_len + 1:].split('_')[0])
|
||||
except:
|
||||
digits = 0
|
||||
return (digits, prefix)
|
||||
|
||||
def compute_vars(input):
|
||||
input = input.replace("%width%", str(images[0].shape[1]))
|
||||
input = input.replace("%height%", str(images[0].shape[0]))
|
||||
return input
|
||||
|
||||
filename_prefix = compute_vars(filename_prefix)
|
||||
|
||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||
|
||||
full_output_folder = os.path.join(self.output_dir, subfolder)
|
||||
|
||||
if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir:
|
||||
print("Saving image outside the output folder is not allowed.")
|
||||
return {}
|
||||
|
||||
try:
|
||||
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
||||
except ValueError:
|
||||
counter = 1
|
||||
except FileNotFoundError:
|
||||
os.makedirs(full_output_folder, exist_ok=True)
|
||||
counter = 1
|
||||
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results = list()
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
@ -867,7 +1021,7 @@ class SaveImage:
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
});
|
||||
})
|
||||
counter += 1
|
||||
|
||||
return { "ui": { "images": results } }
|
||||
@ -888,8 +1042,9 @@ class LoadImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
return {"required":
|
||||
{"image": (sorted(os.listdir(input_dir)), )},
|
||||
{"image": (sorted(files), )},
|
||||
}
|
||||
|
||||
CATEGORY = "image"
|
||||
@ -897,8 +1052,7 @@ class LoadImage:
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "load_image"
|
||||
def load_image(self, image):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
image_path = os.path.join(input_dir, image)
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
i = Image.open(image_path)
|
||||
image = i.convert("RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
@ -912,29 +1066,36 @@ class LoadImage:
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
image_path = os.path.join(input_dir, image)
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, image):
|
||||
if not folder_paths.exists_annotated_filepath(image):
|
||||
return "Invalid image file: {}".format(image)
|
||||
|
||||
return True
|
||||
|
||||
class LoadImageMask:
|
||||
_color_channels = ["alpha", "red", "green", "blue"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
return {"required":
|
||||
{"image": (sorted(os.listdir(input_dir)), ),
|
||||
"channel": (["alpha", "red", "green", "blue"], ),}
|
||||
{"image": (sorted(files), ),
|
||||
"channel": (s._color_channels, ), }
|
||||
}
|
||||
|
||||
CATEGORY = "image"
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "load_image"
|
||||
def load_image(self, image, channel):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
image_path = os.path.join(input_dir, image)
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
i = Image.open(image_path)
|
||||
if i.getbands() != ("R", "G", "B", "A"):
|
||||
i = i.convert("RGBA")
|
||||
@ -951,13 +1112,22 @@ class LoadImageMask:
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image, channel):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
image_path = os.path.join(input_dir, image)
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, image, channel):
|
||||
if not folder_paths.exists_annotated_filepath(image):
|
||||
return "Invalid image file: {}".format(image)
|
||||
|
||||
if channel not in s._color_channels:
|
||||
return "Invalid color channel: {}".format(channel)
|
||||
|
||||
return True
|
||||
|
||||
class ImageScale:
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
||||
|
||||
@ -1002,10 +1172,10 @@ class ImagePadForOutpaint:
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}
|
||||
}
|
||||
@ -1069,6 +1239,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"VAELoader": VAELoader,
|
||||
"EmptyLatentImage": EmptyLatentImage,
|
||||
"LatentUpscale": LatentUpscale,
|
||||
"LatentFromBatch": LatentFromBatch,
|
||||
"RepeatLatentBatch": RepeatLatentBatch,
|
||||
"SaveImage": SaveImage,
|
||||
"PreviewImage": PreviewImage,
|
||||
"LoadImage": LoadImage,
|
||||
@ -1076,8 +1248,10 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ImageScale": ImageScale,
|
||||
"ImageInvert": ImageInvert,
|
||||
"ImagePadForOutpaint": ImagePadForOutpaint,
|
||||
"ConditioningAverage ": ConditioningAverage ,
|
||||
"ConditioningCombine": ConditioningCombine,
|
||||
"ConditioningSetArea": ConditioningSetArea,
|
||||
"ConditioningSetMask": ConditioningSetMask,
|
||||
"KSamplerAdvanced": KSamplerAdvanced,
|
||||
"SetLatentNoiseMask": SetLatentNoiseMask,
|
||||
"LatentComposite": LatentComposite,
|
||||
@ -1098,8 +1272,14 @@ NODE_CLASS_MAPPINGS = {
|
||||
"VAEEncodeTiled": VAEEncodeTiled,
|
||||
"TomePatchModel": TomePatchModel,
|
||||
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
|
||||
"GLIGENLoader": GLIGENLoader,
|
||||
"GLIGENTextBoxApply": GLIGENTextBoxApply,
|
||||
|
||||
"CheckpointLoader": CheckpointLoader,
|
||||
"DiffusersLoader": DiffusersLoader,
|
||||
|
||||
"LoadLatent": LoadLatent,
|
||||
"SaveLatent": SaveLatent
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@ -1123,7 +1303,9 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CLIPTextEncode": "CLIP Text Encode (Prompt)",
|
||||
"CLIPSetLastLayer": "CLIP Set Last Layer",
|
||||
"ConditioningCombine": "Conditioning (Combine)",
|
||||
"ConditioningAverage ": "Conditioning (Average)",
|
||||
"ConditioningSetArea": "Conditioning (Set Area)",
|
||||
"ConditioningSetMask": "Conditioning (Set Mask)",
|
||||
"ControlNetApply": "Apply ControlNet",
|
||||
# Latent
|
||||
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
||||
@ -1136,6 +1318,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"EmptyLatentImage": "Empty Latent Image",
|
||||
"LatentUpscale": "Upscale Latent",
|
||||
"LatentComposite": "Latent Composite",
|
||||
"LatentFromBatch" : "Latent From Batch",
|
||||
"RepeatLatentBatch": "Repeat Latent Batch",
|
||||
# Image
|
||||
"SaveImage": "Save Image",
|
||||
"PreviewImage": "Preview Image",
|
||||
@ -1167,24 +1351,45 @@ def load_custom_node(module_path):
|
||||
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
|
||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||
return True
|
||||
else:
|
||||
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
print(f"Cannot import {module_path} module for custom nodes:", e)
|
||||
return False
|
||||
|
||||
def load_custom_nodes():
|
||||
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")
|
||||
possible_modules = os.listdir(CUSTOM_NODE_PATH)
|
||||
if "__pycache__" in possible_modules:
|
||||
possible_modules.remove("__pycache__")
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
node_import_times = []
|
||||
for custom_node_path in node_paths:
|
||||
possible_modules = os.listdir(custom_node_path)
|
||||
if "__pycache__" in possible_modules:
|
||||
possible_modules.remove("__pycache__")
|
||||
|
||||
for possible_module in possible_modules:
|
||||
module_path = os.path.join(CUSTOM_NODE_PATH, possible_module)
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||
load_custom_node(module_path)
|
||||
for possible_module in possible_modules:
|
||||
module_path = os.path.join(custom_node_path, possible_module)
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||
if module_path.endswith(".disabled"): continue
|
||||
time_before = time.perf_counter()
|
||||
success = load_custom_node(module_path)
|
||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
|
||||
if len(node_import_times) > 0:
|
||||
print("\nImport times for custom nodes:")
|
||||
for n in sorted(node_import_times):
|
||||
if n[2]:
|
||||
import_message = ""
|
||||
else:
|
||||
import_message = " (IMPORT FAILED)"
|
||||
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
|
||||
print()
|
||||
|
||||
def init_custom_nodes():
|
||||
load_custom_nodes()
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
||||
load_custom_nodes()
|
||||
|
||||
@ -47,7 +47,7 @@
|
||||
" !git pull\n",
|
||||
"\n",
|
||||
"!echo -= Install dependencies =-\n",
|
||||
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118"
|
||||
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -119,14 +119,30 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"# ControlNet\n",
|
||||
"#!wget -c https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/control_depth-fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/control_scribble-fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/control_openpose-fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_ip2p_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_shuffle_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_canny_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11f1p_sd15_depth_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_inpaint_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_lineart_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_mlsd_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_normalbae_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_openpose_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_scribble_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_seg_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_softedge_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15s2_lineart_anime_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11u_sd15_tile_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Controlnet Preprocessor nodes by Fannovel16\n",
|
||||
"#!cd custom_nodes && git clone https://github.com/Fannovel16/comfy_controlnet_preprocessors; cd comfy_controlnet_preprocessors && python install.py\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# GLIGEN\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/GLIGEN_pruned_safetensors/resolve/main/gligen_sd14_textbox_pruned_fp16.safetensors -P ./models/gligen/\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# ESRGAN upscale model\n",
|
||||
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
|
||||
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
|
||||
@ -159,6 +175,8 @@
|
||||
"import threading\n",
|
||||
"import time\n",
|
||||
"import socket\n",
|
||||
"import urllib.request\n",
|
||||
"\n",
|
||||
"def iframe_thread(port):\n",
|
||||
" while True:\n",
|
||||
" time.sleep(0.5)\n",
|
||||
@ -167,7 +185,9 @@
|
||||
" if result == 0:\n",
|
||||
" break\n",
|
||||
" sock.close()\n",
|
||||
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n",
|
||||
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n",
|
||||
"\n",
|
||||
" print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n",
|
||||
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
|
||||
" for line in p.stdout:\n",
|
||||
" print(line.decode(), end='')\n",
|
||||
|
||||
155
server.py
155
server.py
@ -7,6 +7,9 @@ import execution
|
||||
import uuid
|
||||
import json
|
||||
import glob
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
@ -78,7 +81,7 @@ class PromptServer():
|
||||
# Reusing existing session, remove old
|
||||
self.sockets.pop(sid, None)
|
||||
else:
|
||||
sid = uuid.uuid4().hex
|
||||
sid = uuid.uuid4().hex
|
||||
|
||||
self.sockets[sid] = ws
|
||||
|
||||
@ -110,42 +113,96 @@ class PromptServer():
|
||||
files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True)
|
||||
return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)))
|
||||
|
||||
@routes.post("/upload/image")
|
||||
async def upload_image(request):
|
||||
upload_dir = folder_paths.get_input_directory()
|
||||
def get_dir_by_type(dir_type):
|
||||
if dir_type is None:
|
||||
dir_type = "input"
|
||||
|
||||
if not os.path.exists(upload_dir):
|
||||
os.makedirs(upload_dir)
|
||||
|
||||
post = await request.post()
|
||||
if dir_type == "input":
|
||||
type_dir = folder_paths.get_input_directory()
|
||||
elif dir_type == "temp":
|
||||
type_dir = folder_paths.get_temp_directory()
|
||||
elif dir_type == "output":
|
||||
type_dir = folder_paths.get_output_directory()
|
||||
|
||||
return type_dir, dir_type
|
||||
|
||||
def image_upload(post, image_save_function=None):
|
||||
image = post.get("image")
|
||||
overwrite = post.get("overwrite")
|
||||
|
||||
image_upload_type = post.get("type")
|
||||
upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
|
||||
|
||||
if image and image.file:
|
||||
filename = image.filename
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
subfolder = post.get("subfolder", "")
|
||||
full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder))
|
||||
|
||||
if os.path.commonpath((upload_dir, os.path.abspath(full_output_folder))) != upload_dir:
|
||||
return web.Response(status=400)
|
||||
|
||||
if not os.path.exists(full_output_folder):
|
||||
os.makedirs(full_output_folder)
|
||||
|
||||
split = os.path.splitext(filename)
|
||||
i = 1
|
||||
while os.path.exists(os.path.join(upload_dir, filename)):
|
||||
filename = f"{split[0]} ({i}){split[1]}"
|
||||
i += 1
|
||||
filepath = os.path.join(full_output_folder, filename)
|
||||
|
||||
filepath = os.path.join(upload_dir, filename)
|
||||
if overwrite is not None and (overwrite == "true" or overwrite == "1"):
|
||||
pass
|
||||
else:
|
||||
i = 1
|
||||
while os.path.exists(filepath):
|
||||
filename = f"{split[0]} ({i}){split[1]}"
|
||||
filepath = os.path.join(full_output_folder, filename)
|
||||
i += 1
|
||||
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(image.file.read())
|
||||
|
||||
return web.json_response({"name" : filename})
|
||||
if image_save_function is not None:
|
||||
image_save_function(image, post, filepath)
|
||||
else:
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(image.file.read())
|
||||
|
||||
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
|
||||
else:
|
||||
return web.Response(status=400)
|
||||
|
||||
@routes.post("/upload/image")
|
||||
async def upload_image(request):
|
||||
post = await request.post()
|
||||
return image_upload(post)
|
||||
|
||||
@routes.post("/upload/mask")
|
||||
async def upload_mask(request):
|
||||
post = await request.post()
|
||||
|
||||
def image_save_function(image, post, filepath):
|
||||
original_pil = Image.open(post.get("original_image").file).convert('RGBA')
|
||||
mask_pil = Image.open(image.file).convert('RGBA')
|
||||
|
||||
# alpha copy
|
||||
new_alpha = mask_pil.getchannel('A')
|
||||
original_pil.putalpha(new_alpha)
|
||||
original_pil.save(filepath, compress_level=4)
|
||||
|
||||
return image_upload(post, image_save_function)
|
||||
|
||||
@routes.get("/view")
|
||||
async def view_image(request):
|
||||
if "filename" in request.rel_url.query:
|
||||
type = request.rel_url.query.get("type", "output")
|
||||
output_dir = folder_paths.get_directory_by_type(type)
|
||||
filename = request.rel_url.query["filename"]
|
||||
filename,output_dir = folder_paths.annotated_filepath(filename)
|
||||
|
||||
# validation for security: prevent accessing arbitrary path
|
||||
if filename[0] == '/' or '..' in filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
if output_dir is None:
|
||||
type = request.rel_url.query.get("type", "output")
|
||||
output_dir = folder_paths.get_directory_by_type(type)
|
||||
|
||||
if output_dir is None:
|
||||
return web.Response(status=400)
|
||||
|
||||
@ -155,13 +212,49 @@ class PromptServer():
|
||||
return web.Response(status=403)
|
||||
output_dir = full_output_dir
|
||||
|
||||
filename = request.rel_url.query["filename"]
|
||||
filename = os.path.basename(filename)
|
||||
file = os.path.join(output_dir, filename)
|
||||
|
||||
if os.path.isfile(file):
|
||||
return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
|
||||
if 'channel' not in request.rel_url.query:
|
||||
channel = 'rgba'
|
||||
else:
|
||||
channel = request.rel_url.query["channel"]
|
||||
|
||||
if channel == 'rgb':
|
||||
with Image.open(file) as img:
|
||||
if img.mode == "RGBA":
|
||||
r, g, b, a = img.split()
|
||||
new_img = Image.merge('RGB', (r, g, b))
|
||||
else:
|
||||
new_img = img.convert("RGB")
|
||||
|
||||
buffer = BytesIO()
|
||||
new_img.save(buffer, format='PNG')
|
||||
buffer.seek(0)
|
||||
|
||||
return web.Response(body=buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
|
||||
elif channel == 'a':
|
||||
with Image.open(file) as img:
|
||||
if img.mode == "RGBA":
|
||||
_, _, _, a = img.split()
|
||||
else:
|
||||
a = Image.new('L', img.size, 255)
|
||||
|
||||
# alpha img
|
||||
alpha_img = Image.new('RGBA', img.size)
|
||||
alpha_img.putalpha(a)
|
||||
alpha_buffer = BytesIO()
|
||||
alpha_img.save(alpha_buffer, format='PNG')
|
||||
alpha_buffer.seek(0)
|
||||
|
||||
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
else:
|
||||
return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
|
||||
return web.Response(status=404)
|
||||
|
||||
@routes.get("/prompt")
|
||||
@ -176,6 +269,7 @@ class PromptServer():
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
info['name'] = x
|
||||
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x
|
||||
@ -225,14 +319,15 @@ class PromptServer():
|
||||
if "client_id" in json_data:
|
||||
extra_data["client_id"] = json_data["client_id"]
|
||||
if valid[0]:
|
||||
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
|
||||
prompt_id = str(uuid.uuid4())
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2]))
|
||||
return web.json_response({"prompt_id": prompt_id})
|
||||
else:
|
||||
resp_code = 400
|
||||
out_string = valid[1]
|
||||
print("invalid prompt:", valid[1])
|
||||
return web.json_response({"error": valid[1]}, status=400)
|
||||
else:
|
||||
return web.json_response({"error": "no prompt"}, status=400)
|
||||
|
||||
return web.Response(body=out_string, status=resp_code)
|
||||
|
||||
@routes.post("/queue")
|
||||
async def post_queue(request):
|
||||
json_data = await request.json()
|
||||
@ -242,9 +337,9 @@ class PromptServer():
|
||||
if "delete" in json_data:
|
||||
to_delete = json_data['delete']
|
||||
for id_to_delete in to_delete:
|
||||
delete_func = lambda a: a[1] == int(id_to_delete)
|
||||
delete_func = lambda a: a[1] == id_to_delete
|
||||
self.prompt_queue.delete_queue_item(delete_func)
|
||||
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
@routes.post("/interrupt")
|
||||
@ -268,7 +363,7 @@ class PromptServer():
|
||||
def add_routes(self):
|
||||
self.app.add_routes(self.routes)
|
||||
self.app.add_routes([
|
||||
web.static('/', self.web_root),
|
||||
web.static('/', self.web_root, follow_symlinks=True),
|
||||
])
|
||||
|
||||
def get_queue_info(self):
|
||||
|
||||
166
web/extensions/core/clipspace.js
Normal file
166
web/extensions/core/clipspace.js
Normal file
@ -0,0 +1,166 @@
|
||||
import { app } from "/scripts/app.js";
|
||||
import { ComfyDialog, $el } from "/scripts/ui.js";
|
||||
import { ComfyApp } from "/scripts/app.js";
|
||||
|
||||
export class ClipspaceDialog extends ComfyDialog {
|
||||
static items = [];
|
||||
static instance = null;
|
||||
|
||||
static registerButton(name, contextPredicate, callback) {
|
||||
const item =
|
||||
$el("button", {
|
||||
type: "button",
|
||||
textContent: name,
|
||||
contextPredicate: contextPredicate,
|
||||
onclick: callback
|
||||
})
|
||||
|
||||
ClipspaceDialog.items.push(item);
|
||||
}
|
||||
|
||||
static invalidatePreview() {
|
||||
if(ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0) {
|
||||
const img_preview = document.getElementById("clipspace_preview");
|
||||
if(img_preview) {
|
||||
img_preview.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
|
||||
img_preview.style.maxHeight = "100%";
|
||||
img_preview.style.maxWidth = "100%";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static invalidate() {
|
||||
if(ClipspaceDialog.instance) {
|
||||
const self = ClipspaceDialog.instance;
|
||||
// allow reconstruct controls when copying from non-image to image content.
|
||||
const children = $el("div.comfy-modal-content", [ self.createImgSettings(), ...self.createButtons() ]);
|
||||
|
||||
if(self.element) {
|
||||
// update
|
||||
self.element.removeChild(self.element.firstChild);
|
||||
self.element.appendChild(children);
|
||||
}
|
||||
else {
|
||||
// new
|
||||
self.element = $el("div.comfy-modal", { parent: document.body }, [children,]);
|
||||
}
|
||||
|
||||
if(self.element.children[0].children.length <= 1) {
|
||||
self.element.children[0].appendChild($el("p", {}, ["Unable to find the features to edit content of a format stored in the current Clipspace."]));
|
||||
}
|
||||
|
||||
ClipspaceDialog.invalidatePreview();
|
||||
}
|
||||
}
|
||||
|
||||
constructor() {
|
||||
super();
|
||||
}
|
||||
|
||||
createButtons(self) {
|
||||
const buttons = [];
|
||||
|
||||
for(let idx in ClipspaceDialog.items) {
|
||||
const item = ClipspaceDialog.items[idx];
|
||||
if(!item.contextPredicate || item.contextPredicate())
|
||||
buttons.push(ClipspaceDialog.items[idx]);
|
||||
}
|
||||
|
||||
buttons.push(
|
||||
$el("button", {
|
||||
type: "button",
|
||||
textContent: "Close",
|
||||
onclick: () => { this.close(); }
|
||||
})
|
||||
);
|
||||
|
||||
return buttons;
|
||||
}
|
||||
|
||||
createImgSettings() {
|
||||
if(ComfyApp.clipspace.imgs) {
|
||||
const combo_items = [];
|
||||
const imgs = ComfyApp.clipspace.imgs;
|
||||
|
||||
for(let i=0; i < imgs.length; i++) {
|
||||
combo_items.push($el("option", {value:i}, [`${i}`]));
|
||||
}
|
||||
|
||||
const combo1 = $el("select",
|
||||
{id:"clipspace_img_selector", onchange:(event) => {
|
||||
ComfyApp.clipspace['selectedIndex'] = event.target.selectedIndex;
|
||||
ClipspaceDialog.invalidatePreview();
|
||||
} }, combo_items);
|
||||
|
||||
const row1 =
|
||||
$el("tr", {},
|
||||
[
|
||||
$el("td", {}, [$el("font", {color:"white"}, ["Select Image"])]),
|
||||
$el("td", {}, [combo1])
|
||||
]);
|
||||
|
||||
|
||||
const combo2 = $el("select",
|
||||
{id:"clipspace_img_paste_mode", onchange:(event) => {
|
||||
ComfyApp.clipspace['img_paste_mode'] = event.target.value;
|
||||
} },
|
||||
[
|
||||
$el("option", {value:'selected'}, 'selected'),
|
||||
$el("option", {value:'all'}, 'all')
|
||||
]);
|
||||
combo2.value = ComfyApp.clipspace['img_paste_mode'];
|
||||
|
||||
const row2 =
|
||||
$el("tr", {},
|
||||
[
|
||||
$el("td", {}, [$el("font", {color:"white"}, ["Paste Mode"])]),
|
||||
$el("td", {}, [combo2])
|
||||
]);
|
||||
|
||||
const td = $el("td", {align:'center', width:'100px', height:'100px', colSpan:'2'},
|
||||
[ $el("img",{id:"clipspace_preview", ondragstart:() => false},[]) ]);
|
||||
|
||||
const row3 =
|
||||
$el("tr", {}, [td]);
|
||||
|
||||
return $el("table", {}, [row1, row2, row3]);
|
||||
}
|
||||
else {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
createImgPreview() {
|
||||
if(ComfyApp.clipspace.imgs) {
|
||||
return $el("img",{id:"clipspace_preview", ondragstart:() => false});
|
||||
}
|
||||
else
|
||||
return [];
|
||||
}
|
||||
|
||||
show() {
|
||||
const img_preview = document.getElementById("clipspace_preview");
|
||||
ClipspaceDialog.invalidate();
|
||||
|
||||
this.element.style.display = "block";
|
||||
}
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.Clipspace",
|
||||
init(app) {
|
||||
app.openClipspace =
|
||||
function () {
|
||||
if(!ClipspaceDialog.instance) {
|
||||
ClipspaceDialog.instance = new ClipspaceDialog(app);
|
||||
ComfyApp.clipspace_invalidate_handler = ClipspaceDialog.invalidate;
|
||||
}
|
||||
|
||||
if(ComfyApp.clipspace) {
|
||||
ClipspaceDialog.instance.show();
|
||||
}
|
||||
else
|
||||
app.ui.dialog.show("Clipspace is Empty!");
|
||||
};
|
||||
}
|
||||
});
|
||||
@ -107,7 +107,7 @@ const colorPalettes = {
|
||||
"descrip-text": "#444",
|
||||
"drag-text": "#555",
|
||||
"error-text": "#F44336",
|
||||
"border-color": "#CCC"
|
||||
"border-color": "#888"
|
||||
}
|
||||
},
|
||||
},
|
||||
@ -232,10 +232,27 @@ app.registerExtension({
|
||||
"name": "My Color Palette",
|
||||
"colors": {
|
||||
"node_slot": {
|
||||
},
|
||||
"litegraph_base": {
|
||||
},
|
||||
"comfy_base": {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Copy over missing keys from default color palette
|
||||
const defaultColorPalette = colorPalettes[defaultColorPaletteId];
|
||||
for (const key in defaultColorPalette.colors.litegraph_base) {
|
||||
if (!colorPalette.colors.litegraph_base[key]) {
|
||||
colorPalette.colors.litegraph_base[key] = "";
|
||||
}
|
||||
}
|
||||
for (const key in defaultColorPalette.colors.comfy_base) {
|
||||
if (!colorPalette.colors.comfy_base[key]) {
|
||||
colorPalette.colors.comfy_base[key] = "";
|
||||
}
|
||||
}
|
||||
|
||||
return completeColorPalette(colorPalette);
|
||||
};
|
||||
|
||||
|
||||
144
web/extensions/core/editAttention.js
Normal file
144
web/extensions/core/editAttention.js
Normal file
@ -0,0 +1,144 @@
|
||||
import { app } from "/scripts/app.js";
|
||||
|
||||
// Allows you to edit the attention weight by holding ctrl (or cmd) and using the up/down arrow keys
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.EditAttention",
|
||||
init() {
|
||||
const editAttentionDelta = app.ui.settings.addSetting({
|
||||
id: "Comfy.EditAttention.Delta",
|
||||
name: "Ctrl+up/down precision",
|
||||
type: "slider",
|
||||
attrs: {
|
||||
min: 0.01,
|
||||
max: 0.5,
|
||||
step: 0.01,
|
||||
},
|
||||
defaultValue: 0.05,
|
||||
});
|
||||
|
||||
function incrementWeight(weight, delta) {
|
||||
const floatWeight = parseFloat(weight);
|
||||
if (isNaN(floatWeight)) return weight;
|
||||
const newWeight = floatWeight + delta;
|
||||
if (newWeight < 0) return "0";
|
||||
return String(Number(newWeight.toFixed(10)));
|
||||
}
|
||||
|
||||
function findNearestEnclosure(text, cursorPos) {
|
||||
let start = cursorPos, end = cursorPos;
|
||||
let openCount = 0, closeCount = 0;
|
||||
|
||||
// Find opening parenthesis before cursor
|
||||
while (start >= 0) {
|
||||
start--;
|
||||
if (text[start] === "(" && openCount === closeCount) break;
|
||||
if (text[start] === "(") openCount++;
|
||||
if (text[start] === ")") closeCount++;
|
||||
}
|
||||
if (start < 0) return false;
|
||||
|
||||
openCount = 0;
|
||||
closeCount = 0;
|
||||
|
||||
// Find closing parenthesis after cursor
|
||||
while (end < text.length) {
|
||||
if (text[end] === ")" && openCount === closeCount) break;
|
||||
if (text[end] === "(") openCount++;
|
||||
if (text[end] === ")") closeCount++;
|
||||
end++;
|
||||
}
|
||||
if (end === text.length) return false;
|
||||
|
||||
return { start: start + 1, end: end };
|
||||
}
|
||||
|
||||
function addWeightToParentheses(text) {
|
||||
const parenRegex = /^\((.*)\)$/;
|
||||
const parenMatch = text.match(parenRegex);
|
||||
|
||||
const floatRegex = /:([+-]?(\d*\.)?\d+([eE][+-]?\d+)?)/;
|
||||
const floatMatch = text.match(floatRegex);
|
||||
|
||||
if (parenMatch && !floatMatch) {
|
||||
return `(${parenMatch[1]}:1.0)`;
|
||||
} else {
|
||||
return text;
|
||||
}
|
||||
};
|
||||
|
||||
function editAttention(event) {
|
||||
const inputField = event.composedPath()[0];
|
||||
const delta = parseFloat(editAttentionDelta.value);
|
||||
|
||||
if (inputField.tagName !== "TEXTAREA") return;
|
||||
if (!(event.key === "ArrowUp" || event.key === "ArrowDown")) return;
|
||||
if (!event.ctrlKey && !event.metaKey) return;
|
||||
|
||||
event.preventDefault();
|
||||
|
||||
let start = inputField.selectionStart;
|
||||
let end = inputField.selectionEnd;
|
||||
let selectedText = inputField.value.substring(start, end);
|
||||
|
||||
// If there is no selection, attempt to find the nearest enclosure, or select the current word
|
||||
if (!selectedText) {
|
||||
const nearestEnclosure = findNearestEnclosure(inputField.value, start);
|
||||
if (nearestEnclosure) {
|
||||
start = nearestEnclosure.start;
|
||||
end = nearestEnclosure.end;
|
||||
selectedText = inputField.value.substring(start, end);
|
||||
} else {
|
||||
// Select the current word, find the start and end of the word
|
||||
const delimiters = " .,\\/!?%^*;:{}=-_`~()\r\n\t";
|
||||
|
||||
while (!delimiters.includes(inputField.value[start - 1]) && start > 0) {
|
||||
start--;
|
||||
}
|
||||
|
||||
while (!delimiters.includes(inputField.value[end]) && end < inputField.value.length) {
|
||||
end++;
|
||||
}
|
||||
|
||||
selectedText = inputField.value.substring(start, end);
|
||||
if (!selectedText) return;
|
||||
}
|
||||
}
|
||||
|
||||
// If the selection ends with a space, remove it
|
||||
if (selectedText[selectedText.length - 1] === " ") {
|
||||
selectedText = selectedText.substring(0, selectedText.length - 1);
|
||||
end -= 1;
|
||||
}
|
||||
|
||||
// If there are parentheses left and right of the selection, select them
|
||||
if (inputField.value[start - 1] === "(" && inputField.value[end] === ")") {
|
||||
start -= 1;
|
||||
end += 1;
|
||||
selectedText = inputField.value.substring(start, end);
|
||||
}
|
||||
|
||||
// If the selection is not enclosed in parentheses, add them
|
||||
if (selectedText[0] !== "(" || selectedText[selectedText.length - 1] !== ")") {
|
||||
selectedText = `(${selectedText})`;
|
||||
}
|
||||
|
||||
// If the selection does not have a weight, add a weight of 1.0
|
||||
selectedText = addWeightToParentheses(selectedText);
|
||||
|
||||
// Increment the weight
|
||||
const weightDelta = event.key === "ArrowUp" ? delta : -delta;
|
||||
const updatedText = selectedText.replace(/\((.*):(\d+(?:\.\d+)?)\)/, (match, text, weight) => {
|
||||
weight = incrementWeight(weight, weightDelta);
|
||||
if (weight == 1) {
|
||||
return text;
|
||||
} else {
|
||||
return `(${text}:${weight})`;
|
||||
}
|
||||
});
|
||||
|
||||
inputField.setRangeText(updatedText, start, end, "select");
|
||||
}
|
||||
window.addEventListener("keydown", editAttention);
|
||||
},
|
||||
});
|
||||
76
web/extensions/core/keybinds.js
Normal file
76
web/extensions/core/keybinds.js
Normal file
@ -0,0 +1,76 @@
|
||||
import { app } from "/scripts/app.js";
|
||||
|
||||
const id = "Comfy.Keybinds";
|
||||
app.registerExtension({
|
||||
name: id,
|
||||
init() {
|
||||
const keybindListener = function(event) {
|
||||
const modifierPressed = event.ctrlKey || event.metaKey;
|
||||
|
||||
// Queue prompt using ctrl or command + enter
|
||||
if (modifierPressed && (event.key === "Enter" || event.keyCode === 13 || event.keyCode === 10)) {
|
||||
app.queuePrompt(event.shiftKey ? -1 : 0);
|
||||
return;
|
||||
}
|
||||
|
||||
const target = event.composedPath()[0];
|
||||
|
||||
if (target.tagName === "INPUT" || target.tagName === "TEXTAREA") {
|
||||
return;
|
||||
}
|
||||
|
||||
const modifierKeyIdMap = {
|
||||
"s": "#comfy-save-button",
|
||||
83: "#comfy-save-button",
|
||||
"o": "#comfy-file-input",
|
||||
79: "#comfy-file-input",
|
||||
"Backspace": "#comfy-clear-button",
|
||||
8: "#comfy-clear-button",
|
||||
"Delete": "#comfy-clear-button",
|
||||
46: "#comfy-clear-button",
|
||||
"d": "#comfy-load-default-button",
|
||||
68: "#comfy-load-default-button",
|
||||
};
|
||||
|
||||
const modifierKeybindId = modifierKeyIdMap[event.key] || modifierKeyIdMap[event.keyCode];
|
||||
if (modifierPressed && modifierKeybindId) {
|
||||
event.preventDefault();
|
||||
|
||||
const elem = document.querySelector(modifierKeybindId);
|
||||
elem.click();
|
||||
return;
|
||||
}
|
||||
|
||||
// Finished Handling all modifier keybinds, now handle the rest
|
||||
if (event.ctrlKey || event.altKey || event.metaKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Close out of modals using escape
|
||||
if (event.key === "Escape" || event.keyCode === 27) {
|
||||
const modals = document.querySelectorAll(".comfy-modal");
|
||||
const modal = Array.from(modals).find(modal => window.getComputedStyle(modal).getPropertyValue("display") !== "none");
|
||||
if (modal) {
|
||||
modal.style.display = "none";
|
||||
}
|
||||
}
|
||||
|
||||
const keyIdMap = {
|
||||
"q": "#comfy-view-queue-button",
|
||||
81: "#comfy-view-queue-button",
|
||||
"h": "#comfy-view-history-button",
|
||||
72: "#comfy-view-history-button",
|
||||
"r": "#comfy-refresh-button",
|
||||
82: "#comfy-refresh-button",
|
||||
};
|
||||
|
||||
const buttonId = keyIdMap[event.key] || keyIdMap[event.keyCode];
|
||||
if (buttonId) {
|
||||
const button = document.querySelector(buttonId);
|
||||
button.click();
|
||||
}
|
||||
}
|
||||
|
||||
window.addEventListener("keydown", keybindListener, true);
|
||||
}
|
||||
});
|
||||
648
web/extensions/core/maskeditor.js
Normal file
648
web/extensions/core/maskeditor.js
Normal file
@ -0,0 +1,648 @@
|
||||
import { app } from "/scripts/app.js";
|
||||
import { ComfyDialog, $el } from "/scripts/ui.js";
|
||||
import { ComfyApp } from "/scripts/app.js";
|
||||
import { ClipspaceDialog } from "/extensions/core/clipspace.js";
|
||||
|
||||
// Helper function to convert a data URL to a Blob object
|
||||
function dataURLToBlob(dataURL) {
|
||||
const parts = dataURL.split(';base64,');
|
||||
const contentType = parts[0].split(':')[1];
|
||||
const byteString = atob(parts[1]);
|
||||
const arrayBuffer = new ArrayBuffer(byteString.length);
|
||||
const uint8Array = new Uint8Array(arrayBuffer);
|
||||
for (let i = 0; i < byteString.length; i++) {
|
||||
uint8Array[i] = byteString.charCodeAt(i);
|
||||
}
|
||||
return new Blob([arrayBuffer], { type: contentType });
|
||||
}
|
||||
|
||||
function loadedImageToBlob(image) {
|
||||
const canvas = document.createElement('canvas');
|
||||
|
||||
canvas.width = image.width;
|
||||
canvas.height = image.height;
|
||||
|
||||
const ctx = canvas.getContext('2d');
|
||||
|
||||
ctx.drawImage(image, 0, 0);
|
||||
|
||||
const dataURL = canvas.toDataURL('image/png', 1);
|
||||
const blob = dataURLToBlob(dataURL);
|
||||
|
||||
return blob;
|
||||
}
|
||||
|
||||
async function uploadMask(filepath, formData) {
|
||||
await fetch('/upload/mask', {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
}).then(response => {}).catch(error => {
|
||||
console.error('Error:', error);
|
||||
});
|
||||
|
||||
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image();
|
||||
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString();
|
||||
|
||||
if(ComfyApp.clipspace.images)
|
||||
ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath;
|
||||
|
||||
ClipspaceDialog.invalidatePreview();
|
||||
}
|
||||
|
||||
function prepareRGB(image, backupCanvas, backupCtx) {
|
||||
// paste mask data into alpha channel
|
||||
backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height);
|
||||
const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height);
|
||||
|
||||
// refine mask image
|
||||
for (let i = 0; i < backupData.data.length; i += 4) {
|
||||
if(backupData.data[i+3] == 255)
|
||||
backupData.data[i+3] = 0;
|
||||
else
|
||||
backupData.data[i+3] = 255;
|
||||
|
||||
backupData.data[i] = 0;
|
||||
backupData.data[i+1] = 0;
|
||||
backupData.data[i+2] = 0;
|
||||
}
|
||||
|
||||
backupCtx.globalCompositeOperation = 'source-over';
|
||||
backupCtx.putImageData(backupData, 0, 0);
|
||||
}
|
||||
|
||||
class MaskEditorDialog extends ComfyDialog {
|
||||
static instance = null;
|
||||
|
||||
static getInstance() {
|
||||
if(!MaskEditorDialog.instance) {
|
||||
MaskEditorDialog.instance = new MaskEditorDialog(app);
|
||||
}
|
||||
|
||||
return MaskEditorDialog.instance;
|
||||
}
|
||||
|
||||
is_layout_created = false;
|
||||
|
||||
constructor() {
|
||||
super();
|
||||
this.element = $el("div.comfy-modal", { parent: document.body },
|
||||
[ $el("div.comfy-modal-content",
|
||||
[...this.createButtons()]),
|
||||
]);
|
||||
}
|
||||
|
||||
createButtons() {
|
||||
return [];
|
||||
}
|
||||
|
||||
createButton(name, callback) {
|
||||
var button = document.createElement("button");
|
||||
button.innerText = name;
|
||||
button.addEventListener("click", callback);
|
||||
return button;
|
||||
}
|
||||
|
||||
createLeftButton(name, callback) {
|
||||
var button = this.createButton(name, callback);
|
||||
button.style.cssFloat = "left";
|
||||
button.style.marginRight = "4px";
|
||||
return button;
|
||||
}
|
||||
|
||||
createRightButton(name, callback) {
|
||||
var button = this.createButton(name, callback);
|
||||
button.style.cssFloat = "right";
|
||||
button.style.marginLeft = "4px";
|
||||
return button;
|
||||
}
|
||||
|
||||
createLeftSlider(self, name, callback) {
|
||||
const divElement = document.createElement('div');
|
||||
divElement.id = "maskeditor-slider";
|
||||
divElement.style.cssFloat = "left";
|
||||
divElement.style.fontFamily = "sans-serif";
|
||||
divElement.style.marginRight = "4px";
|
||||
divElement.style.color = "var(--input-text)";
|
||||
divElement.style.backgroundColor = "var(--comfy-input-bg)";
|
||||
divElement.style.borderRadius = "8px";
|
||||
divElement.style.borderColor = "var(--border-color)";
|
||||
divElement.style.borderStyle = "solid";
|
||||
divElement.style.fontSize = "15px";
|
||||
divElement.style.height = "21px";
|
||||
divElement.style.padding = "1px 6px";
|
||||
divElement.style.display = "flex";
|
||||
divElement.style.position = "relative";
|
||||
divElement.style.top = "2px";
|
||||
self.brush_slider_input = document.createElement('input');
|
||||
self.brush_slider_input.setAttribute('type', 'range');
|
||||
self.brush_slider_input.setAttribute('min', '1');
|
||||
self.brush_slider_input.setAttribute('max', '100');
|
||||
self.brush_slider_input.setAttribute('value', '10');
|
||||
const labelElement = document.createElement("label");
|
||||
labelElement.textContent = name;
|
||||
|
||||
divElement.appendChild(labelElement);
|
||||
divElement.appendChild(self.brush_slider_input);
|
||||
|
||||
self.brush_slider_input.addEventListener("change", callback);
|
||||
|
||||
return divElement;
|
||||
}
|
||||
|
||||
setlayout(imgCanvas, maskCanvas) {
|
||||
const self = this;
|
||||
|
||||
// If it is specified as relative, using it only as a hidden placeholder for padding is recommended
|
||||
// to prevent anomalies where it exceeds a certain size and goes outside of the window.
|
||||
var placeholder = document.createElement("div");
|
||||
placeholder.style.position = "relative";
|
||||
placeholder.style.height = "50px";
|
||||
|
||||
var bottom_panel = document.createElement("div");
|
||||
bottom_panel.style.position = "absolute";
|
||||
bottom_panel.style.bottom = "0px";
|
||||
bottom_panel.style.left = "20px";
|
||||
bottom_panel.style.right = "20px";
|
||||
bottom_panel.style.height = "50px";
|
||||
|
||||
var brush = document.createElement("div");
|
||||
brush.id = "brush";
|
||||
brush.style.backgroundColor = "transparent";
|
||||
brush.style.outline = "1px dashed black";
|
||||
brush.style.boxShadow = "0 0 0 1px white";
|
||||
brush.style.borderRadius = "50%";
|
||||
brush.style.MozBorderRadius = "50%";
|
||||
brush.style.WebkitBorderRadius = "50%";
|
||||
brush.style.position = "absolute";
|
||||
brush.style.zIndex = 8889;
|
||||
brush.style.pointerEvents = "none";
|
||||
this.brush = brush;
|
||||
this.element.appendChild(imgCanvas);
|
||||
this.element.appendChild(maskCanvas);
|
||||
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
|
||||
this.element.appendChild(bottom_panel);
|
||||
document.body.appendChild(brush);
|
||||
|
||||
var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
|
||||
self.brush_size = event.target.value;
|
||||
self.updateBrushPreview(self, null, null);
|
||||
});
|
||||
var clearButton = this.createLeftButton("Clear",
|
||||
() => {
|
||||
self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height);
|
||||
self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height);
|
||||
});
|
||||
var cancelButton = this.createRightButton("Cancel", () => {
|
||||
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
|
||||
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
|
||||
self.close();
|
||||
});
|
||||
|
||||
this.saveButton = this.createRightButton("Save", () => {
|
||||
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
|
||||
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
|
||||
self.save();
|
||||
});
|
||||
|
||||
this.element.appendChild(imgCanvas);
|
||||
this.element.appendChild(maskCanvas);
|
||||
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
|
||||
this.element.appendChild(bottom_panel);
|
||||
|
||||
bottom_panel.appendChild(clearButton);
|
||||
bottom_panel.appendChild(this.saveButton);
|
||||
bottom_panel.appendChild(cancelButton);
|
||||
bottom_panel.appendChild(brush_size_slider);
|
||||
|
||||
imgCanvas.style.position = "relative";
|
||||
imgCanvas.style.top = "200";
|
||||
imgCanvas.style.left = "0";
|
||||
|
||||
maskCanvas.style.position = "absolute";
|
||||
}
|
||||
|
||||
show() {
|
||||
if(!this.is_layout_created) {
|
||||
// layout
|
||||
const imgCanvas = document.createElement('canvas');
|
||||
const maskCanvas = document.createElement('canvas');
|
||||
const backupCanvas = document.createElement('canvas');
|
||||
|
||||
imgCanvas.id = "imageCanvas";
|
||||
maskCanvas.id = "maskCanvas";
|
||||
backupCanvas.id = "backupCanvas";
|
||||
|
||||
this.setlayout(imgCanvas, maskCanvas);
|
||||
|
||||
// prepare content
|
||||
this.imgCanvas = imgCanvas;
|
||||
this.maskCanvas = maskCanvas;
|
||||
this.backupCanvas = backupCanvas;
|
||||
this.maskCtx = maskCanvas.getContext('2d');
|
||||
this.backupCtx = backupCanvas.getContext('2d');
|
||||
|
||||
this.setEventHandler(maskCanvas);
|
||||
|
||||
this.is_layout_created = true;
|
||||
|
||||
// replacement of onClose hook since close is not real close
|
||||
const self = this;
|
||||
const observer = new MutationObserver(function(mutations) {
|
||||
mutations.forEach(function(mutation) {
|
||||
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
|
||||
if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') {
|
||||
ComfyApp.onClipspaceEditorClosed();
|
||||
}
|
||||
|
||||
self.last_display_style = self.element.style.display;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
const config = { attributes: true };
|
||||
observer.observe(this.element, config);
|
||||
}
|
||||
|
||||
this.setImages(this.imgCanvas, this.backupCanvas);
|
||||
|
||||
if(ComfyApp.clipspace_return_node) {
|
||||
this.saveButton.innerText = "Save to node";
|
||||
}
|
||||
else {
|
||||
this.saveButton.innerText = "Save";
|
||||
}
|
||||
this.saveButton.disabled = false;
|
||||
|
||||
this.element.style.display = "block";
|
||||
this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority.
|
||||
}
|
||||
|
||||
isOpened() {
|
||||
return this.element.style.display == "block";
|
||||
}
|
||||
|
||||
setImages(imgCanvas, backupCanvas) {
|
||||
const imgCtx = imgCanvas.getContext('2d');
|
||||
const backupCtx = backupCanvas.getContext('2d');
|
||||
const maskCtx = this.maskCtx;
|
||||
const maskCanvas = this.maskCanvas;
|
||||
|
||||
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
|
||||
imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height);
|
||||
maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height);
|
||||
|
||||
// image load
|
||||
const orig_image = new Image();
|
||||
window.addEventListener("resize", () => {
|
||||
// repositioning
|
||||
imgCanvas.width = window.innerWidth - 250;
|
||||
imgCanvas.height = window.innerHeight - 200;
|
||||
|
||||
// redraw image
|
||||
let drawWidth = orig_image.width;
|
||||
let drawHeight = orig_image.height;
|
||||
if (orig_image.width > imgCanvas.width) {
|
||||
drawWidth = imgCanvas.width;
|
||||
drawHeight = (drawWidth / orig_image.width) * orig_image.height;
|
||||
}
|
||||
|
||||
if (drawHeight > imgCanvas.height) {
|
||||
drawHeight = imgCanvas.height;
|
||||
drawWidth = (drawHeight / orig_image.height) * orig_image.width;
|
||||
}
|
||||
|
||||
imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
|
||||
|
||||
// update mask
|
||||
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
|
||||
maskCanvas.width = drawWidth;
|
||||
maskCanvas.height = drawHeight;
|
||||
maskCanvas.style.top = imgCanvas.offsetTop + "px";
|
||||
maskCanvas.style.left = imgCanvas.offsetLeft + "px";
|
||||
maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
|
||||
});
|
||||
|
||||
const filepath = ComfyApp.clipspace.images;
|
||||
|
||||
const touched_image = new Image();
|
||||
|
||||
touched_image.onload = function() {
|
||||
backupCanvas.width = touched_image.width;
|
||||
backupCanvas.height = touched_image.height;
|
||||
|
||||
prepareRGB(touched_image, backupCanvas, backupCtx);
|
||||
};
|
||||
|
||||
const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src)
|
||||
alpha_url.searchParams.delete('channel');
|
||||
alpha_url.searchParams.set('channel', 'a');
|
||||
touched_image.src = alpha_url;
|
||||
|
||||
// original image load
|
||||
orig_image.onload = function() {
|
||||
window.dispatchEvent(new Event('resize'));
|
||||
};
|
||||
|
||||
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
|
||||
rgb_url.searchParams.delete('channel');
|
||||
rgb_url.searchParams.set('channel', 'rgb');
|
||||
orig_image.src = rgb_url;
|
||||
this.image = orig_image;
|
||||
}
|
||||
|
||||
setEventHandler(maskCanvas) {
|
||||
maskCanvas.addEventListener("contextmenu", (event) => {
|
||||
event.preventDefault();
|
||||
});
|
||||
|
||||
const self = this;
|
||||
maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event));
|
||||
maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event));
|
||||
document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp);
|
||||
maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event));
|
||||
maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event));
|
||||
maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; });
|
||||
maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; });
|
||||
document.addEventListener('keydown', MaskEditorDialog.handleKeyDown);
|
||||
}
|
||||
|
||||
brush_size = 10;
|
||||
drawing_mode = false;
|
||||
lastx = -1;
|
||||
lasty = -1;
|
||||
lasttime = 0;
|
||||
|
||||
static handleKeyDown(event) {
|
||||
const self = MaskEditorDialog.instance;
|
||||
if (event.key === ']') {
|
||||
self.brush_size = Math.min(self.brush_size+2, 100);
|
||||
} else if (event.key === '[') {
|
||||
self.brush_size = Math.max(self.brush_size-2, 1);
|
||||
} else if(event.key === 'Enter') {
|
||||
self.save();
|
||||
}
|
||||
|
||||
self.updateBrushPreview(self);
|
||||
}
|
||||
|
||||
static handlePointerUp(event) {
|
||||
event.preventDefault();
|
||||
MaskEditorDialog.instance.drawing_mode = false;
|
||||
}
|
||||
|
||||
updateBrushPreview(self) {
|
||||
const brush = self.brush;
|
||||
|
||||
var centerX = self.cursorX;
|
||||
var centerY = self.cursorY;
|
||||
|
||||
brush.style.width = self.brush_size * 2 + "px";
|
||||
brush.style.height = self.brush_size * 2 + "px";
|
||||
brush.style.left = (centerX - self.brush_size) + "px";
|
||||
brush.style.top = (centerY - self.brush_size) + "px";
|
||||
}
|
||||
|
||||
handleWheelEvent(self, event) {
|
||||
if(event.deltaY < 0)
|
||||
self.brush_size = Math.min(self.brush_size+2, 100);
|
||||
else
|
||||
self.brush_size = Math.max(self.brush_size-2, 1);
|
||||
|
||||
self.brush_slider_input.value = self.brush_size;
|
||||
|
||||
self.updateBrushPreview(self);
|
||||
}
|
||||
|
||||
draw_move(self, event) {
|
||||
event.preventDefault();
|
||||
|
||||
this.cursorX = event.pageX;
|
||||
this.cursorY = event.pageY;
|
||||
|
||||
self.updateBrushPreview(self);
|
||||
|
||||
if (window.TouchEvent && event instanceof TouchEvent || event.buttons == 1) {
|
||||
var diff = performance.now() - self.lasttime;
|
||||
|
||||
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||
|
||||
var x = event.offsetX;
|
||||
var y = event.offsetY
|
||||
|
||||
if(event.offsetX == null) {
|
||||
x = event.targetTouches[0].clientX - maskRect.left;
|
||||
}
|
||||
|
||||
if(event.offsetY == null) {
|
||||
y = event.targetTouches[0].clientY - maskRect.top;
|
||||
}
|
||||
|
||||
var brush_size = this.brush_size;
|
||||
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||
brush_size *= event.pressure;
|
||||
this.last_pressure = event.pressure;
|
||||
}
|
||||
else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){
|
||||
// The firing interval of PointerEvents in Pen is unreliable, so it is supplemented by TouchEvents.
|
||||
brush_size *= this.last_pressure;
|
||||
}
|
||||
else {
|
||||
brush_size = this.brush_size;
|
||||
}
|
||||
|
||||
if(diff > 20 && !this.drawing_mode)
|
||||
requestAnimationFrame(() => {
|
||||
self.maskCtx.beginPath();
|
||||
self.maskCtx.fillStyle = "rgb(0,0,0)";
|
||||
self.maskCtx.globalCompositeOperation = "source-over";
|
||||
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
|
||||
self.maskCtx.fill();
|
||||
self.lastx = x;
|
||||
self.lasty = y;
|
||||
});
|
||||
else
|
||||
requestAnimationFrame(() => {
|
||||
self.maskCtx.beginPath();
|
||||
self.maskCtx.fillStyle = "rgb(0,0,0)";
|
||||
self.maskCtx.globalCompositeOperation = "source-over";
|
||||
|
||||
var dx = x - self.lastx;
|
||||
var dy = y - self.lasty;
|
||||
|
||||
var distance = Math.sqrt(dx * dx + dy * dy);
|
||||
var directionX = dx / distance;
|
||||
var directionY = dy / distance;
|
||||
|
||||
for (var i = 0; i < distance; i+=5) {
|
||||
var px = self.lastx + (directionX * i);
|
||||
var py = self.lasty + (directionY * i);
|
||||
self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false);
|
||||
self.maskCtx.fill();
|
||||
}
|
||||
self.lastx = x;
|
||||
self.lasty = y;
|
||||
});
|
||||
|
||||
self.lasttime = performance.now();
|
||||
}
|
||||
else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) {
|
||||
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
|
||||
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
|
||||
|
||||
var brush_size = this.brush_size;
|
||||
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||
brush_size *= event.pressure;
|
||||
this.last_pressure = event.pressure;
|
||||
}
|
||||
else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){
|
||||
brush_size *= this.last_pressure;
|
||||
}
|
||||
else {
|
||||
brush_size = this.brush_size;
|
||||
}
|
||||
|
||||
if(diff > 20 && !drawing_mode) // cannot tracking drawing_mode for touch event
|
||||
requestAnimationFrame(() => {
|
||||
self.maskCtx.beginPath();
|
||||
self.maskCtx.globalCompositeOperation = "destination-out";
|
||||
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
|
||||
self.maskCtx.fill();
|
||||
self.lastx = x;
|
||||
self.lasty = y;
|
||||
});
|
||||
else
|
||||
requestAnimationFrame(() => {
|
||||
self.maskCtx.beginPath();
|
||||
self.maskCtx.globalCompositeOperation = "destination-out";
|
||||
|
||||
var dx = x - self.lastx;
|
||||
var dy = y - self.lasty;
|
||||
|
||||
var distance = Math.sqrt(dx * dx + dy * dy);
|
||||
var directionX = dx / distance;
|
||||
var directionY = dy / distance;
|
||||
|
||||
for (var i = 0; i < distance; i+=5) {
|
||||
var px = self.lastx + (directionX * i);
|
||||
var py = self.lasty + (directionY * i);
|
||||
self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false);
|
||||
self.maskCtx.fill();
|
||||
}
|
||||
self.lastx = x;
|
||||
self.lasty = y;
|
||||
});
|
||||
|
||||
self.lasttime = performance.now();
|
||||
}
|
||||
}
|
||||
|
||||
handlePointerDown(self, event) {
|
||||
var brush_size = this.brush_size;
|
||||
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||
brush_size *= event.pressure;
|
||||
this.last_pressure = event.pressure;
|
||||
}
|
||||
|
||||
if ([0, 2, 5].includes(event.button)) {
|
||||
self.drawing_mode = true;
|
||||
|
||||
event.preventDefault();
|
||||
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
|
||||
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
|
||||
|
||||
self.maskCtx.beginPath();
|
||||
if (event.button == 0) {
|
||||
self.maskCtx.fillStyle = "rgb(0,0,0)";
|
||||
self.maskCtx.globalCompositeOperation = "source-over";
|
||||
} else {
|
||||
self.maskCtx.globalCompositeOperation = "destination-out";
|
||||
}
|
||||
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
|
||||
self.maskCtx.fill();
|
||||
self.lastx = x;
|
||||
self.lasty = y;
|
||||
self.lasttime = performance.now();
|
||||
}
|
||||
}
|
||||
|
||||
async save() {
|
||||
const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true});
|
||||
|
||||
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
|
||||
backupCtx.drawImage(this.maskCanvas,
|
||||
0, 0, this.maskCanvas.width, this.maskCanvas.height,
|
||||
0, 0, this.backupCanvas.width, this.backupCanvas.height);
|
||||
|
||||
// paste mask data into alpha channel
|
||||
const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height);
|
||||
|
||||
// refine mask image
|
||||
for (let i = 0; i < backupData.data.length; i += 4) {
|
||||
if(backupData.data[i+3] == 255)
|
||||
backupData.data[i+3] = 0;
|
||||
else
|
||||
backupData.data[i+3] = 255;
|
||||
|
||||
backupData.data[i] = 0;
|
||||
backupData.data[i+1] = 0;
|
||||
backupData.data[i+2] = 0;
|
||||
}
|
||||
|
||||
backupCtx.globalCompositeOperation = 'source-over';
|
||||
backupCtx.putImageData(backupData, 0, 0);
|
||||
|
||||
const formData = new FormData();
|
||||
const filename = "clipspace-mask-" + performance.now() + ".png";
|
||||
|
||||
const item =
|
||||
{
|
||||
"filename": filename,
|
||||
"subfolder": "clipspace",
|
||||
"type": "input",
|
||||
};
|
||||
|
||||
if(ComfyApp.clipspace.images)
|
||||
ComfyApp.clipspace.images[0] = item;
|
||||
|
||||
if(ComfyApp.clipspace.widgets) {
|
||||
const index = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image');
|
||||
|
||||
if(index >= 0)
|
||||
ComfyApp.clipspace.widgets[index].value = item;
|
||||
}
|
||||
|
||||
const dataURL = this.backupCanvas.toDataURL();
|
||||
const blob = dataURLToBlob(dataURL);
|
||||
|
||||
const original_blob = loadedImageToBlob(this.image);
|
||||
|
||||
formData.append('image', blob, filename);
|
||||
formData.append('original_image', original_blob);
|
||||
formData.append('type', "input");
|
||||
formData.append('subfolder', "clipspace");
|
||||
|
||||
this.saveButton.innerText = "Saving...";
|
||||
this.saveButton.disabled = true;
|
||||
await uploadMask(item, formData);
|
||||
ComfyApp.onClipspaceEditorSave();
|
||||
this.close();
|
||||
}
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.MaskEditor",
|
||||
init(app) {
|
||||
ComfyApp.open_maskeditor =
|
||||
function () {
|
||||
const dlg = MaskEditorDialog.getInstance();
|
||||
if(!dlg.isOpened()) {
|
||||
dlg.show();
|
||||
}
|
||||
};
|
||||
|
||||
const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0
|
||||
ClipspaceDialog.registerButton("MaskEditor", context_predicate, ComfyApp.open_maskeditor);
|
||||
}
|
||||
});
|
||||
41
web/extensions/core/noteNode.js
Normal file
41
web/extensions/core/noteNode.js
Normal file
@ -0,0 +1,41 @@
|
||||
import {app} from "../../scripts/app.js";
|
||||
import {ComfyWidgets} from "../../scripts/widgets.js";
|
||||
// Node that add notes to your project
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.NoteNode",
|
||||
registerCustomNodes() {
|
||||
class NoteNode {
|
||||
color=LGraphCanvas.node_colors.yellow.color;
|
||||
bgcolor=LGraphCanvas.node_colors.yellow.bgcolor;
|
||||
groupcolor = LGraphCanvas.node_colors.yellow.groupcolor;
|
||||
constructor() {
|
||||
if (!this.properties) {
|
||||
this.properties = {};
|
||||
this.properties.text="";
|
||||
}
|
||||
|
||||
ComfyWidgets.STRING(this, "", ["", {default:this.properties.text, multiline: true}], app)
|
||||
|
||||
this.serialize_widgets = true;
|
||||
this.isVirtualNode = true;
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
// Load default visibility
|
||||
|
||||
LiteGraph.registerNodeType(
|
||||
"Note",
|
||||
Object.assign(NoteNode, {
|
||||
title_mode: LiteGraph.NORMAL_TITLE,
|
||||
title: "Note",
|
||||
collapsable: true,
|
||||
})
|
||||
);
|
||||
|
||||
NoteNode.category = "utils";
|
||||
},
|
||||
});
|
||||
@ -1,21 +1,91 @@
|
||||
import { app } from "/scripts/app.js";
|
||||
|
||||
import { ComfyWidgets } from "/scripts/widgets.js";
|
||||
// Adds defaults for quickly adding nodes with middle click on the input/output
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.SlotDefaults",
|
||||
suggestionsNumber: null,
|
||||
init() {
|
||||
LiteGraph.search_filter_enabled = true;
|
||||
LiteGraph.middle_click_slot_add_default_node = true;
|
||||
LiteGraph.slot_types_default_in = {
|
||||
MODEL: "CheckpointLoaderSimple",
|
||||
LATENT: "EmptyLatentImage",
|
||||
VAE: "VAELoader",
|
||||
};
|
||||
|
||||
LiteGraph.slot_types_default_out = {
|
||||
LATENT: "VAEDecode",
|
||||
IMAGE: "SaveImage",
|
||||
CLIP: "CLIPTextEncode",
|
||||
};
|
||||
this.suggestionsNumber = app.ui.settings.addSetting({
|
||||
id: "Comfy.NodeSuggestions.number",
|
||||
name: "number of nodes suggestions",
|
||||
type: "slider",
|
||||
attrs: {
|
||||
min: 1,
|
||||
max: 100,
|
||||
step: 1,
|
||||
},
|
||||
defaultValue: 5,
|
||||
onChange: (newVal, oldVal) => {
|
||||
this.setDefaults(newVal);
|
||||
}
|
||||
});
|
||||
},
|
||||
slot_types_default_out: {},
|
||||
slot_types_default_in: {},
|
||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||
var nodeId = nodeData.name;
|
||||
var inputs = [];
|
||||
inputs = nodeData["input"]["required"]; //only show required inputs to reduce the mess also not logical to create node with optional inputs
|
||||
for (const inputKey in inputs) {
|
||||
var input = (inputs[inputKey]);
|
||||
if (typeof input[0] !== "string") continue;
|
||||
|
||||
var type = input[0]
|
||||
if (type in ComfyWidgets) {
|
||||
var customProperties = input[1]
|
||||
if (!(customProperties?.forceInput)) continue; //ignore widgets that don't force input
|
||||
}
|
||||
|
||||
if (!(type in this.slot_types_default_out)) {
|
||||
this.slot_types_default_out[type] = ["Reroute"];
|
||||
}
|
||||
if (this.slot_types_default_out[type].includes(nodeId)) continue;
|
||||
this.slot_types_default_out[type].push(nodeId);
|
||||
|
||||
// Input types have to be stored as lower case
|
||||
// Store each node that can handle this input type
|
||||
const lowerType = type.toLocaleLowerCase();
|
||||
if (!(lowerType in LiteGraph.registered_slot_in_types)) {
|
||||
LiteGraph.registered_slot_in_types[lowerType] = { nodes: [] };
|
||||
}
|
||||
LiteGraph.registered_slot_in_types[lowerType].nodes.push(nodeType.comfyClass);
|
||||
}
|
||||
|
||||
var outputs = nodeData["output"];
|
||||
for (const key in outputs) {
|
||||
var type = outputs[key];
|
||||
if (!(type in this.slot_types_default_in)) {
|
||||
this.slot_types_default_in[type] = ["Reroute"];// ["Reroute", "Primitive"]; primitive doesn't always work :'()
|
||||
}
|
||||
|
||||
this.slot_types_default_in[type].push(nodeId);
|
||||
|
||||
// Store each node that can handle this output type
|
||||
if (!(type in LiteGraph.registered_slot_out_types)) {
|
||||
LiteGraph.registered_slot_out_types[type] = { nodes: [] };
|
||||
}
|
||||
LiteGraph.registered_slot_out_types[type].nodes.push(nodeType.comfyClass);
|
||||
|
||||
if(!LiteGraph.slot_types_out.includes(type)) {
|
||||
LiteGraph.slot_types_out.push(type);
|
||||
}
|
||||
}
|
||||
var maxNum = this.suggestionsNumber.value;
|
||||
this.setDefaults(maxNum);
|
||||
},
|
||||
setDefaults(maxNum) {
|
||||
|
||||
LiteGraph.slot_types_default_out = {};
|
||||
LiteGraph.slot_types_default_in = {};
|
||||
|
||||
for (const type in this.slot_types_default_out) {
|
||||
LiteGraph.slot_types_default_out[type] = this.slot_types_default_out[type].slice(0, maxNum);
|
||||
}
|
||||
for (const type in this.slot_types_default_in) {
|
||||
LiteGraph.slot_types_default_in[type] = this.slot_types_default_in[type].slice(0, maxNum);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@ -9,7 +9,7 @@ app.registerExtension({
|
||||
app.ui.settings.addSetting({
|
||||
id: "Comfy.SnapToGrid.GridSize",
|
||||
name: "Grid Size",
|
||||
type: "number",
|
||||
type: "slider",
|
||||
attrs: {
|
||||
min: 1,
|
||||
max: 500,
|
||||
|
||||
@ -159,27 +159,33 @@ app.registerExtension({
|
||||
const r = origOnInputDblClick ? origOnInputDblClick.apply(this, arguments) : undefined;
|
||||
|
||||
const input = this.inputs[slot];
|
||||
if (input.widget && !input[ignoreDblClick]) {
|
||||
const node = LiteGraph.createNode("PrimitiveNode");
|
||||
app.graph.add(node);
|
||||
|
||||
// Calculate a position that wont directly overlap another node
|
||||
const pos = [this.pos[0] - node.size[0] - 30, this.pos[1]];
|
||||
while (isNodeAtPos(pos)) {
|
||||
pos[1] += LiteGraph.NODE_TITLE_HEIGHT;
|
||||
if (!input.widget || !input[ignoreDblClick]) {
|
||||
// Not a widget input or already handled input
|
||||
if (!(input.type in ComfyWidgets) && !(input.widget.config?.[0] instanceof Array)) {
|
||||
return r; //also Not a ComfyWidgets input or combo (do nothing)
|
||||
}
|
||||
|
||||
node.pos = pos;
|
||||
node.connect(0, this, slot);
|
||||
node.title = input.name;
|
||||
|
||||
// Prevent adding duplicates due to triple clicking
|
||||
input[ignoreDblClick] = true;
|
||||
setTimeout(() => {
|
||||
delete input[ignoreDblClick];
|
||||
}, 300);
|
||||
}
|
||||
|
||||
// Create a primitive node
|
||||
const node = LiteGraph.createNode("PrimitiveNode");
|
||||
app.graph.add(node);
|
||||
|
||||
// Calculate a position that wont directly overlap another node
|
||||
const pos = [this.pos[0] - node.size[0] - 30, this.pos[1]];
|
||||
while (isNodeAtPos(pos)) {
|
||||
pos[1] += LiteGraph.NODE_TITLE_HEIGHT;
|
||||
}
|
||||
|
||||
node.pos = pos;
|
||||
node.connect(0, this, slot);
|
||||
node.title = input.name;
|
||||
|
||||
// Prevent adding duplicates due to triple clicking
|
||||
input[ignoreDblClick] = true;
|
||||
setTimeout(() => {
|
||||
delete input[ignoreDblClick];
|
||||
}, 300);
|
||||
|
||||
return r;
|
||||
};
|
||||
},
|
||||
@ -233,7 +239,9 @@ app.registerExtension({
|
||||
// Fires before the link is made allowing us to reject it if it isn't valid
|
||||
|
||||
// No widget, we cant connect
|
||||
if (!input.widget) return false;
|
||||
if (!input.widget) {
|
||||
if (!(input.type in ComfyWidgets)) return false;
|
||||
}
|
||||
|
||||
if (this.outputs[slot].links?.length) {
|
||||
return this.#isValidConnection(input);
|
||||
@ -252,9 +260,17 @@ app.registerExtension({
|
||||
const input = theirNode.inputs[link.target_slot];
|
||||
if (!input) return;
|
||||
|
||||
const widget = input.widget;
|
||||
const { type, linkType } = getWidgetType(widget.config);
|
||||
|
||||
var _widget;
|
||||
if (!input.widget) {
|
||||
if (!(input.type in ComfyWidgets)) return;
|
||||
_widget = { "name": input.name, "config": [input.type, {}] }//fake widget
|
||||
} else {
|
||||
_widget = input.widget;
|
||||
}
|
||||
|
||||
const widget = _widget;
|
||||
const { type, linkType } = getWidgetType(widget.config);
|
||||
// Update our output to restrict to the widget type
|
||||
this.outputs[0].type = linkType;
|
||||
this.outputs[0].name = type;
|
||||
@ -274,7 +290,7 @@ app.registerExtension({
|
||||
if (type in ComfyWidgets) {
|
||||
widget = (ComfyWidgets[type](this, "value", inputData, app) || {}).widget;
|
||||
} else {
|
||||
widget = this.addWidget(type, "value", null, () => {}, {});
|
||||
widget = this.addWidget(type, "value", null, () => { }, {});
|
||||
}
|
||||
|
||||
if (node?.widgets && widget) {
|
||||
@ -284,7 +300,7 @@ app.registerExtension({
|
||||
}
|
||||
}
|
||||
|
||||
if (widget.type === "number") {
|
||||
if (widget.type === "number" || widget.type === "combo") {
|
||||
addValueControlWidget(this, widget, "fixed");
|
||||
}
|
||||
|
||||
@ -319,7 +335,20 @@ app.registerExtension({
|
||||
const config1 = this.outputs[0].widget.config;
|
||||
const config2 = input.widget.config;
|
||||
|
||||
if (config1[0] !== config2[0]) return false;
|
||||
if (config1[0] instanceof Array) {
|
||||
// These checks shouldnt actually be necessary as the types should match
|
||||
// but double checking doesn't hurt
|
||||
|
||||
// New input isnt a combo
|
||||
if (!(config2[0] instanceof Array)) return false;
|
||||
// New imput combo has a different size
|
||||
if (config1[0].length !== config2[0].length) return false;
|
||||
// New input combo has different elements
|
||||
if (config1[0].find((v, i) => config2[0][i] !== v)) return false;
|
||||
} else if (config1[0] !== config2[0]) {
|
||||
// Configs dont match
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const k in config1[1]) {
|
||||
if (k !== "default") {
|
||||
|
||||
@ -3628,6 +3628,18 @@
|
||||
return size;
|
||||
};
|
||||
|
||||
LGraphNode.prototype.inResizeCorner = function(canvasX, canvasY) {
|
||||
var rows = this.outputs ? this.outputs.length : 1;
|
||||
var outputs_offset = (this.constructor.slot_start_y || 0) + rows * LiteGraph.NODE_SLOT_HEIGHT;
|
||||
return isInsideRectangle(canvasX,
|
||||
canvasY,
|
||||
this.pos[0] + this.size[0] - 15,
|
||||
this.pos[1] + Math.max(this.size[1] - 15, outputs_offset),
|
||||
20,
|
||||
20
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* returns all the info available about a property of this node.
|
||||
*
|
||||
@ -5868,23 +5880,16 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
|
||||
//when clicked on top of a node
|
||||
//and it is not interactive
|
||||
if (node && this.allow_interaction && !skip_action && !this.read_only) {
|
||||
if (node && (this.allow_interaction || node.flags.allow_interaction) && !skip_action && !this.read_only) {
|
||||
if (!this.live_mode && !node.flags.pinned) {
|
||||
this.bringToFront(node);
|
||||
} //if it wasn't selected?
|
||||
|
||||
//not dragging mouse to connect two slots
|
||||
if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
|
||||
if ( this.allow_interaction && !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
|
||||
//Search for corner for resize
|
||||
if ( !skip_action &&
|
||||
node.resizable !== false &&
|
||||
isInsideRectangle( e.canvasX,
|
||||
e.canvasY,
|
||||
node.pos[0] + node.size[0] - 5,
|
||||
node.pos[1] + node.size[1] - 5,
|
||||
10,
|
||||
10
|
||||
)
|
||||
node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY)
|
||||
) {
|
||||
this.graph.beforeChange();
|
||||
this.resizing_node = node;
|
||||
@ -6028,7 +6033,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
}
|
||||
|
||||
//double clicking
|
||||
if (is_double_click && this.selected_nodes[node.id]) {
|
||||
if (this.allow_interaction && is_double_click && this.selected_nodes[node.id]) {
|
||||
//double click node
|
||||
if (node.onDblClick) {
|
||||
node.onDblClick( e, pos, this );
|
||||
@ -6302,6 +6307,9 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
this.dirty_canvas = true;
|
||||
}
|
||||
|
||||
//get node over
|
||||
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
|
||||
|
||||
if (this.dragging_rectangle)
|
||||
{
|
||||
this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0];
|
||||
@ -6331,14 +6339,11 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
this.ds.offset[1] += delta[1] / this.ds.scale;
|
||||
this.dirty_canvas = true;
|
||||
this.dirty_bgcanvas = true;
|
||||
} else if (this.allow_interaction && !this.read_only) {
|
||||
} else if ((this.allow_interaction || (node && node.flags.allow_interaction)) && !this.read_only) {
|
||||
if (this.connecting_node) {
|
||||
this.dirty_canvas = true;
|
||||
}
|
||||
|
||||
//get node over
|
||||
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
|
||||
|
||||
//remove mouseover flag
|
||||
for (var i = 0, l = this.graph._nodes.length; i < l; ++i) {
|
||||
if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) {
|
||||
@ -6424,16 +6429,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
|
||||
//Search for corner
|
||||
if (this.canvas) {
|
||||
if (
|
||||
isInsideRectangle(
|
||||
e.canvasX,
|
||||
e.canvasY,
|
||||
node.pos[0] + node.size[0] - 5,
|
||||
node.pos[1] + node.size[1] - 5,
|
||||
5,
|
||||
5
|
||||
)
|
||||
) {
|
||||
if (node.inResizeCorner(e.canvasX, e.canvasY)) {
|
||||
this.canvas.style.cursor = "se-resize";
|
||||
} else {
|
||||
this.canvas.style.cursor = "crosshair";
|
||||
@ -9738,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
if (show_text) {
|
||||
ctx.textAlign = "center";
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7);
|
||||
ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7);
|
||||
}
|
||||
break;
|
||||
case "toggle":
|
||||
@ -9759,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
ctx.fill();
|
||||
if (show_text) {
|
||||
ctx.fillStyle = secondary_text_color;
|
||||
if (w.name != null) {
|
||||
ctx.fillText(w.name, margin * 2, y + H * 0.7);
|
||||
const label = w.label || w.name;
|
||||
if (label != null) {
|
||||
ctx.fillText(label, margin * 2, y + H * 0.7);
|
||||
}
|
||||
ctx.fillStyle = w.value ? text_color : secondary_text_color;
|
||||
ctx.textAlign = "right";
|
||||
@ -9795,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
ctx.textAlign = "center";
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.fillText(
|
||||
w.name + " " + Number(w.value).toFixed(3),
|
||||
w.label || w.name + " " + Number(w.value).toFixed(3),
|
||||
widget_width * 0.5,
|
||||
y + H * 0.7
|
||||
);
|
||||
@ -9830,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
ctx.fill();
|
||||
}
|
||||
ctx.fillStyle = secondary_text_color;
|
||||
ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7);
|
||||
ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7);
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.textAlign = "right";
|
||||
if (w.type == "number") {
|
||||
@ -9882,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
|
||||
//ctx.stroke();
|
||||
ctx.fillStyle = secondary_text_color;
|
||||
if (w.name != null) {
|
||||
ctx.fillText(w.name, margin * 2, y + H * 0.7);
|
||||
const label = w.label || w.name;
|
||||
if (label != null) {
|
||||
ctx.fillText(label, margin * 2, y + H * 0.7);
|
||||
}
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.textAlign = "right";
|
||||
@ -9915,7 +9913,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
event,
|
||||
active_widget
|
||||
) {
|
||||
if (!node.widgets || !node.widgets.length) {
|
||||
if (!node.widgets || !node.widgets.length || (!this.allow_interaction && !node.flags.allow_interaction)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@ -9953,11 +9951,11 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
}
|
||||
break;
|
||||
case "slider":
|
||||
var range = w.options.max - w.options.min;
|
||||
var old_value = w.value;
|
||||
var nvalue = Math.clamp((x - 15) / (widget_width - 30), 0, 1);
|
||||
if(w.options.read_only) break;
|
||||
w.value = w.options.min + (w.options.max - w.options.min) * nvalue;
|
||||
if (w.callback) {
|
||||
if (old_value != w.value) {
|
||||
setTimeout(function() {
|
||||
inner_value_change(w, w.value);
|
||||
}, 20);
|
||||
@ -10044,7 +10042,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
if (event.click_time < 200 && delta == 0) {
|
||||
this.prompt("Value",w.value,function(v) {
|
||||
// check if v is a valid equation or a number
|
||||
if (/^[0-9+\-*/()\s]+$/.test(v)) {
|
||||
if (/^[0-9+\-*/()\s]+|\d+\.\d+$/.test(v)) {
|
||||
try {//solve the equation if possible
|
||||
v = eval(v);
|
||||
} catch (e) { }
|
||||
@ -10304,6 +10302,119 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
canvas.graph.add(group);
|
||||
};
|
||||
|
||||
/**
|
||||
* Determines the furthest nodes in each direction
|
||||
* @param nodes {LGraphNode[]} the nodes to from which boundary nodes will be extracted
|
||||
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
|
||||
*/
|
||||
LGraphCanvas.getBoundaryNodes = function(nodes) {
|
||||
let top = null;
|
||||
let right = null;
|
||||
let bottom = null;
|
||||
let left = null;
|
||||
for (const nID in nodes) {
|
||||
const node = nodes[nID];
|
||||
const [x, y] = node.pos;
|
||||
const [width, height] = node.size;
|
||||
|
||||
if (top === null || y < top.pos[1]) {
|
||||
top = node;
|
||||
}
|
||||
if (right === null || x + width > right.pos[0] + right.size[0]) {
|
||||
right = node;
|
||||
}
|
||||
if (bottom === null || y + height > bottom.pos[1] + bottom.size[1]) {
|
||||
bottom = node;
|
||||
}
|
||||
if (left === null || x < left.pos[0]) {
|
||||
left = node;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"top": top,
|
||||
"right": right,
|
||||
"bottom": bottom,
|
||||
"left": left
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Determines the furthest nodes in each direction for the currently selected nodes
|
||||
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
|
||||
*/
|
||||
LGraphCanvas.prototype.boundaryNodesForSelection = function() {
|
||||
return LGraphCanvas.getBoundaryNodes(Object.values(this.selected_nodes));
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {LGraphNode[]} nodes a list of nodes
|
||||
* @param {"top"|"bottom"|"left"|"right"} direction Direction to align the nodes
|
||||
* @param {LGraphNode?} align_to Node to align to (if null, align to the furthest node in the given direction)
|
||||
*/
|
||||
LGraphCanvas.alignNodes = function (nodes, direction, align_to) {
|
||||
if (!nodes) {
|
||||
return;
|
||||
}
|
||||
|
||||
const canvas = LGraphCanvas.active_canvas;
|
||||
let boundaryNodes = []
|
||||
if (align_to === undefined) {
|
||||
boundaryNodes = LGraphCanvas.getBoundaryNodes(nodes)
|
||||
} else {
|
||||
boundaryNodes = {
|
||||
"top": align_to,
|
||||
"right": align_to,
|
||||
"bottom": align_to,
|
||||
"left": align_to
|
||||
}
|
||||
}
|
||||
|
||||
for (const [_, node] of Object.entries(canvas.selected_nodes)) {
|
||||
switch (direction) {
|
||||
case "right":
|
||||
node.pos[0] = boundaryNodes["right"].pos[0] + boundaryNodes["right"].size[0] - node.size[0];
|
||||
break;
|
||||
case "left":
|
||||
node.pos[0] = boundaryNodes["left"].pos[0];
|
||||
break;
|
||||
case "top":
|
||||
node.pos[1] = boundaryNodes["top"].pos[1];
|
||||
break;
|
||||
case "bottom":
|
||||
node.pos[1] = boundaryNodes["bottom"].pos[1] + boundaryNodes["bottom"].size[1] - node.size[1];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
canvas.dirty_canvas = true;
|
||||
canvas.dirty_bgcanvas = true;
|
||||
};
|
||||
|
||||
LGraphCanvas.onNodeAlign = function(value, options, event, prev_menu, node) {
|
||||
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
|
||||
event: event,
|
||||
callback: inner_clicked,
|
||||
parentMenu: prev_menu,
|
||||
});
|
||||
|
||||
function inner_clicked(value) {
|
||||
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase(), node);
|
||||
}
|
||||
}
|
||||
|
||||
LGraphCanvas.onGroupAlign = function(value, options, event, prev_menu) {
|
||||
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
|
||||
event: event,
|
||||
callback: inner_clicked,
|
||||
parentMenu: prev_menu,
|
||||
});
|
||||
|
||||
function inner_clicked(value) {
|
||||
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase());
|
||||
}
|
||||
}
|
||||
|
||||
LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) {
|
||||
|
||||
var canvas = LGraphCanvas.active_canvas;
|
||||
@ -12904,6 +13015,14 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
options.push({ content: "Options", callback: that.showShowGraphOptionsPanel });
|
||||
}*/
|
||||
|
||||
if (Object.keys(this.selected_nodes).length > 1) {
|
||||
options.push({
|
||||
content: "Align",
|
||||
has_submenu: true,
|
||||
callback: LGraphCanvas.onGroupAlign,
|
||||
})
|
||||
}
|
||||
|
||||
if (this._graph_stack && this._graph_stack.length > 0) {
|
||||
options.push(null, {
|
||||
content: "Close subgraph",
|
||||
@ -13018,6 +13137,14 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
callback: LGraphCanvas.onMenuNodeToSubgraph
|
||||
});
|
||||
|
||||
if (Object.keys(this.selected_nodes).length > 1) {
|
||||
options.push({
|
||||
content: "Align Selected To",
|
||||
has_submenu: true,
|
||||
callback: LGraphCanvas.onNodeAlign,
|
||||
})
|
||||
}
|
||||
|
||||
options.push(null, {
|
||||
content: "Remove",
|
||||
disabled: !(node.removable !== false && !node.block_delete ),
|
||||
|
||||
@ -35,7 +35,7 @@ class ComfyApi extends EventTarget {
|
||||
}
|
||||
|
||||
let opened = false;
|
||||
let existingSession = sessionStorage["Comfy.SessionId"] || "";
|
||||
let existingSession = window.name;
|
||||
if (existingSession) {
|
||||
existingSession = "?clientId=" + existingSession;
|
||||
}
|
||||
@ -75,7 +75,7 @@ class ComfyApi extends EventTarget {
|
||||
case "status":
|
||||
if (msg.data.sid) {
|
||||
this.clientId = msg.data.sid;
|
||||
sessionStorage["Comfy.SessionId"] = this.clientId;
|
||||
window.name = this.clientId;
|
||||
}
|
||||
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
|
||||
break;
|
||||
@ -163,7 +163,7 @@ class ComfyApi extends EventTarget {
|
||||
|
||||
if (res.status !== 200) {
|
||||
throw {
|
||||
response: await res.text(),
|
||||
response: await res.json(),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,29 +2,167 @@ import { ComfyWidgets } from "./widgets.js";
|
||||
import { ComfyUI, $el } from "./ui.js";
|
||||
import { api } from "./api.js";
|
||||
import { defaultGraph } from "./defaultGraph.js";
|
||||
import { getPngMetadata, importA1111 } from "./pnginfo.js";
|
||||
import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
|
||||
|
||||
class ComfyApp {
|
||||
/**
|
||||
* List of {number, batchCount} entries to queue
|
||||
/**
|
||||
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
|
||||
*/
|
||||
|
||||
export class ComfyApp {
|
||||
/**
|
||||
* List of entries to queue
|
||||
* @type {{number: number, batchCount: number}[]}
|
||||
*/
|
||||
#queueItems = [];
|
||||
/**
|
||||
* If the queue is currently being processed
|
||||
* @type {boolean}
|
||||
*/
|
||||
#processingQueue = false;
|
||||
|
||||
/**
|
||||
* Content Clipboard
|
||||
* @type {serialized node object}
|
||||
*/
|
||||
static clipspace = null;
|
||||
static clipspace_invalidate_handler = null;
|
||||
static open_maskeditor = null;
|
||||
static clipspace_return_node = null;
|
||||
|
||||
constructor() {
|
||||
this.ui = new ComfyUI(this);
|
||||
|
||||
/**
|
||||
* List of extensions that are registered with the app
|
||||
* @type {ComfyExtension[]}
|
||||
*/
|
||||
this.extensions = [];
|
||||
|
||||
/**
|
||||
* Stores the execution output data for each node
|
||||
* @type {Record<string, any>}
|
||||
*/
|
||||
this.nodeOutputs = {};
|
||||
|
||||
/**
|
||||
* If the shift key on the keyboard is pressed
|
||||
* @type {boolean}
|
||||
*/
|
||||
this.shiftDown = false;
|
||||
}
|
||||
|
||||
static isImageNode(node) {
|
||||
return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0);
|
||||
}
|
||||
|
||||
static onClipspaceEditorSave() {
|
||||
if(ComfyApp.clipspace_return_node) {
|
||||
ComfyApp.pasteFromClipspace(ComfyApp.clipspace_return_node);
|
||||
}
|
||||
}
|
||||
|
||||
static onClipspaceEditorClosed() {
|
||||
ComfyApp.clipspace_return_node = null;
|
||||
}
|
||||
|
||||
static copyToClipspace(node) {
|
||||
var widgets = null;
|
||||
if(node.widgets) {
|
||||
widgets = node.widgets.map(({ type, name, value }) => ({ type, name, value }));
|
||||
}
|
||||
|
||||
var imgs = undefined;
|
||||
var orig_imgs = undefined;
|
||||
if(node.imgs != undefined) {
|
||||
imgs = [];
|
||||
orig_imgs = [];
|
||||
|
||||
for (let i = 0; i < node.imgs.length; i++) {
|
||||
imgs[i] = new Image();
|
||||
imgs[i].src = node.imgs[i].src;
|
||||
orig_imgs[i] = imgs[i];
|
||||
}
|
||||
}
|
||||
|
||||
var selectedIndex = 0;
|
||||
if(node.imageIndex) {
|
||||
selectedIndex = node.imageIndex;
|
||||
}
|
||||
|
||||
ComfyApp.clipspace = {
|
||||
'widgets': widgets,
|
||||
'imgs': imgs,
|
||||
'original_imgs': orig_imgs,
|
||||
'images': node.images,
|
||||
'selectedIndex': selectedIndex,
|
||||
'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action
|
||||
};
|
||||
|
||||
ComfyApp.clipspace_return_node = null;
|
||||
|
||||
if(ComfyApp.clipspace_invalidate_handler) {
|
||||
ComfyApp.clipspace_invalidate_handler();
|
||||
}
|
||||
}
|
||||
|
||||
static pasteFromClipspace(node) {
|
||||
if(ComfyApp.clipspace) {
|
||||
// image paste
|
||||
if(ComfyApp.clipspace.imgs && node.imgs) {
|
||||
if(node.images && ComfyApp.clipspace.images) {
|
||||
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
|
||||
app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
|
||||
}
|
||||
else
|
||||
app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images;
|
||||
}
|
||||
|
||||
if(ComfyApp.clipspace.imgs) {
|
||||
// deep-copy to cut link with clipspace
|
||||
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
|
||||
const img = new Image();
|
||||
img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
|
||||
node.imgs = [img];
|
||||
node.imageIndex = 0;
|
||||
}
|
||||
else {
|
||||
const imgs = [];
|
||||
for(let i=0; i<ComfyApp.clipspace.imgs.length; i++) {
|
||||
imgs[i] = new Image();
|
||||
imgs[i].src = ComfyApp.clipspace.imgs[i].src;
|
||||
node.imgs = imgs;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(node.widgets) {
|
||||
if(ComfyApp.clipspace.images) {
|
||||
const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']];
|
||||
const index = node.widgets.findIndex(obj => obj.name === 'image');
|
||||
if(index >= 0) {
|
||||
node.widgets[index].value = clip_image;
|
||||
}
|
||||
}
|
||||
if(ComfyApp.clipspace.widgets) {
|
||||
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
||||
const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name);
|
||||
if (prop && prop.type != 'button') {
|
||||
prop.value = value;
|
||||
prop.callback(value);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
app.graph.setDirtyCanvas(true);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Invoke an extension callback
|
||||
* @param {string} method The extension callback to execute
|
||||
* @param {...any} args Any arguments to pass to the callback
|
||||
* @param {keyof ComfyExtension} method The extension callback to execute
|
||||
* @param {any[]} args Any arguments to pass to the callback
|
||||
* @returns
|
||||
*/
|
||||
#invokeExtensions(method, ...args) {
|
||||
@ -109,6 +247,32 @@ class ComfyApp {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// prevent conflict of clipspace content
|
||||
if(!ComfyApp.clipspace_return_node) {
|
||||
options.push({
|
||||
content: "Copy (Clipspace)",
|
||||
callback: (obj) => { ComfyApp.copyToClipspace(this); }
|
||||
});
|
||||
|
||||
if(ComfyApp.clipspace != null) {
|
||||
options.push({
|
||||
content: "Paste (Clipspace)",
|
||||
callback: () => { ComfyApp.pasteFromClipspace(this); }
|
||||
});
|
||||
}
|
||||
|
||||
if(ComfyApp.isImageNode(this)) {
|
||||
options.push({
|
||||
content: "Open in MaskEditor",
|
||||
callback: (obj) => {
|
||||
ComfyApp.copyToClipspace(this);
|
||||
ComfyApp.clipspace_return_node = this;
|
||||
ComfyApp.open_maskeditor();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@ -159,6 +323,34 @@ class ComfyApp {
|
||||
*/
|
||||
#addDrawBackgroundHandler(node) {
|
||||
const app = this;
|
||||
|
||||
function getImageTop(node) {
|
||||
let shiftY;
|
||||
if (node.imageOffset != null) {
|
||||
shiftY = node.imageOffset;
|
||||
} else {
|
||||
if (node.widgets?.length) {
|
||||
const w = node.widgets[node.widgets.length - 1];
|
||||
shiftY = w.last_y;
|
||||
if (w.computeSize) {
|
||||
shiftY += w.computeSize()[1] + 4;
|
||||
} else {
|
||||
shiftY += LiteGraph.NODE_WIDGET_HEIGHT + 4;
|
||||
}
|
||||
} else {
|
||||
shiftY = node.computeSize()[1];
|
||||
}
|
||||
}
|
||||
return shiftY;
|
||||
}
|
||||
|
||||
node.prototype.setSizeForImage = function () {
|
||||
const minHeight = getImageTop(this) + 220;
|
||||
if (this.size[1] < minHeight) {
|
||||
this.setSize([this.size[0], minHeight]);
|
||||
}
|
||||
};
|
||||
|
||||
node.prototype.onDrawBackground = function (ctx) {
|
||||
if (!this.flags.collapsed) {
|
||||
const output = app.nodeOutputs[this.id + ""];
|
||||
@ -179,9 +371,7 @@ class ComfyApp {
|
||||
).then((imgs) => {
|
||||
if (this.images === output.images) {
|
||||
this.imgs = imgs.filter(Boolean);
|
||||
if (this.size[1] < 100) {
|
||||
this.size[1] = 250;
|
||||
}
|
||||
this.setSizeForImage?.();
|
||||
app.graph.setDirtyCanvas(true);
|
||||
}
|
||||
});
|
||||
@ -206,12 +396,7 @@ class ComfyApp {
|
||||
this.imageIndex = imageIndex = 0;
|
||||
}
|
||||
|
||||
let shiftY;
|
||||
if (this.imageOffset != null) {
|
||||
shiftY = this.imageOffset;
|
||||
} else {
|
||||
shiftY = this.computeSize()[1];
|
||||
}
|
||||
const shiftY = getImageTop(this);
|
||||
|
||||
let dw = this.size[0];
|
||||
let dh = this.size[1];
|
||||
@ -599,7 +784,7 @@ class ComfyApp {
|
||||
ctx.globalAlpha = 0.8;
|
||||
ctx.beginPath();
|
||||
if (shape == LiteGraph.BOX_SHAPE)
|
||||
ctx.rect(-6, -6 + LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT);
|
||||
ctx.rect(-6, -6 - LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT);
|
||||
else if (shape == LiteGraph.ROUND_SHAPE || (shape == LiteGraph.CARD_SHAPE && node.flags.collapsed))
|
||||
ctx.roundRect(
|
||||
-6,
|
||||
@ -611,12 +796,11 @@ class ComfyApp {
|
||||
else if (shape == LiteGraph.CARD_SHAPE)
|
||||
ctx.roundRect(
|
||||
-6,
|
||||
-6 + LiteGraph.NODE_TITLE_HEIGHT,
|
||||
-6 - LiteGraph.NODE_TITLE_HEIGHT,
|
||||
12 + size[0] + 1,
|
||||
12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT,
|
||||
this.round_radius * 2,
|
||||
2
|
||||
);
|
||||
[this.round_radius * 2, this.round_radius * 2, 2, 2]
|
||||
);
|
||||
else if (shape == LiteGraph.CIRCLE_SHAPE)
|
||||
ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI * 2);
|
||||
ctx.strokeStyle = color;
|
||||
@ -691,11 +875,6 @@ class ComfyApp {
|
||||
#addKeyboardHandler() {
|
||||
window.addEventListener("keydown", (e) => {
|
||||
this.shiftDown = e.shiftKey;
|
||||
|
||||
// Queue prompt using ctrl or command + enter
|
||||
if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) {
|
||||
this.queuePrompt(e.shiftKey ? -1 : 0);
|
||||
}
|
||||
});
|
||||
window.addEventListener("keyup", (e) => {
|
||||
this.shiftDown = e.shiftKey;
|
||||
@ -723,7 +902,9 @@ class ComfyApp {
|
||||
await this.#loadExtensions();
|
||||
|
||||
// Create and mount the LiteGraph in the DOM
|
||||
const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" }));
|
||||
const mainCanvas = document.createElement("canvas")
|
||||
mainCanvas.style.touchAction = "none"
|
||||
const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" }));
|
||||
canvasEl.tabIndex = "1";
|
||||
document.body.prepend(canvasEl);
|
||||
|
||||
@ -835,7 +1016,8 @@ class ComfyApp {
|
||||
for (const o in nodeData["output"]) {
|
||||
const output = nodeData["output"][o];
|
||||
const outputName = nodeData["output_name"][o] || output;
|
||||
this.addOutput(outputName, output);
|
||||
const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
|
||||
this.addOutput(outputName, output, { shape: outputShape });
|
||||
}
|
||||
|
||||
const s = this.computeSize();
|
||||
@ -872,8 +1054,10 @@ class ComfyApp {
|
||||
loadGraphData(graphData) {
|
||||
this.clean();
|
||||
|
||||
let reset_invalid_values = false;
|
||||
if (!graphData) {
|
||||
graphData = structuredClone(defaultGraph);
|
||||
reset_invalid_values = true;
|
||||
}
|
||||
|
||||
const missingNodeTypes = [];
|
||||
@ -949,9 +1133,20 @@ class ComfyApp {
|
||||
widget.value = widget.value.slice(7);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (node.type == "KSampler" || node.type == "KSamplerAdvanced" || node.type == "PrimitiveNode") {
|
||||
if (widget.name == "control_after_generate") {
|
||||
if (widget.value == true) {
|
||||
if (widget.value === true) {
|
||||
widget.value = "randomize";
|
||||
} else if (widget.value === false) {
|
||||
widget.value = "fixed";
|
||||
}
|
||||
}
|
||||
}
|
||||
if (reset_invalid_values) {
|
||||
if (widget.type == "combo") {
|
||||
if (!widget.options.values.includes(widget.value) && widget.options.values.length > 0) {
|
||||
widget.value = widget.options.values[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1067,7 +1262,7 @@ class ComfyApp {
|
||||
try {
|
||||
await api.queuePrompt(number, p);
|
||||
} catch (error) {
|
||||
this.ui.dialog.show(error.response || error.toString());
|
||||
this.ui.dialog.show(error.response.error || error.toString());
|
||||
break;
|
||||
}
|
||||
|
||||
@ -1113,9 +1308,18 @@ class ComfyApp {
|
||||
this.loadGraphData(JSON.parse(reader.result));
|
||||
};
|
||||
reader.readAsText(file);
|
||||
} else if (file.name?.endsWith(".latent")) {
|
||||
const info = await getLatentMetadata(file);
|
||||
if (info.workflow) {
|
||||
this.loadGraphData(JSON.parse(info.workflow));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers a Comfy web extension with the app
|
||||
* @param {ComfyExtension} extension
|
||||
*/
|
||||
registerExtension(extension) {
|
||||
if (!extension.name) {
|
||||
throw new Error("Extensions must have a 'name' property.");
|
||||
@ -1139,12 +1343,12 @@ class ComfyApp {
|
||||
|
||||
for(const widgetNum in node.widgets) {
|
||||
const widget = node.widgets[widgetNum]
|
||||
|
||||
if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) {
|
||||
widget.options.values = def["input"]["required"][widget.name][0];
|
||||
|
||||
if(!widget.options.values.includes(widget.value)) {
|
||||
if(widget.name != 'image' && !widget.options.values.includes(widget.value)) {
|
||||
widget.value = widget.options.values[0];
|
||||
widget.callback(widget.value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -47,6 +47,22 @@ export function getPngMetadata(file) {
|
||||
});
|
||||
}
|
||||
|
||||
export function getLatentMetadata(file) {
|
||||
return new Promise((r) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
const safetensorsData = new Uint8Array(event.target.result);
|
||||
const dataView = new DataView(safetensorsData.buffer);
|
||||
let header_size = dataView.getUint32(0, true);
|
||||
let offset = 8;
|
||||
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
|
||||
r(header.__metadata__);
|
||||
};
|
||||
|
||||
reader.readAsArrayBuffer(file);
|
||||
});
|
||||
}
|
||||
|
||||
export async function importA1111(graph, parameters) {
|
||||
const p = parameters.lastIndexOf("\nSteps:");
|
||||
if (p > -1) {
|
||||
@ -131,6 +147,7 @@ export async function importA1111(graph, parameters) {
|
||||
}
|
||||
|
||||
function replaceEmbeddings(text) {
|
||||
if(!embeddings.length) return text;
|
||||
return text.replaceAll(
|
||||
new RegExp(
|
||||
"\\b(" + embeddings.map((e) => e.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")).join("\\b|\\b") + ")\\b",
|
||||
|
||||
@ -270,6 +270,30 @@ class ComfySettingsDialog extends ComfyDialog {
|
||||
]),
|
||||
]);
|
||||
break;
|
||||
case "slider":
|
||||
element = $el("div", [
|
||||
$el("label", { textContent: name }, [
|
||||
$el("input", {
|
||||
type: "range",
|
||||
value,
|
||||
oninput: (e) => {
|
||||
setter(e.target.value);
|
||||
e.target.nextElementSibling.value = e.target.value;
|
||||
},
|
||||
...attrs
|
||||
}),
|
||||
$el("input", {
|
||||
type: "number",
|
||||
value,
|
||||
oninput: (e) => {
|
||||
setter(e.target.value);
|
||||
e.target.previousElementSibling.value = e.target.value;
|
||||
},
|
||||
...attrs
|
||||
}),
|
||||
]),
|
||||
]);
|
||||
break;
|
||||
default:
|
||||
console.warn("Unsupported setting type, defaulting to text");
|
||||
element = $el("div", [
|
||||
@ -431,9 +455,17 @@ export class ComfyUI {
|
||||
defaultValue: true,
|
||||
});
|
||||
|
||||
const promptFilename = this.settings.addSetting({
|
||||
id: "Comfy.PromptFilename",
|
||||
name: "Prompt for filename when saving workflow",
|
||||
type: "boolean",
|
||||
defaultValue: true,
|
||||
});
|
||||
|
||||
const fileInput = $el("input", {
|
||||
id: "comfy-file-input",
|
||||
type: "file",
|
||||
accept: ".json,image/png",
|
||||
accept: ".json,image/png,.latent",
|
||||
style: { display: "none" },
|
||||
parent: document.body,
|
||||
onchange: () => {
|
||||
@ -448,6 +480,7 @@ export class ComfyUI {
|
||||
$el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }),
|
||||
]),
|
||||
$el("button.comfy-queue-btn", {
|
||||
id: "queue-button",
|
||||
textContent: "Queue Prompt",
|
||||
onclick: () => app.queuePrompt(0, this.batchCount),
|
||||
}),
|
||||
@ -496,9 +529,10 @@ export class ComfyUI {
|
||||
]),
|
||||
]),
|
||||
$el("div.comfy-menu-btns", [
|
||||
$el("button", { textContent: "Queue Front", onclick: () => app.queuePrompt(-1, this.batchCount) }),
|
||||
$el("button", { id: "queue-front-button", textContent: "Queue Front", onclick: () => app.queuePrompt(-1, this.batchCount) }),
|
||||
$el("button", {
|
||||
$: (b) => (this.queue.button = b),
|
||||
id: "comfy-view-queue-button",
|
||||
textContent: "View Queue",
|
||||
onclick: () => {
|
||||
this.history.hide();
|
||||
@ -507,6 +541,7 @@ export class ComfyUI {
|
||||
}),
|
||||
$el("button", {
|
||||
$: (b) => (this.history.button = b),
|
||||
id: "comfy-view-history-button",
|
||||
textContent: "View History",
|
||||
onclick: () => {
|
||||
this.queue.hide();
|
||||
@ -517,14 +552,23 @@ export class ComfyUI {
|
||||
this.queue.element,
|
||||
this.history.element,
|
||||
$el("button", {
|
||||
id: "comfy-save-button",
|
||||
textContent: "Save",
|
||||
onclick: () => {
|
||||
let filename = "workflow.json";
|
||||
if (promptFilename.value) {
|
||||
filename = prompt("Save workflow as:", filename);
|
||||
if (!filename) return;
|
||||
if (!filename.toLowerCase().endsWith(".json")) {
|
||||
filename += ".json";
|
||||
}
|
||||
}
|
||||
const json = JSON.stringify(app.graph.serialize(), null, 2); // convert the data to a JSON string
|
||||
const blob = new Blob([json], { type: "application/json" });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = $el("a", {
|
||||
href: url,
|
||||
download: "workflow.json",
|
||||
download: filename,
|
||||
style: { display: "none" },
|
||||
parent: document.body,
|
||||
});
|
||||
@ -535,15 +579,16 @@ export class ComfyUI {
|
||||
}, 0);
|
||||
},
|
||||
}),
|
||||
$el("button", { textContent: "Load", onclick: () => fileInput.click() }),
|
||||
$el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
|
||||
$el("button", { textContent: "Clear", onclick: () => {
|
||||
$el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }),
|
||||
$el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
|
||||
$el("button", { id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace() }),
|
||||
$el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => {
|
||||
if (!confirmClear.value || confirm("Clear workflow?")) {
|
||||
app.clean();
|
||||
app.graph.clear();
|
||||
}
|
||||
}}),
|
||||
$el("button", { textContent: "Load Default", onclick: () => {
|
||||
$el("button", { id: "comfy-load-default-button", textContent: "Load Default", onclick: () => {
|
||||
if (!confirmClear.value || confirm("Load default workflow?")) {
|
||||
app.loadGraphData()
|
||||
}
|
||||
|
||||
@ -19,35 +19,60 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
|
||||
|
||||
var v = valueControl.value;
|
||||
|
||||
let min = targetWidget.options.min;
|
||||
let max = targetWidget.options.max;
|
||||
// limit to something that javascript can handle
|
||||
max = Math.min(1125899906842624, max);
|
||||
min = Math.max(-1125899906842624, min);
|
||||
let range = (max - min) / (targetWidget.options.step / 10);
|
||||
if (targetWidget.type == "combo" && v !== "fixed") {
|
||||
let current_index = targetWidget.options.values.indexOf(targetWidget.value);
|
||||
let current_length = targetWidget.options.values.length;
|
||||
|
||||
//adjust values based on valueControl Behaviour
|
||||
switch (v) {
|
||||
case "fixed":
|
||||
break;
|
||||
case "increment":
|
||||
targetWidget.value += targetWidget.options.step / 10;
|
||||
break;
|
||||
case "decrement":
|
||||
targetWidget.value -= targetWidget.options.step / 10;
|
||||
break;
|
||||
case "randomize":
|
||||
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
|
||||
default:
|
||||
break;
|
||||
switch (v) {
|
||||
case "increment":
|
||||
current_index += 1;
|
||||
break;
|
||||
case "decrement":
|
||||
current_index -= 1;
|
||||
break;
|
||||
case "randomize":
|
||||
current_index = Math.floor(Math.random() * current_length);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
current_index = Math.max(0, current_index);
|
||||
current_index = Math.min(current_length - 1, current_index);
|
||||
if (current_index >= 0) {
|
||||
let value = targetWidget.options.values[current_index];
|
||||
targetWidget.value = value;
|
||||
targetWidget.callback(value);
|
||||
}
|
||||
} else { //number
|
||||
let min = targetWidget.options.min;
|
||||
let max = targetWidget.options.max;
|
||||
// limit to something that javascript can handle
|
||||
max = Math.min(1125899906842624, max);
|
||||
min = Math.max(-1125899906842624, min);
|
||||
let range = (max - min) / (targetWidget.options.step / 10);
|
||||
|
||||
//adjust values based on valueControl Behaviour
|
||||
switch (v) {
|
||||
case "fixed":
|
||||
break;
|
||||
case "increment":
|
||||
targetWidget.value += targetWidget.options.step / 10;
|
||||
break;
|
||||
case "decrement":
|
||||
targetWidget.value -= targetWidget.options.step / 10;
|
||||
break;
|
||||
case "randomize":
|
||||
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
/*check if values are over or under their respective
|
||||
* ranges and set them to min or max.*/
|
||||
if (targetWidget.value < min)
|
||||
targetWidget.value = min;
|
||||
|
||||
if (targetWidget.value > max)
|
||||
targetWidget.value = max;
|
||||
}
|
||||
/*check if values are over or under their respective
|
||||
* ranges and set them to min or max.*/
|
||||
if (targetWidget.value < min)
|
||||
targetWidget.value = min;
|
||||
|
||||
if (targetWidget.value > max)
|
||||
targetWidget.value = max;
|
||||
}
|
||||
return valueControl;
|
||||
};
|
||||
@ -136,9 +161,11 @@ function addMultilineWidget(node, name, opts, app) {
|
||||
left: `${t.a * margin + t.e}px`,
|
||||
top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`,
|
||||
width: `${(widgetWidth - margin * 2 - 3) * t.a}px`,
|
||||
background: (!node.color)?'':node.color,
|
||||
height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`,
|
||||
position: "absolute",
|
||||
zIndex: 1,
|
||||
color: (!node.color)?'':'white',
|
||||
zIndex: app.graph._nodes.indexOf(node),
|
||||
fontSize: `${t.d * 10.0}px`,
|
||||
});
|
||||
this.inputEl.hidden = !visible;
|
||||
@ -259,19 +286,51 @@ export const ComfyWidgets = {
|
||||
let uploadWidget;
|
||||
|
||||
function showImage(name) {
|
||||
// Position the image somewhere sensible
|
||||
if (!node.imageOffset) {
|
||||
node.imageOffset = uploadWidget.last_y ? uploadWidget.last_y + 25 : 75;
|
||||
}
|
||||
|
||||
const img = new Image();
|
||||
img.onload = () => {
|
||||
node.imgs = [img];
|
||||
app.graph.setDirtyCanvas(true);
|
||||
};
|
||||
img.src = `/view?filename=${name}&type=input`;
|
||||
let folder_separator = name.lastIndexOf("/");
|
||||
let subfolder = "";
|
||||
if (folder_separator > -1) {
|
||||
subfolder = name.substring(0, folder_separator);
|
||||
name = name.substring(folder_separator + 1);
|
||||
}
|
||||
img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`;
|
||||
node.setSizeForImage?.();
|
||||
}
|
||||
|
||||
var default_value = imageWidget.value;
|
||||
Object.defineProperty(imageWidget, "value", {
|
||||
set : function(value) {
|
||||
this._real_value = value;
|
||||
},
|
||||
|
||||
get : function() {
|
||||
let value = "";
|
||||
if (this._real_value) {
|
||||
value = this._real_value;
|
||||
} else {
|
||||
return default_value;
|
||||
}
|
||||
|
||||
if (value.filename) {
|
||||
let real_value = value;
|
||||
value = "";
|
||||
if (real_value.subfolder) {
|
||||
value = real_value.subfolder + "/";
|
||||
}
|
||||
|
||||
value += real_value.filename;
|
||||
|
||||
if(real_value.type && real_value.type !== "input")
|
||||
value += ` [${real_value.type}]`;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
});
|
||||
|
||||
// Add our own callback to the combo widget to render an image when it changes
|
||||
const cb = node.callback;
|
||||
imageWidget.callback = function () {
|
||||
|
||||
@ -120,7 +120,7 @@ body {
|
||||
.comfy-menu > button,
|
||||
.comfy-menu-btns button,
|
||||
.comfy-menu .comfy-list button,
|
||||
.comfy-modal button{
|
||||
.comfy-modal button {
|
||||
color: var(--input-text);
|
||||
background-color: var(--comfy-input-bg);
|
||||
border-radius: 8px;
|
||||
@ -129,6 +129,15 @@ body {
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.comfy-menu > button:hover,
|
||||
.comfy-menu-btns button:hover,
|
||||
.comfy-menu .comfy-list button:hover,
|
||||
.comfy-modal button:hover,
|
||||
.comfy-settings-btn:hover {
|
||||
filter: brightness(1.2);
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.comfy-menu span.drag-handle {
|
||||
width: 10px;
|
||||
height: 20px;
|
||||
@ -160,9 +169,9 @@ body {
|
||||
|
||||
.comfy-list {
|
||||
color: var(--descrip-text);
|
||||
background-color: #333;
|
||||
background-color: var(--comfy-menu-bg);
|
||||
margin-bottom: 10px;
|
||||
border-color: #4e4e4e;
|
||||
border-color: var(--border-color);
|
||||
border-style: solid;
|
||||
}
|
||||
|
||||
@ -217,6 +226,14 @@ button.comfy-queue-btn {
|
||||
z-index: 99;
|
||||
}
|
||||
|
||||
.comfy-modal.comfy-settings input[type="range"] {
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
.comfy-modal.comfy-settings input[type="range"] + input[type="number"] {
|
||||
width: 3.5em;
|
||||
}
|
||||
|
||||
.comfy-modal input,
|
||||
.comfy-modal select {
|
||||
color: var(--input-text);
|
||||
@ -240,8 +257,11 @@ button.comfy-queue-btn {
|
||||
}
|
||||
}
|
||||
|
||||
/* Input popup */
|
||||
|
||||
.graphdialog {
|
||||
min-height: 1em;
|
||||
background-color: var(--comfy-menu-bg);
|
||||
}
|
||||
|
||||
.graphdialog .name {
|
||||
@ -265,15 +285,66 @@ button.comfy-queue-btn {
|
||||
border-radius: 12px 0 0 12px;
|
||||
}
|
||||
|
||||
/* Context menu */
|
||||
|
||||
.litegraph .litemenu-entry.has_submenu {
|
||||
position: relative;
|
||||
padding-right: 20px;
|
||||
}
|
||||
}
|
||||
|
||||
.litemenu-entry.has_submenu::after {
|
||||
.litemenu-entry.has_submenu::after {
|
||||
content: ">";
|
||||
position: absolute;
|
||||
top: 0;
|
||||
right: 2px;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
.litegraph.litecontextmenu,
|
||||
.litegraph.litecontextmenu.dark {
|
||||
z-index: 9999 !important;
|
||||
background-color: var(--comfy-menu-bg) !important;
|
||||
filter: brightness(95%);
|
||||
}
|
||||
|
||||
.litegraph.litecontextmenu .litemenu-entry:hover:not(.disabled):not(.separator) {
|
||||
background-color: var(--comfy-menu-bg) !important;
|
||||
filter: brightness(155%);
|
||||
color: var(--input-text);
|
||||
}
|
||||
|
||||
.litegraph.litecontextmenu .litemenu-entry.submenu,
|
||||
.litegraph.litecontextmenu.dark .litemenu-entry.submenu {
|
||||
background-color: var(--comfy-menu-bg) !important;
|
||||
color: var(--input-text);
|
||||
}
|
||||
|
||||
.litegraph.litecontextmenu input {
|
||||
background-color: var(--comfy-input-bg) !important;
|
||||
color: var(--input-text) !important;
|
||||
}
|
||||
|
||||
/* Search box */
|
||||
|
||||
.litegraph.litesearchbox {
|
||||
z-index: 9999 !important;
|
||||
background-color: var(--comfy-menu-bg) !important;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.litegraph.litesearchbox input,
|
||||
.litegraph.litesearchbox select {
|
||||
background-color: var(--comfy-input-bg) !important;
|
||||
color: var(--input-text);
|
||||
}
|
||||
|
||||
.litegraph.lite-search-item {
|
||||
color: var(--input-text);
|
||||
background-color: var(--comfy-input-bg);
|
||||
filter: brightness(80%);
|
||||
padding-left: 0.2em;
|
||||
}
|
||||
|
||||
.litegraph.lite-search-item.generic_type {
|
||||
color: var(--input-text);
|
||||
filter: brightness(50%);
|
||||
}
|
||||
|
||||
78
web/types/comfy.d.ts
vendored
Normal file
78
web/types/comfy.d.ts
vendored
Normal file
@ -0,0 +1,78 @@
|
||||
import { LGraphNode, IWidget } from "./litegraph";
|
||||
import { ComfyApp } from "/scripts/app";
|
||||
|
||||
export interface ComfyExtension {
|
||||
/**
|
||||
* The name of the extension
|
||||
*/
|
||||
name: string;
|
||||
/**
|
||||
* Allows any initialisation, e.g. loading resources. Called after the canvas is created but before nodes are added
|
||||
* @param app The ComfyUI app instance
|
||||
*/
|
||||
init(app: ComfyApp): Promise<void>;
|
||||
/**
|
||||
* Allows any additonal setup, called after the application is fully set up and running
|
||||
* @param app The ComfyUI app instance
|
||||
*/
|
||||
setup(app: ComfyApp): Promise<void>;
|
||||
/**
|
||||
* Called before nodes are registered with the graph
|
||||
* @param defs The collection of node definitions, add custom ones or edit existing ones
|
||||
* @param app The ComfyUI app instance
|
||||
*/
|
||||
addCustomNodeDefs(defs: Record<string, ComfyObjectInfo>, app: ComfyApp): Promise<void>;
|
||||
/**
|
||||
* Allows the extension to add custom widgets
|
||||
* @param app The ComfyUI app instance
|
||||
* @returns An array of {[widget name]: widget data}
|
||||
*/
|
||||
getCustomWidgets(
|
||||
app: ComfyApp
|
||||
): Promise<
|
||||
Array<
|
||||
Record<string, (node, inputName, inputData, app) => { widget?: IWidget; minWidth?: number; minHeight?: number }>
|
||||
>
|
||||
>;
|
||||
/**
|
||||
* Allows the extension to add additional handling to the node before it is registered with LGraph
|
||||
* @param nodeType The node class (not an instance)
|
||||
* @param nodeData The original node object info config object
|
||||
* @param app The ComfyUI app instance
|
||||
*/
|
||||
beforeRegisterNodeDef(nodeType: typeof LGraphNode, nodeData: ComfyObjectInfo, app: ComfyApp): Promise<void>;
|
||||
/**
|
||||
* Allows the extension to register additional nodes with LGraph after standard nodes are added
|
||||
* @param app The ComfyUI app instance
|
||||
*/
|
||||
registerCustomNodes(app: ComfyApp): Promise<void>;
|
||||
/**
|
||||
* Allows the extension to modify a node that has been reloaded onto the graph.
|
||||
* If you break something in the backend and want to patch workflows in the frontend
|
||||
* This is the place to do this
|
||||
* @param node The node that has been loaded
|
||||
* @param app The ComfyUI app instance
|
||||
*/
|
||||
loadedGraphNode(node: LGraphNode, app: ComfyApp);
|
||||
/**
|
||||
* Allows the extension to run code after the constructor of the node
|
||||
* @param node The node that has been created
|
||||
* @param app The ComfyUI app instance
|
||||
*/
|
||||
nodeCreated(node: LGraphNode, app: ComfyApp);
|
||||
}
|
||||
|
||||
export type ComfyObjectInfo = {
|
||||
name: string;
|
||||
display_name?: string;
|
||||
description?: string;
|
||||
category: string;
|
||||
input?: {
|
||||
required?: Record<string, ComfyObjectInfoConfig>;
|
||||
optional?: Record<string, ComfyObjectInfoConfig>;
|
||||
};
|
||||
output?: string[];
|
||||
output_name: string[];
|
||||
};
|
||||
|
||||
export type ComfyObjectInfoConfig = [string | any[]] | [string | any[], any];
|
||||
1506
web/types/litegraph.d.ts
vendored
Normal file
1506
web/types/litegraph.d.ts
vendored
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user