Merge branch 'comfyanonymous:master' into master

This commit is contained in:
JAlB- 2023-04-07 23:55:36 +03:00 committed by GitHub
commit 032908a6dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 1379 additions and 392 deletions

View File

@ -14,7 +14,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) - Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
- Works even if you don't have a GPU with: ```--cpu``` (slow) - Works even if you don't have a GPU with: ```--cpu``` (slow)
- Can load both ckpt and safetensors models/checkpoints. Standalone VAEs and CLIP models. - Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
- Embeddings/Textual inversion - Embeddings/Textual inversion
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) - [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- Loading full workflows (with seeds) from generated PNG files. - Loading full workflows (with seeds) from generated PNG files.
@ -24,6 +24,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models. - [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/) - [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) - [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- Starts up very fast. - Starts up very fast.
- Works fully offline: will never download anything. - Works fully offline: will never download anything.
- [Config file](extra_model_paths.yaml.example) to set the search paths for models. - [Config file](extra_model_paths.yaml.example) to set the search paths for models.

31
comfy/cli_args.py Normal file
View File

@ -0,0 +1,31 @@
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
args = parser.parse_args()

362
comfy/diffusers_convert.py Normal file
View File

@ -0,0 +1,362 @@
import json
import os
import yaml
import folder_paths
from comfy.ldm.util import instantiate_from_config
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
import os.path as osp
import re
import torch
from safetensors.torch import load_file, save_file
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
# =================#
# UNet Conversion #
# =================#
unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("out.0.weight", "conv_norm_out.weight"),
("out.0.bias", "conv_norm_out.bias"),
("out.2.weight", "conv_out.weight"),
("out.2.bias", "conv_out.bias"),
]
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0", "norm1"),
("in_layers.2", "conv1"),
("out_layers.0", "norm2"),
("out_layers.3", "conv2"),
("emb_layers.1", "time_emb_proj"),
("skip_connection", "conv_shortcut"),
]
unet_conversion_map_layer = []
# hardcoded number of downblocks and resnets/attentions...
# would need smarter logic for other networks.
for i in range(4):
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
def convert_unet_state_dict(unet_state_dict):
# buyer beware: this is a *brittle* function,
# and correct output requires that all of these pieces interact in
# the exact order in which I have arranged them.
mapping = {k: k for k in unet_state_dict.keys()}
for sd_name, hf_name in unet_conversion_map:
mapping[hf_name] = sd_name
for k, v in mapping.items():
if "resnets" in k:
for sd_part, hf_part in unet_conversion_map_resnet:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
for sd_part, hf_part in unet_conversion_map_layer:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
# ================#
# VAE Conversion #
# ================#
vae_conversion_map = [
# (stable-diffusion, HF Diffusers)
("nin_shortcut", "conv_shortcut"),
("norm_out", "conv_norm_out"),
("mid.attn_1.", "mid_block.attentions.0."),
]
for i in range(4):
# down_blocks have two resnets
for j in range(2):
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
sd_down_prefix = f"encoder.down.{i}.block.{j}."
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
if i < 3:
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
sd_downsample_prefix = f"down.{i}.downsample."
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"up.{3 - i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
sd_mid_res_prefix = f"mid.block_{i + 1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers)
("norm.", "group_norm."),
("q.", "query."),
("k.", "key."),
("v.", "value."),
("proj_out.", "proj_attn."),
]
def reshape_weight_for_sd(w):
# convert HF linear weights to SD conv2d weights
return w.reshape(*w.shape, 1, 1)
def convert_vae_state_dict(vae_state_dict):
mapping = {k: k for k in vae_state_dict.keys()}
for k, v in mapping.items():
for sd_part, hf_part in vae_conversion_map:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
if "attentions" in k:
for sd_part, hf_part in vae_conversion_map_attn:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
weights_to_convert = ["q", "k", "v", "proj_out"]
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
print(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict
# =========================#
# Text Encoder Conversion #
# =========================#
textenc_conversion_lst = [
# (stable-diffusion, HF Diffusers)
("resblocks.", "text_model.encoder.layers."),
("ln_1", "layer_norm1"),
("ln_2", "layer_norm2"),
(".c_fc.", ".fc1."),
(".c_proj.", ".fc2."),
(".attn", ".self_attn"),
("ln_final.", "transformer.text_model.final_layer_norm."),
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
]
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
textenc_pattern = re.compile("|".join(protected.keys()))
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx = {"q": 0, "k": 1, "v": 2}
def convert_text_enc_state_dict_v20(text_enc_dict):
new_state_dict = {}
capture_qkv_weight = {}
capture_qkv_bias = {}
for k, v in text_enc_dict.items():
if (
k.endswith(".self_attn.q_proj.weight")
or k.endswith(".self_attn.k_proj.weight")
or k.endswith(".self_attn.v_proj.weight")
):
k_pre = k[: -len(".q_proj.weight")]
k_code = k[-len("q_proj.weight")]
if k_pre not in capture_qkv_weight:
capture_qkv_weight[k_pre] = [None, None, None]
capture_qkv_weight[k_pre][code2idx[k_code]] = v
continue
if (
k.endswith(".self_attn.q_proj.bias")
or k.endswith(".self_attn.k_proj.bias")
or k.endswith(".self_attn.v_proj.bias")
):
k_pre = k[: -len(".q_proj.bias")]
k_code = k[-len("q_proj.bias")]
if k_pre not in capture_qkv_bias:
capture_qkv_bias[k_pre] = [None, None, None]
capture_qkv_bias[k_pre][code2idx[k_code]] = v
continue
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
new_state_dict[relabelled_key] = v
for k_pre, tensors in capture_qkv_weight.items():
if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
for k_pre, tensors in capture_qkv_bias.items():
if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
return new_state_dict
def convert_text_enc_state_dict(text_enc_dict):
return text_enc_dict
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
# magic
v2 = diffusers_unet_conf["sample_size"] == 96
if 'prediction_type' in diffusers_scheduler_conf:
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
if v2:
if v_pred:
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
vae_config = model_config_params['first_stage_config']
vae_config['scale_factor'] = scale_factor
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
# Load models from safetensors if it exists, if it doesn't pytorch
if osp.exists(unet_path):
unet_state_dict = load_file(unet_path, device="cpu")
else:
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
unet_state_dict = torch.load(unet_path, map_location="cpu")
if osp.exists(vae_path):
vae_state_dict = load_file(vae_path, device="cpu")
else:
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
vae_state_dict = torch.load(vae_path, map_location="cpu")
if osp.exists(text_enc_path):
text_enc_dict = load_file(text_enc_path, device="cpu")
else:
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
# Convert the UNet model
unet_state_dict = convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
# Convert the VAE model
vae_state_dict = convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
if is_v20_model:
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
else:
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
# Put together new checkpoint
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
clip = None
vae = None
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
load_state_dict_to = []
if output_vae:
vae = VAE(scale_factor=scale_factor, config=vae_config)
w.first_stage_model = vae.first_stage_model
load_state_dict_to = [w]
if output_clip:
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
model = instantiate_from_config(config["model"])
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
model = model.half()
return ModelPatcher(model), clip, vae

View File

@ -21,6 +21,8 @@ if model_management.xformers_enabled():
import os import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
from cli_args import args
def exists(val): def exists(val):
return val is not None return val is not None
@ -474,7 +476,6 @@ class CrossAttentionPytorch(nn.Module):
return self.to_out(out) return self.to_out(out)
import sys
if model_management.xformers_enabled(): if model_management.xformers_enabled():
print("Using xformers cross attention") print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention CrossAttention = MemoryEfficientCrossAttention
@ -482,7 +483,7 @@ elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention") print("Using pytorch cross attention")
CrossAttention = CrossAttentionPytorch CrossAttention = CrossAttentionPytorch
else: else:
if "--use-split-cross-attention" in sys.argv: if args.use_split_cross_attention:
print("Using split optimization for cross attention") print("Using split optimization for cross attention")
CrossAttention = CrossAttentionDoggettx CrossAttention = CrossAttentionDoggettx
else: else:

View File

@ -9,7 +9,7 @@ from typing import Optional, Any
from ldm.modules.attention import MemoryEfficientCrossAttention from ldm.modules.attention import MemoryEfficientCrossAttention
import model_management import model_management
if model_management.xformers_enabled(): if model_management.xformers_enabled_vae():
import xformers import xformers
import xformers.ops import xformers.ops
@ -364,7 +364,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
if model_management.xformers_enabled() and attn_type == "vanilla": if model_management.xformers_enabled_vae() and attn_type == "vanilla":
attn_type = "vanilla-xformers" attn_type = "vanilla-xformers"
if model_management.pytorch_attention_enabled() and attn_type == "vanilla": if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
attn_type = "vanilla-pytorch" attn_type = "vanilla-pytorch"

View File

@ -1,4 +1,4 @@
#Taken from: https://github.com/dbolya/tomesd
import torch import torch
from typing import Tuple, Callable from typing import Tuple, Callable
@ -8,13 +8,23 @@ def do_nothing(x: torch.Tensor, mode:str=None):
return x return x
def mps_gather_workaround(input, dim, index):
if input.shape[-1] == 1:
return torch.gather(
input.unsqueeze(-1),
dim - 1 if dim < 0 else dim,
index.unsqueeze(-1)
).squeeze(-1)
else:
return torch.gather(input, dim, index)
def bipartite_soft_matching_random2d(metric: torch.Tensor, def bipartite_soft_matching_random2d(metric: torch.Tensor,
w: int, h: int, sx: int, sy: int, r: int, w: int, h: int, sx: int, sy: int, r: int,
no_rand: bool = False) -> Tuple[Callable, Callable]: no_rand: bool = False) -> Tuple[Callable, Callable]:
""" """
Partitions the tokens into src and dst and merges r tokens from src to dst. Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Args: Args:
- metric [B, N, C]: metric to use for similarity - metric [B, N, C]: metric to use for similarity
- w: image width in tokens - w: image width in tokens
@ -28,33 +38,49 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
if r <= 0: if r <= 0:
return do_nothing, do_nothing return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
with torch.no_grad(): with torch.no_grad():
hsy, wsx = h // sy, w // sx hsy, wsx = h // sy, w // sx
# For each sy by sx kernel, randomly assign one token to be dst and the rest src # For each sy by sx kernel, randomly assign one token to be dst and the rest src
idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device)
if no_rand: if no_rand:
rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64) rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
else: else:
rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device) rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype)) # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1) idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
rand_idx = idx_buffer.argsort(dim=1) idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
num_dst = int((1 / (sx*sy)) * N) # Image is not divisible by sx or sy so we need to move it into a new buffer
if (hsy * sy) < h or (wsx * sx) < w:
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
else:
idx_buffer = idx_buffer_view
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
# We're finished with these
del idx_buffer, idx_buffer_view
# rand_idx is currently dst|src, so split them
num_dst = hsy * wsx
a_idx = rand_idx[:, num_dst:, :] # src a_idx = rand_idx[:, num_dst:, :] # src
b_idx = rand_idx[:, :num_dst, :] # dst b_idx = rand_idx[:, :num_dst, :] # dst
def split(x): def split(x):
C = x.shape[-1] C = x.shape[-1]
src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C)) src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C)) dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
return src, dst return src, dst
# Cosine similarity between A and B
metric = metric / metric.norm(dim=-1, keepdim=True) metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = split(metric) a, b = split(metric)
scores = a @ b.transpose(-1, -2) scores = a @ b.transpose(-1, -2)
@ -62,19 +88,20 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
# Can't reduce more than the # tokens in src # Can't reduce more than the # tokens in src
r = min(a.shape[1], r) r = min(a.shape[1], r)
# Find the most similar greedily
node_max, node_idx = scores.max(dim=-1) node_max, node_idx = scores.max(dim=-1)
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
src_idx = edge_idx[..., :r, :] # Merged Tokens src_idx = edge_idx[..., :r, :] # Merged Tokens
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x) src, dst = split(x)
n, t1, c = src.shape n, t1, c = src.shape
unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
return torch.cat([unm, dst], dim=1) return torch.cat([unm, dst], dim=1)
@ -84,13 +111,13 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
_, _, c = unm.shape _, _, c = unm.shape
src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c)) src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
# Combine back to the original shape # Combine back to the original shape
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src) out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
return out return out
@ -100,14 +127,14 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
def get_functions(x, ratio, original_shape): def get_functions(x, ratio, original_shape):
b, c, original_h, original_w = original_shape b, c, original_h, original_w = original_shape
original_tokens = original_h * original_w original_tokens = original_h * original_w
downsample = int(math.sqrt(original_tokens // x.shape[1])) downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
stride_x = 2 stride_x = 2
stride_y = 2 stride_y = 2
max_downsample = 1 max_downsample = 1
if downsample <= max_downsample: if downsample <= max_downsample:
w = original_w // downsample w = int(math.ceil(original_w / downsample))
h = original_h // downsample h = int(math.ceil(original_h / downsample))
r = int(x.shape[1] * ratio) r = int(x.shape[1] * ratio)
no_rand = False no_rand = False
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand) m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)

View File

@ -1,36 +1,42 @@
import psutil
from enum import Enum
from cli_args import args
CPU = 0 class VRAMState(Enum):
NO_VRAM = 1 CPU = 0
LOW_VRAM = 2 NO_VRAM = 1
NORMAL_VRAM = 3 LOW_VRAM = 2
HIGH_VRAM = 4 NORMAL_VRAM = 3
MPS = 5 HIGH_VRAM = 4
MPS = 5
accelerate_enabled = False # Determine VRAM State
vram_state = NORMAL_VRAM vram_state = VRAMState.NORMAL_VRAM
set_vram_to = VRAMState.NORMAL_VRAM
total_vram = 0 total_vram = 0
total_vram_available_mb = -1 total_vram_available_mb = -1
import sys accelerate_enabled = False
import psutil xpu_available = False
forced_cpu = "--cpu" in sys.argv
set_vram_to = NORMAL_VRAM
try: try:
import torch import torch
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
except:
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
forced_normal_vram = "--normalvram" in sys.argv if not args.normalvram and not args.cpu:
if not forced_normal_vram and not forced_cpu:
if total_vram <= 4096: if total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = LOW_VRAM set_vram_to = VRAMState.LOW_VRAM
elif total_vram > total_ram * 1.1 and total_vram > 14336: elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = HIGH_VRAM vram_state = VRAMState.HIGH_VRAM
except: except:
pass pass
@ -39,34 +45,37 @@ try:
except: except:
OOM_EXCEPTION = Exception OOM_EXCEPTION = Exception
if "--disable-xformers" in sys.argv: if args.disable_xformers:
XFORMERS_IS_AVAILBLE = False XFORMERS_IS_AVAILABLE = False
else: else:
try: try:
import xformers import xformers
import xformers.ops import xformers.ops
XFORMERS_IS_AVAILBLE = True XFORMERS_IS_AVAILABLE = True
except: except:
XFORMERS_IS_AVAILBLE = False XFORMERS_IS_AVAILABLE = False
ENABLE_PYTORCH_ATTENTION = False ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
if "--use-pytorch-cross-attention" in sys.argv: if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True)
ENABLE_PYTORCH_ATTENTION = True XFORMERS_IS_AVAILABLE = False
XFORMERS_IS_AVAILBLE = False
if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM
elif args.novram:
set_vram_to = VRAMState.NO_VRAM
elif args.highvram:
vram_state = VRAMState.HIGH_VRAM
FORCE_FP32 = False
if args.force_fp32:
print("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True
if "--lowvram" in sys.argv: if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
set_vram_to = LOW_VRAM
if "--novram" in sys.argv:
set_vram_to = NO_VRAM
if "--highvram" in sys.argv:
vram_state = HIGH_VRAM
if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
try: try:
import accelerate import accelerate
accelerate_enabled = True accelerate_enabled = True
@ -81,14 +90,14 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
try: try:
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
vram_state = MPS vram_state = VRAMState.MPS
except: except:
pass pass
if forced_cpu: if args.cpu:
vram_state = CPU vram_state = VRAMState.CPU
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state]) print(f"Set vram state to: {vram_state.name}")
current_loaded_model = None current_loaded_model = None
@ -109,12 +118,12 @@ def unload_model():
model_accelerated = False model_accelerated = False
#never unload models from GPU on high vram #never unload models from GPU on high vram
if vram_state != HIGH_VRAM: if vram_state != VRAMState.HIGH_VRAM:
current_loaded_model.model.cpu() current_loaded_model.model.cpu()
current_loaded_model.unpatch_model() current_loaded_model.unpatch_model()
current_loaded_model = None current_loaded_model = None
if vram_state != HIGH_VRAM: if vram_state != VRAMState.HIGH_VRAM:
if len(current_gpu_controlnets) > 0: if len(current_gpu_controlnets) > 0:
for n in current_gpu_controlnets: for n in current_gpu_controlnets:
n.cpu() n.cpu()
@ -135,32 +144,32 @@ def load_model_gpu(model):
model.unpatch_model() model.unpatch_model()
raise e raise e
current_loaded_model = model current_loaded_model = model
if vram_state == CPU: if vram_state == VRAMState.CPU:
pass pass
elif vram_state == MPS: elif vram_state == VRAMState.MPS:
mps_device = torch.device("mps") mps_device = torch.device("mps")
real_model.to(mps_device) real_model.to(mps_device)
pass pass
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM: elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
model_accelerated = False model_accelerated = False
real_model.cuda() real_model.to(get_torch_device())
else: else:
if vram_state == NO_VRAM: if vram_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_state == LOW_VRAM: elif vram_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda") accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
model_accelerated = True model_accelerated = True
return current_loaded_model return current_loaded_model
def load_controlnet_gpu(models): def load_controlnet_gpu(models):
global current_gpu_controlnets global current_gpu_controlnets
global vram_state global vram_state
if vram_state == CPU: if vram_state == VRAMState.CPU:
return return
if vram_state == LOW_VRAM or vram_state == NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return return
@ -176,38 +185,57 @@ def load_controlnet_gpu(models):
def load_if_low_vram(model): def load_if_low_vram(model):
global vram_state global vram_state
if vram_state == LOW_VRAM or vram_state == NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.cuda() return model.to(get_torch_device())
return model return model
def unload_if_low_vram(model): def unload_if_low_vram(model):
global vram_state global vram_state
if vram_state == LOW_VRAM or vram_state == NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.cpu() return model.cpu()
return model return model
def get_torch_device(): def get_torch_device():
if vram_state == MPS: global xpu_available
if vram_state == VRAMState.MPS:
return torch.device("mps") return torch.device("mps")
if vram_state == CPU: if vram_state == VRAMState.CPU:
return torch.device("cpu") return torch.device("cpu")
else: else:
return torch.cuda.current_device() if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_autocast_device(dev): def get_autocast_device(dev):
if hasattr(dev, 'type'): if hasattr(dev, 'type'):
return dev.type return dev.type
return "cuda" return "cuda"
def xformers_enabled(): def xformers_enabled():
if vram_state == CPU: if vram_state == VRAMState.CPU:
return False return False
return XFORMERS_IS_AVAILBLE return XFORMERS_IS_AVAILABLE
def xformers_enabled_vae():
enabled = xformers_enabled()
if not enabled:
return False
try:
#0.0.18 has a bug where Nan is returned when inputs are too big (1152x1920 res images and above)
if xformers.version.__version__ == "0.0.18":
return False
except:
pass
return enabled
def pytorch_attention_enabled(): def pytorch_attention_enabled():
return ENABLE_PYTORCH_ATTENTION return ENABLE_PYTORCH_ATTENTION
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global xpu_available
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()
@ -215,12 +243,16 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = psutil.virtual_memory().available mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
else: else:
stats = torch.cuda.memory_stats(dev) if xpu_available:
mem_active = stats['active_bytes.all.current'] mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_free_torch = mem_free_total
mem_free_cuda, _ = torch.cuda.mem_get_info(dev) else:
mem_free_torch = mem_reserved - mem_active stats = torch.cuda.memory_stats(dev)
mem_free_total = mem_free_cuda + mem_free_torch mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
if torch_free_too: if torch_free_too:
return (mem_free_total, mem_free_torch) return (mem_free_total, mem_free_torch)
@ -229,7 +261,7 @@ def get_free_memory(dev=None, torch_free_too=False):
def maximum_batch_area(): def maximum_batch_area():
global vram_state global vram_state
if vram_state == NO_VRAM: if vram_state == VRAMState.NO_VRAM:
return 0 return 0
memory_free = get_free_memory() / (1024 * 1024) memory_free = get_free_memory() / (1024 * 1024)
@ -238,14 +270,18 @@ def maximum_batch_area():
def cpu_mode(): def cpu_mode():
global vram_state global vram_state
return vram_state == CPU return vram_state == VRAMState.CPU
def mps_mode(): def mps_mode():
global vram_state global vram_state
return vram_state == MPS return vram_state == VRAMState.MPS
def should_use_fp16(): def should_use_fp16():
if cpu_mode() or mps_mode(): global xpu_available
if FORCE_FP32:
return False
if cpu_mode() or mps_mode() or xpu_available:
return False #TODO ? return False #TODO ?
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():

View File

@ -348,17 +348,26 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
if 'adm' in x[1]: if 'adm' in x[1]:
adm_inputs = [] adm_inputs = []
weights = [] weights = []
noise_aug = []
adm_in = x[1]["adm"] adm_in = x[1]["adm"]
for adm_c in adm_in: for adm_c in adm_in:
adm_cond = adm_c[0].image_embeds adm_cond = adm_c[0].image_embeds
weight = adm_c[1] weight = adm_c[1]
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([0], device=device)) noise_augment = adm_c[2]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight) weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out) adm_inputs.append(adm_out)
adm_out = torch.stack(adm_inputs).sum(0) if len(noise_aug) > 1:
#TODO: Apply Noise to Embedding Mix adm_out = torch.stack(adm_inputs).sum(0)
#TODO: add a way to control this
noise_augment = 0.05
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1)
else: else:
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
x[1] = x[1].copy() x[1] = x[1].copy()

View File

@ -74,9 +74,12 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if isinstance(y, int): if isinstance(y, int):
tokens_temp += [y] tokens_temp += [y]
else: else:
embedding_weights += [y] if y.shape[0] == current_embeds.weight.shape[1]:
tokens_temp += [next_new_token] embedding_weights += [y]
next_new_token += 1 tokens_temp += [next_new_token]
next_new_token += 1
else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
out_tokens += [tokens_temp] out_tokens += [tokens_temp]
if len(embedding_weights) > 0: if len(embedding_weights) > 0:

View File

@ -0,0 +1,210 @@
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import comfy.utils
class Blend:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image1": ("IMAGE",),
"image2": ("IMAGE",),
"blend_factor": ("FLOAT", {
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01
}),
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "blend_images"
CATEGORY = "image/postprocessing"
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2)
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
image2 = image2.permute(0, 2, 3, 1)
blended_image = self.blend_mode(image1, image2, blend_mode)
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
blended_image = torch.clamp(blended_image, 0, 1)
return (blended_image,)
def blend_mode(self, img1, img2, mode):
if mode == "normal":
return img2
elif mode == "multiply":
return img1 * img2
elif mode == "screen":
return 1 - (1 - img1) * (1 - img2)
elif mode == "overlay":
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
elif mode == "soft_light":
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
else:
raise ValueError(f"Unsupported blend mode: {mode}")
def g(self, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
class Blur:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"blur_radius": ("INT", {
"default": 1,
"min": 1,
"max": 31,
"step": 1
}),
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.1
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "blur"
CATEGORY = "image/postprocessing"
def gaussian_kernel(self, kernel_size: int, sigma: float):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
if blur_radius == 0:
return (image,)
batch_size, height, width, channels = image.shape
kernel_size = blur_radius * 2 + 1
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)
blurred = blurred.permute(0, 2, 3, 1)
return (blurred,)
class Quantize:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"colors": ("INT", {
"default": 256,
"min": 1,
"max": 256,
"step": 1
}),
"dither": (["none", "floyd-steinberg"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "quantize"
CATEGORY = "image/postprocessing"
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
for b in range(batch_size):
tensor_image = image[b]
img = (tensor_image * 255).to(torch.uint8).numpy()
pil_image = Image.fromarray(img, mode='RGB')
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
result[b] = quantized_array
return (result,)
class Sharpen:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"sharpen_radius": ("INT", {
"default": 1,
"min": 1,
"max": 31,
"step": 1
}),
"alpha": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 5.0,
"step": 0.1
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "sharpen"
CATEGORY = "image/postprocessing"
def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float):
if sharpen_radius == 0:
return (image,)
batch_size, height, width, channels = image.shape
kernel_size = sharpen_radius * 2 + 1
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1
center = kernel_size // 2
kernel[center, center] = kernel_size**2
kernel *= alpha
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
sharpened = sharpened.permute(0, 2, 3, 1)
result = torch.clamp(sharpened, 0, 1)
return (result,)
NODE_CLASS_MAPPINGS = {
"ImageBlend": Blend,
"ImageBlur": Blur,
"ImageQuantize": Quantize,
"ImageSharpen": Sharpen,
}

View File

@ -23,10 +23,45 @@ folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions) folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
if not os.path.exists(input_directory):
os.makedirs(input_directory)
def set_output_directory(output_dir):
global output_directory
output_directory = output_dir
def get_output_directory():
global output_directory
return output_directory
def get_temp_directory():
global temp_directory
return temp_directory
def get_input_directory():
global input_directory
return input_directory
#NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name):
if type_name == "output":
return get_output_directory()
if type_name == "temp":
return get_temp_directory()
if type_name == "input":
return get_input_directory()
return None
def add_model_folder_path(folder_name, full_folder_path): def add_model_folder_path(folder_name, full_folder_path):
global folder_names_and_paths global folder_names_and_paths

97
main.py
View File

@ -1,56 +1,33 @@
import os
import sys
import shutil
import threading
import asyncio import asyncio
import itertools
import os
import shutil
import threading
from comfy.cli_args import args
if os.name == "nt": if os.name == "nt":
import logging import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
if __name__ == "__main__": if __name__ == "__main__":
if '--help' in sys.argv: if args.dont_upcast_attention:
print()
print("Valid Command line Arguments:")
print("\t--listen [ip]\t\t\tListen on ip or 0.0.0.0 if none given so the UI can be accessed from other computers.")
print("\t--port 8188\t\t\tSet the listen port.")
print()
print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.")
print()
print()
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.")
print("\t--disable-xformers\t\tdisables xformers")
print("\t--cuda-device 1\t\tSet the id of the cuda device this instance will use.")
print()
print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n")
print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.")
print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.")
print("\t--novram\t\t\tWhen lowvram isn't enough.")
print()
print("\t--cpu\t\t\tTo use the CPU for everything (slow).")
exit()
if '--dont-upcast-attention' in sys.argv:
print("disabling upcasting of attention") print("disabling upcasting of attention")
os.environ['ATTN_PRECISION'] = "fp16" os.environ['ATTN_PRECISION'] = "fp16"
try: if args.cuda_device is not None:
index = sys.argv.index('--cuda-device') os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
device = sys.argv[index + 1] print("Set cuda device to:", args.cuda_device)
os.environ['CUDA_VISIBLE_DEVICES'] = device
print("Set cuda device to:", device)
except:
pass
from nodes import init_custom_nodes
import execution
import server
import folder_paths
import yaml import yaml
import execution
import folder_paths
import server
from nodes import init_custom_nodes
def prompt_worker(q, server): def prompt_worker(q, server):
e = execution.PromptExecutor(server) e = execution.PromptExecutor(server)
while True: while True:
@ -109,43 +86,31 @@ if __name__ == "__main__":
hijack_progress(server) hijack_progress(server)
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
try:
address = '0.0.0.0'
p_index = sys.argv.index('--listen')
try:
ip = sys.argv[p_index + 1]
if ip[:2] != '--':
address = ip
except:
pass
except:
address = '127.0.0.1'
dont_print = False address = args.listen
if '--dont-print-server' in sys.argv:
dont_print = True dont_print = args.dont_print_server
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path): if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path) load_extra_path_config(extra_model_paths_config_path)
if '--extra-model-paths-config' in sys.argv: if args.extra_model_paths_config:
indices = [(i + 1) for i in range(len(sys.argv) - 1) if sys.argv[i] == '--extra-model-paths-config'] for config_path in itertools.chain(*args.extra_model_paths_config):
for i in indices: load_extra_path_config(config_path)
load_extra_path_config(sys.argv[i])
port = 8188 if args.output_directory:
try: output_dir = os.path.abspath(args.output_directory)
p_index = sys.argv.index('--port') print(f"Setting output directory to: {output_dir}")
port = int(sys.argv[p_index + 1]) folder_paths.set_output_directory(output_dir)
except:
pass
if '--quick-test-for-ci' in sys.argv: port = args.port
if args.quick_test_for_ci:
exit(0) exit(0)
call_on_start = None call_on_start = None
if "--windows-standalone-build" in sys.argv: if args.windows_standalone_build:
def startup_server(address, port): def startup_server(address, port):
import webbrowser import webbrowser
webbrowser.open("http://{}:{}".format(address, port)) webbrowser.open("http://{}:{}".format(address, port))

View File

@ -4,16 +4,17 @@ import os
import sys import sys
import json import json
import hashlib import hashlib
import copy
import traceback import traceback
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
import numpy as np import numpy as np
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
import comfy.diffusers_convert
import comfy.samplers import comfy.samplers
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
@ -197,7 +198,7 @@ class CheckpointLoader:
RETURN_TYPES = ("MODEL", "CLIP", "VAE") RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
CATEGORY = "loaders" CATEGORY = "advanced/loaders"
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
config_path = folder_paths.get_full_path("configs", config_name) config_path = folder_paths.get_full_path("configs", config_name)
@ -219,6 +220,30 @@ class CheckpointLoaderSimple:
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out return out
class DiffusersLoader:
@classmethod
def INPUT_TYPES(cls):
paths = []
for search_path in folder_paths.get_folder_paths("diffusers"):
if os.path.exists(search_path):
paths += next(os.walk(search_path))[1]
return {"required": {"model_path": (paths,), }}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders"
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
for search_path in folder_paths.get_folder_paths("diffusers"):
if os.path.exists(search_path):
paths = next(os.walk(search_path))[1]
if model_path in paths:
model_path = os.path.join(search_path, model_path)
break
return comfy.diffusers_convert.load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class unCLIPCheckpointLoader: class unCLIPCheckpointLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -227,7 +252,7 @@ class unCLIPCheckpointLoader:
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION") RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
CATEGORY = "_for_testing/unclip" CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
@ -445,17 +470,18 @@ class unCLIPConditioning:
return {"required": {"conditioning": ("CONDITIONING", ), return {"required": {"conditioning": ("CONDITIONING", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ), "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}} }}
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_adm" FUNCTION = "apply_adm"
CATEGORY = "_for_testing/unclip" CATEGORY = "conditioning"
def apply_adm(self, conditioning, clip_vision_output, strength): def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
c = [] c = []
for t in conditioning: for t in conditioning:
o = t[1].copy() o = t[1].copy()
x = (clip_vision_output, strength) x = (clip_vision_output, strength, noise_augmentation)
if "adm" in o: if "adm" in o:
o["adm"] = o["adm"][:] + [x] o["adm"] = o["adm"][:] + [x]
else: else:
@ -776,7 +802,7 @@ class KSamplerAdvanced:
class SaveImage: class SaveImage:
def __init__(self): def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") self.output_dir = folder_paths.get_output_directory()
self.type = "output" self.type = "output"
@classmethod @classmethod
@ -828,9 +854,6 @@ class SaveImage:
os.makedirs(full_output_folder, exist_ok=True) os.makedirs(full_output_folder, exist_ok=True)
counter = 1 counter = 1
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
results = list() results = list()
for image in images: for image in images:
i = 255. * image.cpu().numpy() i = 255. * image.cpu().numpy()
@ -855,7 +878,7 @@ class SaveImage:
class PreviewImage(SaveImage): class PreviewImage(SaveImage):
def __init__(self): def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") self.output_dir = folder_paths.get_temp_directory()
self.type = "temp" self.type = "temp"
@classmethod @classmethod
@ -866,13 +889,11 @@ class PreviewImage(SaveImage):
} }
class LoadImage: class LoadImage:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
if not os.path.exists(s.input_dir): input_dir = folder_paths.get_input_directory()
os.makedirs(s.input_dir)
return {"required": return {"required":
{"image": (sorted(os.listdir(s.input_dir)), )}, {"image": (sorted(os.listdir(input_dir)), )},
} }
CATEGORY = "image" CATEGORY = "image"
@ -880,7 +901,8 @@ class LoadImage:
RETURN_TYPES = ("IMAGE", "MASK") RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image): def load_image(self, image):
image_path = os.path.join(self.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
i = Image.open(image_path) i = Image.open(image_path)
image = i.convert("RGB") image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
@ -894,18 +916,19 @@ class LoadImage:
@classmethod @classmethod
def IS_CHANGED(s, image): def IS_CHANGED(s, image):
image_path = os.path.join(s.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
m = hashlib.sha256() m = hashlib.sha256()
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
m.update(f.read()) m.update(f.read())
return m.digest().hex() return m.digest().hex()
class LoadImageMask: class LoadImageMask:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
return {"required": return {"required":
{"image": (sorted(os.listdir(s.input_dir)), ), {"image": (sorted(os.listdir(input_dir)), ),
"channel": (["alpha", "red", "green", "blue"], ),} "channel": (["alpha", "red", "green", "blue"], ),}
} }
@ -914,7 +937,8 @@ class LoadImageMask:
RETURN_TYPES = ("MASK",) RETURN_TYPES = ("MASK",)
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image, channel): def load_image(self, image, channel):
image_path = os.path.join(self.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
i = Image.open(image_path) i = Image.open(image_path)
mask = None mask = None
c = channel[0].upper() c = channel[0].upper()
@ -929,7 +953,8 @@ class LoadImageMask:
@classmethod @classmethod
def IS_CHANGED(s, image, channel): def IS_CHANGED(s, image, channel):
image_path = os.path.join(s.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
m = hashlib.sha256() m = hashlib.sha256()
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
m.update(f.read()) m.update(f.read())
@ -1037,7 +1062,6 @@ class ImagePadForOutpaint:
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"KSampler": KSampler, "KSampler": KSampler,
"CheckpointLoader": CheckpointLoader,
"CheckpointLoaderSimple": CheckpointLoaderSimple, "CheckpointLoaderSimple": CheckpointLoaderSimple,
"CLIPTextEncode": CLIPTextEncode, "CLIPTextEncode": CLIPTextEncode,
"CLIPSetLastLayer": CLIPSetLastLayer, "CLIPSetLastLayer": CLIPSetLastLayer,
@ -1076,6 +1100,8 @@ NODE_CLASS_MAPPINGS = {
"VAEEncodeTiled": VAEEncodeTiled, "VAEEncodeTiled": VAEEncodeTiled,
"TomePatchModel": TomePatchModel, "TomePatchModel": TomePatchModel,
"unCLIPCheckpointLoader": unCLIPCheckpointLoader, "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader,
} }
def load_custom_node(module_path): def load_custom_node(module_path):
@ -1112,4 +1138,5 @@ def load_custom_nodes():
def init_custom_nodes(): def init_custom_nodes():
load_custom_nodes() load_custom_nodes()
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))

View File

@ -47,7 +47,7 @@
" !git pull\n", " !git pull\n",
"\n", "\n",
"!echo -= Install dependencies =-\n", "!echo -= Install dependencies =-\n",
"!pip install xformers==0.0.16 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu117" "!pip install xformers -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118"
] ]
}, },
{ {

View File

@ -4,7 +4,7 @@ torchsde
einops einops
open-clip-torch open-clip-torch
transformers>=4.25.1 transformers>=4.25.1
safetensors safetensors>=0.3.0
pytorch_lightning pytorch_lightning
aiohttp aiohttp
accelerate accelerate

View File

@ -18,6 +18,7 @@ except ImportError:
sys.exit() sys.exit()
import mimetypes import mimetypes
from comfy.cli_args import args
@web.middleware @web.middleware
@ -27,6 +28,23 @@ async def cache_control(request: web.Request, handler):
response.headers.setdefault('Cache-Control', 'no-cache') response.headers.setdefault('Cache-Control', 'no-cache')
return response return response
def create_cors_middleware(allowed_origin: str):
@web.middleware
async def cors_middleware(request: web.Request, handler):
if request.method == "OPTIONS":
# Pre-flight request. Reply successfully:
response = web.Response()
else:
response = await handler(request)
response.headers['Access-Control-Allow-Origin'] = allowed_origin
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response.headers['Access-Control-Allow-Credentials'] = 'true'
return response
return cors_middleware
class PromptServer(): class PromptServer():
def __init__(self, loop): def __init__(self, loop):
PromptServer.instance = self PromptServer.instance = self
@ -37,7 +55,12 @@ class PromptServer():
self.loop = loop self.loop = loop
self.messages = asyncio.Queue() self.messages = asyncio.Queue()
self.number = 0 self.number = 0
self.app = web.Application(client_max_size=20971520, middlewares=[cache_control])
middlewares = [cache_control]
if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header))
self.app = web.Application(client_max_size=20971520, middlewares=middlewares)
self.sockets = dict() self.sockets = dict()
self.web_root = os.path.join(os.path.dirname( self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "web") os.path.realpath(__file__)), "web")
@ -89,7 +112,7 @@ class PromptServer():
@routes.post("/upload/image") @routes.post("/upload/image")
async def upload_image(request): async def upload_image(request):
upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") upload_dir = folder_paths.get_input_directory()
if not os.path.exists(upload_dir): if not os.path.exists(upload_dir):
os.makedirs(upload_dir) os.makedirs(upload_dir)
@ -122,10 +145,10 @@ class PromptServer():
async def view_image(request): async def view_image(request):
if "filename" in request.rel_url.query: if "filename" in request.rel_url.query:
type = request.rel_url.query.get("type", "output") type = request.rel_url.query.get("type", "output")
if type not in ["output", "input", "temp"]: output_dir = folder_paths.get_directory_by_type(type)
if output_dir is None:
return web.Response(status=400) return web.Response(status=400)
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type)
if "subfolder" in request.rel_url.query: if "subfolder" in request.rel_url.query:
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"]) full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir: if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:

View File

@ -0,0 +1,137 @@
import { app } from "/scripts/app.js";
// Adds filtering to combo context menus
const id = "Comfy.ContextMenuFilter";
app.registerExtension({
name: id,
init() {
const ctxMenu = LiteGraph.ContextMenu;
LiteGraph.ContextMenu = function (values, options) {
const ctx = ctxMenu.call(this, values, options);
// If we are a dark menu (only used for combo boxes) then add a filter input
if (options?.className === "dark" && values?.length > 10) {
const filter = document.createElement("input");
Object.assign(filter.style, {
width: "calc(100% - 10px)",
border: "0",
boxSizing: "border-box",
background: "#333",
border: "1px solid #999",
margin: "0 0 5px 5px",
color: "#fff",
});
filter.placeholder = "Filter list";
this.root.prepend(filter);
let selectedIndex = 0;
let items = this.root.querySelectorAll(".litemenu-entry");
let itemCount = items.length;
let selectedItem;
// Apply highlighting to the selected item
function updateSelected() {
if (selectedItem) {
selectedItem.style.setProperty("background-color", "");
selectedItem.style.setProperty("color", "");
}
selectedItem = items[selectedIndex];
if (selectedItem) {
selectedItem.style.setProperty("background-color", "#ccc", "important");
selectedItem.style.setProperty("color", "#000", "important");
}
}
const positionList = () => {
const rect = this.root.getBoundingClientRect();
// If the top is off screen then shift the element with scaling applied
if (rect.top < 0) {
const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight;
const shift = (this.root.clientHeight * scale) / 2;
this.root.style.top = -shift + "px";
}
}
updateSelected();
// Arrow up/down to select items
filter.addEventListener("keydown", (e) => {
if (e.key === "ArrowUp") {
if (selectedIndex === 0) {
selectedIndex = itemCount - 1;
} else {
selectedIndex--;
}
updateSelected();
e.preventDefault();
} else if (e.key === "ArrowDown") {
if (selectedIndex === itemCount - 1) {
selectedIndex = 0;
} else {
selectedIndex++;
}
updateSelected();
e.preventDefault();
} else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) {
selectedItem.click();
} else if(e.key === "Escape") {
this.close();
}
});
filter.addEventListener("input", () => {
// Hide all items that dont match our filter
const term = filter.value.toLocaleLowerCase();
items = this.root.querySelectorAll(".litemenu-entry");
// When filtering recompute which items are visible for arrow up/down
// Try and maintain selection
let visibleItems = [];
for (const item of items) {
const visible = !term || item.textContent.toLocaleLowerCase().includes(term);
if (visible) {
item.style.display = "block";
if (item === selectedItem) {
selectedIndex = visibleItems.length;
}
visibleItems.push(item);
} else {
item.style.display = "none";
if (item === selectedItem) {
selectedIndex = 0;
}
}
}
items = visibleItems;
updateSelected();
// If we have an event then we can try and position the list under the source
if (options.event) {
let top = options.event.clientY - 10;
const bodyRect = document.body.getBoundingClientRect();
const rootRect = this.root.getBoundingClientRect();
if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) {
top = Math.max(0, bodyRect.height - rootRect.height - 10);
}
this.root.style.top = top + "px";
positionList();
}
});
requestAnimationFrame(() => {
// Focus the filter box when opening
filter.focus();
positionList();
});
}
return ctx;
};
LiteGraph.ContextMenu.prototype = ctxMenu.prototype;
},
});

View File

@ -30,7 +30,8 @@ app.registerExtension({
} }
// Overwrite the value in the serialized workflow pnginfo // Overwrite the value in the serialized workflow pnginfo
workflowNode.widgets_values[widgetIndex] = prompt; if (workflowNode?.widgets_values)
workflowNode.widgets_values[widgetIndex] = prompt;
return prompt; return prompt;
}; };

View File

@ -3,10 +3,10 @@ import { app } from "/scripts/app.js";
// Inverts the scrolling of context menus // Inverts the scrolling of context menus
const id = "Comfy.InvertMenuScrolling"; const id = "Comfy.InvertMenuScrolling";
const ctxMenu = LiteGraph.ContextMenu;
app.registerExtension({ app.registerExtension({
name: id, name: id,
init() { init() {
const ctxMenu = LiteGraph.ContextMenu;
const replace = () => { const replace = () => {
LiteGraph.ContextMenu = function (values, options) { LiteGraph.ContextMenu = function (values, options) {
options = options || {}; options = options || {};

View File

@ -11,11 +11,14 @@ app.registerExtension({
this.properties = {}; this.properties = {};
} }
this.properties.showOutputText = RerouteNode.defaultVisibility; this.properties.showOutputText = RerouteNode.defaultVisibility;
this.properties.horizontal = false;
this.addInput("", "*"); this.addInput("", "*");
this.addOutput(this.properties.showOutputText ? "*" : "", "*"); this.addOutput(this.properties.showOutputText ? "*" : "", "*");
this.onConnectionsChange = function (type, index, connected, link_info) { this.onConnectionsChange = function (type, index, connected, link_info) {
this.applyOrientation();
// Prevent multiple connections to different types when we have no input // Prevent multiple connections to different types when we have no input
if (connected && type === LiteGraph.OUTPUT) { if (connected && type === LiteGraph.OUTPUT) {
// Ignore wildcard nodes as these will be updated to real types // Ignore wildcard nodes as these will be updated to real types
@ -49,13 +52,13 @@ app.registerExtension({
currentNode = null; currentNode = null;
} }
else { else {
// Move the previous node // Move the previous node
currentNode = node; currentNode = node;
} }
} else { } else {
// We've found the end // We've found the end
inputNode = currentNode; inputNode = currentNode;
inputType = node.outputs[link.origin_slot].type; inputType = node.outputs[link.origin_slot]?.type ?? null;
break; break;
} }
} else { } else {
@ -87,7 +90,7 @@ app.registerExtension({
updateNodes.push(node); updateNodes.push(node);
} else { } else {
// We've found an output // We've found an output
const nodeOutType = node.inputs[link.target_slot].type; const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null;
if (inputType && nodeOutType !== inputType) { if (inputType && nodeOutType !== inputType) {
// The output doesnt match our input so disconnect it // The output doesnt match our input so disconnect it
node.disconnectInput(link.target_slot); node.disconnectInput(link.target_slot);
@ -112,6 +115,7 @@ app.registerExtension({
node.__outputType = displayType; node.__outputType = displayType;
node.outputs[0].name = node.properties.showOutputText ? displayType : ""; node.outputs[0].name = node.properties.showOutputText ? displayType : "";
node.size = node.computeSize(); node.size = node.computeSize();
node.applyOrientation();
for (const l of node.outputs[0].links || []) { for (const l of node.outputs[0].links || []) {
const link = app.graph.links[l]; const link = app.graph.links[l];
@ -153,6 +157,7 @@ app.registerExtension({
this.outputs[0].name = ""; this.outputs[0].name = "";
} }
this.size = this.computeSize(); this.size = this.computeSize();
this.applyOrientation();
app.graph.setDirtyCanvas(true, true); app.graph.setDirtyCanvas(true, true);
}, },
}, },
@ -161,9 +166,32 @@ app.registerExtension({
callback: () => { callback: () => {
RerouteNode.setDefaultTextVisibility(!RerouteNode.defaultVisibility); RerouteNode.setDefaultTextVisibility(!RerouteNode.defaultVisibility);
}, },
},
{
// naming is inverted with respect to LiteGraphNode.horizontal
// LiteGraphNode.horizontal == true means that
// each slot in the inputs and outputs are layed out horizontally,
// which is the opposite of the visual orientation of the inputs and outputs as a node
content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"),
callback: () => {
this.properties.horizontal = !this.properties.horizontal;
this.applyOrientation();
},
} }
); );
} }
applyOrientation() {
this.horizontal = this.properties.horizontal;
if (this.horizontal) {
// we correct the input position, because LiteGraphNode.horizontal
// doesn't account for title presence
// which reroute nodes don't have
this.inputs[0].pos = [this.size[0] / 2, 0];
} else {
delete this.inputs[0].pos;
}
app.graph.setDirtyCanvas(true, true);
}
computeSize() { computeSize() {
return [ return [

View File

@ -2,6 +2,7 @@
<html lang="en"> <html lang="en">
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<title>ComfyUI</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no"> <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
<link rel="stylesheet" type="text/css" href="lib/litegraph.css" /> <link rel="stylesheet" type="text/css" href="lib/litegraph.css" />
<link rel="stylesheet" type="text/css" href="style.css" /> <link rel="stylesheet" type="text/css" href="style.css" />

View File

@ -89,6 +89,7 @@
NO_TITLE: 1, NO_TITLE: 1,
TRANSPARENT_TITLE: 2, TRANSPARENT_TITLE: 2,
AUTOHIDE_TITLE: 3, AUTOHIDE_TITLE: 3,
VERTICAL_LAYOUT: "vertical", // arrange nodes vertically
proxy: null, //used to redirect calls proxy: null, //used to redirect calls
node_images_path: "", node_images_path: "",
@ -125,14 +126,14 @@
registered_slot_out_types: {}, // slot types for nodeclass registered_slot_out_types: {}, // slot types for nodeclass
slot_types_in: [], // slot types IN slot_types_in: [], // slot types IN
slot_types_out: [], // slot types OUT slot_types_out: [], // slot types OUT
slot_types_default_in: [], // specify for each IN slot type a(/many) deafult node(s), use single string, array, or object (with node, title, parameters, ..) like for search slot_types_default_in: [], // specify for each IN slot type a(/many) default node(s), use single string, array, or object (with node, title, parameters, ..) like for search
slot_types_default_out: [], // specify for each OUT slot type a(/many) deafult node(s), use single string, array, or object (with node, title, parameters, ..) like for search slot_types_default_out: [], // specify for each OUT slot type a(/many) default node(s), use single string, array, or object (with node, title, parameters, ..) like for search
alt_drag_do_clone_nodes: false, // [true!] very handy, ALT click to clone and drag the new node alt_drag_do_clone_nodes: false, // [true!] very handy, ALT click to clone and drag the new node
do_add_triggers_slots: false, // [true!] will create and connect event slots when using action/events connections, !WILL CHANGE node mode when using onTrigger (enable mode colors), onExecuted does not need this do_add_triggers_slots: false, // [true!] will create and connect event slots when using action/events connections, !WILL CHANGE node mode when using onTrigger (enable mode colors), onExecuted does not need this
allow_multi_output_for_events: true, // [false!] being events, it is strongly reccomended to use them sequentually, one by one allow_multi_output_for_events: true, // [false!] being events, it is strongly reccomended to use them sequentially, one by one
middle_click_slot_add_default_node: false, //[true!] allows to create and connect a ndoe clicking with the third button (wheel) middle_click_slot_add_default_node: false, //[true!] allows to create and connect a ndoe clicking with the third button (wheel)
@ -158,80 +159,67 @@
console.log("Node registered: " + type); console.log("Node registered: " + type);
} }
var categories = type.split("/"); const classname = base_class.name;
var classname = base_class.name;
var pos = type.lastIndexOf("/"); const pos = type.lastIndexOf("/");
base_class.category = type.substr(0, pos); base_class.category = type.substring(0, pos);
if (!base_class.title) { if (!base_class.title) {
base_class.title = classname; base_class.title = classname;
} }
//info.name = name.substr(pos+1,name.length - pos);
//extend class //extend class
if (base_class.prototype) { for (var i in LGraphNode.prototype) {
//is a class if (!base_class.prototype[i]) {
for (var i in LGraphNode.prototype) { base_class.prototype[i] = LGraphNode.prototype[i];
if (!base_class.prototype[i]) {
base_class.prototype[i] = LGraphNode.prototype[i];
}
} }
} }
var prev = this.registered_node_types[type]; const prev = this.registered_node_types[type];
if(prev) if(prev) {
console.log("replacing node type: " + type); console.log("replacing node type: " + type);
else }
{ if( !Object.prototype.hasOwnProperty.call( base_class.prototype, "shape") ) {
if( !Object.hasOwnProperty( base_class.prototype, "shape") ) Object.defineProperty(base_class.prototype, "shape", {
Object.defineProperty(base_class.prototype, "shape", { set: function(v) {
set: function(v) { switch (v) {
switch (v) { case "default":
case "default": delete this._shape;
delete this._shape; break;
break; case "box":
case "box": this._shape = LiteGraph.BOX_SHAPE;
this._shape = LiteGraph.BOX_SHAPE; break;
break; case "round":
case "round": this._shape = LiteGraph.ROUND_SHAPE;
this._shape = LiteGraph.ROUND_SHAPE; break;
break; case "circle":
case "circle": this._shape = LiteGraph.CIRCLE_SHAPE;
this._shape = LiteGraph.CIRCLE_SHAPE; break;
break; case "card":
case "card": this._shape = LiteGraph.CARD_SHAPE;
this._shape = LiteGraph.CARD_SHAPE; break;
break; default:
default: this._shape = v;
this._shape = v; }
} },
}, get: function() {
get: function(v) { return this._shape;
return this._shape; },
}, enumerable: true,
enumerable: true, configurable: true
configurable: true });
});
//warnings //used to know which nodes to create when dragging files to the canvas
if (base_class.prototype.onPropertyChange) { if (base_class.supported_extensions) {
console.warn( for (let i in base_class.supported_extensions) {
"LiteGraph node class " + const ext = base_class.supported_extensions[i];
type + if(ext && ext.constructor === String) {
" has onPropertyChange method, it must be called onPropertyChanged with d at the end" this.node_types_by_file_extension[ ext.toLowerCase() ] = base_class;
); }
} }
}
//used to know which nodes create when dragging files to the canvas }
if (base_class.supported_extensions) {
for (var i in base_class.supported_extensions) {
var ext = base_class.supported_extensions[i];
if(ext && ext.constructor === String)
this.node_types_by_file_extension[ ext.toLowerCase() ] = base_class;
}
}
}
this.registered_node_types[type] = base_class; this.registered_node_types[type] = base_class;
if (base_class.constructor.name) { if (base_class.constructor.name) {
@ -252,19 +240,11 @@
" has onPropertyChange method, it must be called onPropertyChanged with d at the end" " has onPropertyChange method, it must be called onPropertyChanged with d at the end"
); );
} }
//used to know which nodes create when dragging files to the canvas
if (base_class.supported_extensions) {
for (var i=0; i < base_class.supported_extensions.length; i++) {
var ext = base_class.supported_extensions[i];
if(ext && ext.constructor === String)
this.node_types_by_file_extension[ ext.toLowerCase() ] = base_class;
}
}
// TODO one would want to know input and ouput :: this would allow trought registerNodeAndSlotType to get all the slots types // TODO one would want to know input and ouput :: this would allow through registerNodeAndSlotType to get all the slots types
//console.debug("Registering "+type); if (this.auto_load_slot_types) {
if (this.auto_load_slot_types) nodeTmp = new base_class(base_class.title || "tmpnode"); new base_class(base_class.title || "tmpnode");
}
}, },
/** /**
@ -1260,37 +1240,39 @@
* Positions every node in a more readable manner * Positions every node in a more readable manner
* @method arrange * @method arrange
*/ */
LGraph.prototype.arrange = function(margin) { LGraph.prototype.arrange = function (margin, layout) {
margin = margin || 100; margin = margin || 100;
var nodes = this.computeExecutionOrder(false, true); const nodes = this.computeExecutionOrder(false, true);
var columns = []; const columns = [];
for (var i = 0; i < nodes.length; ++i) { for (let i = 0; i < nodes.length; ++i) {
var node = nodes[i]; const node = nodes[i];
var col = node._level || 1; const col = node._level || 1;
if (!columns[col]) { if (!columns[col]) {
columns[col] = []; columns[col] = [];
} }
columns[col].push(node); columns[col].push(node);
} }
var x = margin; let x = margin;
for (var i = 0; i < columns.length; ++i) { for (let i = 0; i < columns.length; ++i) {
var column = columns[i]; const column = columns[i];
if (!column) { if (!column) {
continue; continue;
} }
var max_size = 100; let max_size = 100;
var y = margin + LiteGraph.NODE_TITLE_HEIGHT; let y = margin + LiteGraph.NODE_TITLE_HEIGHT;
for (var j = 0; j < column.length; ++j) { for (let j = 0; j < column.length; ++j) {
var node = column[j]; const node = column[j];
node.pos[0] = x; node.pos[0] = (layout == LiteGraph.VERTICAL_LAYOUT) ? y : x;
node.pos[1] = y; node.pos[1] = (layout == LiteGraph.VERTICAL_LAYOUT) ? x : y;
if (node.size[0] > max_size) { const max_size_index = (layout == LiteGraph.VERTICAL_LAYOUT) ? 1 : 0;
max_size = node.size[0]; if (node.size[max_size_index] > max_size) {
max_size = node.size[max_size_index];
} }
y += node.size[1] + margin + LiteGraph.NODE_TITLE_HEIGHT; const node_size_index = (layout == LiteGraph.VERTICAL_LAYOUT) ? 0 : 1;
y += node.size[node_size_index] + margin + LiteGraph.NODE_TITLE_HEIGHT;
} }
x += max_size + margin; x += max_size + margin;
} }
@ -2468,43 +2450,34 @@
this.title = this.constructor.title; this.title = this.constructor.title;
} }
if (this.onConnectionsChange) { if (this.inputs) {
if (this.inputs) { for (var i = 0; i < this.inputs.length; ++i) {
for (var i = 0; i < this.inputs.length; ++i) { var input = this.inputs[i];
var input = this.inputs[i]; var link_info = this.graph ? this.graph.links[input.link] : null;
var link_info = this.graph if (this.onConnectionsChange)
? this.graph.links[input.link] this.onConnectionsChange( LiteGraph.INPUT, i, true, link_info, input ); //link_info has been created now, so its updated
: null;
this.onConnectionsChange(
LiteGraph.INPUT,
i,
true,
link_info,
input
); //link_info has been created now, so its updated
}
}
if (this.outputs) { if( this.onInputAdded )
for (var i = 0; i < this.outputs.length; ++i) { this.onInputAdded(input);
var output = this.outputs[i];
if (!output.links) { }
continue; }
}
for (var j = 0; j < output.links.length; ++j) { if (this.outputs) {
var link_info = this.graph for (var i = 0; i < this.outputs.length; ++i) {
? this.graph.links[output.links[j]] var output = this.outputs[i];
: null; if (!output.links) {
this.onConnectionsChange( continue;
LiteGraph.OUTPUT, }
i, for (var j = 0; j < output.links.length; ++j) {
true, var link_info = this.graph ? this.graph.links[output.links[j]] : null;
link_info, if (this.onConnectionsChange)
output this.onConnectionsChange( LiteGraph.OUTPUT, i, true, link_info, output ); //link_info has been created now, so its updated
); //link_info has been created now, so its updated }
}
} if( this.onOutputAdded )
} this.onOutputAdded(output);
}
} }
if( this.widgets ) if( this.widgets )
@ -3200,6 +3173,15 @@
return; return;
} }
if(slot == null)
{
console.error("slot must be a number");
return;
}
if(slot.constructor !== Number)
console.warn("slot must be a number, use node.trigger('name') if you want to use a string");
var output = this.outputs[slot]; var output = this.outputs[slot];
if (!output) { if (!output) {
return; return;
@ -3346,26 +3328,26 @@
* @param {Object} extra_info this can be used to have special properties of an output (label, special color, position, etc) * @param {Object} extra_info this can be used to have special properties of an output (label, special color, position, etc)
*/ */
LGraphNode.prototype.addOutput = function(name, type, extra_info) { LGraphNode.prototype.addOutput = function(name, type, extra_info) {
var o = { name: name, type: type, links: null }; var output = { name: name, type: type, links: null };
if (extra_info) { if (extra_info) {
for (var i in extra_info) { for (var i in extra_info) {
o[i] = extra_info[i]; output[i] = extra_info[i];
} }
} }
if (!this.outputs) { if (!this.outputs) {
this.outputs = []; this.outputs = [];
} }
this.outputs.push(o); this.outputs.push(output);
if (this.onOutputAdded) { if (this.onOutputAdded) {
this.onOutputAdded(o); this.onOutputAdded(output);
} }
if (LiteGraph.auto_load_slot_types) LiteGraph.registerNodeAndSlotType(this,type,true); if (LiteGraph.auto_load_slot_types) LiteGraph.registerNodeAndSlotType(this,type,true);
this.setSize( this.computeSize() ); this.setSize( this.computeSize() );
this.setDirtyCanvas(true, true); this.setDirtyCanvas(true, true);
return o; return output;
}; };
/** /**
@ -3437,10 +3419,10 @@
*/ */
LGraphNode.prototype.addInput = function(name, type, extra_info) { LGraphNode.prototype.addInput = function(name, type, extra_info) {
type = type || 0; type = type || 0;
var o = { name: name, type: type, link: null }; var input = { name: name, type: type, link: null };
if (extra_info) { if (extra_info) {
for (var i in extra_info) { for (var i in extra_info) {
o[i] = extra_info[i]; input[i] = extra_info[i];
} }
} }
@ -3448,17 +3430,17 @@
this.inputs = []; this.inputs = [];
} }
this.inputs.push(o); this.inputs.push(input);
this.setSize( this.computeSize() ); this.setSize( this.computeSize() );
if (this.onInputAdded) { if (this.onInputAdded) {
this.onInputAdded(o); this.onInputAdded(input);
} }
LiteGraph.registerNodeAndSlotType(this,type); LiteGraph.registerNodeAndSlotType(this,type);
this.setDirtyCanvas(true, true); this.setDirtyCanvas(true, true);
return o; return input;
}; };
/** /**
@ -5210,6 +5192,7 @@ LGraphNode.prototype.executeAction = function(action)
this.allow_dragcanvas = true; this.allow_dragcanvas = true;
this.allow_dragnodes = true; this.allow_dragnodes = true;
this.allow_interaction = true; //allow to control widgets, buttons, collapse, etc this.allow_interaction = true; //allow to control widgets, buttons, collapse, etc
this.multi_select = false; //allow selecting multi nodes without pressing extra keys
this.allow_searchbox = true; this.allow_searchbox = true;
this.allow_reconnect_links = true; //allows to change a connection with having to redo it again this.allow_reconnect_links = true; //allows to change a connection with having to redo it again
this.align_to_grid = false; //snap to grid this.align_to_grid = false; //snap to grid
@ -5435,7 +5418,7 @@ LGraphNode.prototype.executeAction = function(action)
}; };
/** /**
* returns the visualy active graph (in case there are more in the stack) * returns the visually active graph (in case there are more in the stack)
* @method getCurrentGraph * @method getCurrentGraph
* @return {LGraph} the active graph * @return {LGraph} the active graph
*/ */
@ -6060,9 +6043,13 @@ LGraphNode.prototype.executeAction = function(action)
this.graph.beforeChange(); this.graph.beforeChange();
this.node_dragged = node; this.node_dragged = node;
} }
if (!this.selected_nodes[node.id]) { this.processNodeSelected(node, e);
this.processNodeSelected(node, e); } else { // double-click
} /**
* Don't call the function if the block is already selected.
* Otherwise, it could cause the block to be unselected while its panel is open.
*/
if (!node.is_selected) this.processNodeSelected(node, e);
} }
this.dirty_canvas = true; this.dirty_canvas = true;
@ -6474,6 +6461,10 @@ LGraphNode.prototype.executeAction = function(action)
var n = this.selected_nodes[i]; var n = this.selected_nodes[i];
n.pos[0] += delta[0] / this.ds.scale; n.pos[0] += delta[0] / this.ds.scale;
n.pos[1] += delta[1] / this.ds.scale; n.pos[1] += delta[1] / this.ds.scale;
if (!n.is_selected) this.processNodeSelected(n, e); /*
* Don't call the function if the block is already selected.
* Otherwise, it could cause the block to be unselected while dragging.
*/
} }
this.dirty_canvas = true; this.dirty_canvas = true;
@ -7287,7 +7278,7 @@ LGraphNode.prototype.executeAction = function(action)
}; };
LGraphCanvas.prototype.processNodeSelected = function(node, e) { LGraphCanvas.prototype.processNodeSelected = function(node, e) {
this.selectNode(node, e && (e.shiftKey||e.ctrlKey)); this.selectNode(node, e && (e.shiftKey || e.ctrlKey || this.multi_select));
if (this.onNodeSelected) { if (this.onNodeSelected) {
this.onNodeSelected(node); this.onNodeSelected(node);
} }
@ -7323,6 +7314,7 @@ LGraphNode.prototype.executeAction = function(action)
for (var i in nodes) { for (var i in nodes) {
var node = nodes[i]; var node = nodes[i];
if (node.is_selected) { if (node.is_selected) {
this.deselectNode(node);
continue; continue;
} }
@ -7489,8 +7481,8 @@ LGraphNode.prototype.executeAction = function(action)
clientY_rel = e.clientY; clientY_rel = e.clientY;
} }
e.deltaX = clientX_rel - this.last_mouse_position[0]; // e.deltaX = clientX_rel - this.last_mouse_position[0];
e.deltaY = clientY_rel- this.last_mouse_position[1]; // e.deltaY = clientY_rel- this.last_mouse_position[1];
this.last_mouse_position[0] = clientX_rel; this.last_mouse_position[0] = clientX_rel;
this.last_mouse_position[1] = clientY_rel; this.last_mouse_position[1] = clientY_rel;
@ -9742,13 +9734,17 @@ LGraphNode.prototype.executeAction = function(action)
ctx.fillRect(margin, y, widget_width - margin * 2, H); ctx.fillRect(margin, y, widget_width - margin * 2, H);
var range = w.options.max - w.options.min; var range = w.options.max - w.options.min;
var nvalue = (w.value - w.options.min) / range; var nvalue = (w.value - w.options.min) / range;
ctx.fillStyle = active_widget == w ? "#89A" : "#678"; if(nvalue < 0.0) nvalue = 0.0;
if(nvalue > 1.0) nvalue = 1.0;
ctx.fillStyle = w.options.hasOwnProperty("slider_color") ? w.options.slider_color : (active_widget == w ? "#89A" : "#678");
ctx.fillRect(margin, y, nvalue * (widget_width - margin * 2), H); ctx.fillRect(margin, y, nvalue * (widget_width - margin * 2), H);
if(show_text && !w.disabled) if(show_text && !w.disabled)
ctx.strokeRect(margin, y, widget_width - margin * 2, H); ctx.strokeRect(margin, y, widget_width - margin * 2, H);
if (w.marker) { if (w.marker) {
var marker_nvalue = (w.marker - w.options.min) / range; var marker_nvalue = (w.marker - w.options.min) / range;
ctx.fillStyle = "#AA9"; if(marker_nvalue < 0.0) marker_nvalue = 0.0;
if(marker_nvalue > 1.0) marker_nvalue = 1.0;
ctx.fillStyle = w.options.hasOwnProperty("marker_color") ? w.options.marker_color : "#AA9";
ctx.fillRect( margin + marker_nvalue * (widget_width - margin * 2), y, 2, H ); ctx.fillRect( margin + marker_nvalue * (widget_width - margin * 2), y, 2, H );
} }
if (show_text) { if (show_text) {
@ -9915,6 +9911,7 @@ LGraphNode.prototype.executeAction = function(action)
case "slider": case "slider":
var range = w.options.max - w.options.min; var range = w.options.max - w.options.min;
var nvalue = Math.clamp((x - 15) / (widget_width - 30), 0, 1); var nvalue = Math.clamp((x - 15) / (widget_width - 30), 0, 1);
if(w.options.read_only) break;
w.value = w.options.min + (w.options.max - w.options.min) * nvalue; w.value = w.options.min + (w.options.max - w.options.min) * nvalue;
if (w.callback) { if (w.callback) {
setTimeout(function() { setTimeout(function() {
@ -9927,7 +9924,8 @@ LGraphNode.prototype.executeAction = function(action)
case "combo": case "combo":
var old_value = w.value; var old_value = w.value;
if (event.type == LiteGraph.pointerevents_method+"move" && w.type == "number") { if (event.type == LiteGraph.pointerevents_method+"move" && w.type == "number") {
w.value += event.deltaX * 0.1 * (w.options.step || 1); if(event.deltaX)
w.value += event.deltaX * 0.1 * (w.options.step || 1);
if ( w.options.min != null && w.value < w.options.min ) { if ( w.options.min != null && w.value < w.options.min ) {
w.value = w.options.min; w.value = w.options.min;
} }
@ -9994,6 +9992,12 @@ LGraphNode.prototype.executeAction = function(action)
var delta = x < 40 ? -1 : x > widget_width - 40 ? 1 : 0; var delta = x < 40 ? -1 : x > widget_width - 40 ? 1 : 0;
if (event.click_time < 200 && delta == 0) { if (event.click_time < 200 && delta == 0) {
this.prompt("Value",w.value,function(v) { this.prompt("Value",w.value,function(v) {
// check if v is a valid equation or a number
if (/^[0-9+\-*/()\s]+$/.test(v)) {
try {//solve the equation if possible
v = eval(v);
} catch (e) { }
}
this.value = Number(v); this.value = Number(v);
inner_value_change(this, this.value); inner_value_change(this, this.value);
}.bind(w), }.bind(w),
@ -10022,7 +10026,6 @@ LGraphNode.prototype.executeAction = function(action)
case "text": case "text":
if (event.type == LiteGraph.pointerevents_method+"down") { if (event.type == LiteGraph.pointerevents_method+"down") {
this.prompt("Value",w.value,function(v) { this.prompt("Value",w.value,function(v) {
this.value = v;
inner_value_change(this, v); inner_value_change(this, v);
}.bind(w), }.bind(w),
event,w.options ? w.options.multiline : false ); event,w.options ? w.options.multiline : false );
@ -10047,6 +10050,9 @@ LGraphNode.prototype.executeAction = function(action)
}//end for }//end for
function inner_value_change(widget, value) { function inner_value_change(widget, value) {
if(widget.type == "number"){
value = Number(value);
}
widget.value = value; widget.value = value;
if ( widget.options && widget.options.property && node.properties[widget.options.property] !== undefined ) { if ( widget.options && widget.options.property && node.properties[widget.options.property] !== undefined ) {
node.setProperty( widget.options.property, value ); node.setProperty( widget.options.property, value );
@ -11165,7 +11171,7 @@ LGraphNode.prototype.executeAction = function(action)
LGraphCanvas.search_limit = -1; LGraphCanvas.search_limit = -1;
LGraphCanvas.prototype.showSearchBox = function(event, options) { LGraphCanvas.prototype.showSearchBox = function(event, options) {
// proposed defaults // proposed defaults
def_options = { slot_from: null var def_options = { slot_from: null
,node_from: null ,node_from: null
,node_to: null ,node_to: null
,do_type_filter: LiteGraph.search_filter_enabled // TODO check for registered_slot_[in/out]_types not empty // this will be checked for functionality enabled : filter on slot type, in and out ,do_type_filter: LiteGraph.search_filter_enabled // TODO check for registered_slot_[in/out]_types not empty // this will be checked for functionality enabled : filter on slot type, in and out
@ -11863,7 +11869,7 @@ LGraphNode.prototype.executeAction = function(action)
// TODO refactor, theer are different dialog, some uses createDialog, some dont // TODO refactor, theer are different dialog, some uses createDialog, some dont
LGraphCanvas.prototype.createDialog = function(html, options) { LGraphCanvas.prototype.createDialog = function(html, options) {
def_options = { checkForInput: false, closeOnLeave: true, closeOnLeave_checkModified: true }; var def_options = { checkForInput: false, closeOnLeave: true, closeOnLeave_checkModified: true };
options = Object.assign(def_options, options || {}); options = Object.assign(def_options, options || {});
var dialog = document.createElement("div"); var dialog = document.createElement("div");
@ -11993,7 +11999,8 @@ LGraphNode.prototype.executeAction = function(action)
if (root.onClose && typeof root.onClose == "function"){ if (root.onClose && typeof root.onClose == "function"){
root.onClose(); root.onClose();
} }
root.parentNode.removeChild(root); if(root.parentNode)
root.parentNode.removeChild(root);
/* XXX CHECK THIS */ /* XXX CHECK THIS */
if(this.parentNode){ if(this.parentNode){
this.parentNode.removeChild(this); this.parentNode.removeChild(this);
@ -12285,7 +12292,7 @@ LGraphNode.prototype.executeAction = function(action)
var ref_window = this.getCanvasWindow(); var ref_window = this.getCanvasWindow();
var that = this; var that = this;
var graphcanvas = this; var graphcanvas = this;
panel = this.createPanel(node.title || "",{ var panel = this.createPanel(node.title || "",{
closable: true closable: true
,window: ref_window ,window: ref_window
,onOpen: function(){ ,onOpen: function(){

View File

@ -112,6 +112,46 @@ class ComfyApp {
}; };
} }
#addNodeKeyHandler(node) {
const app = this;
const origNodeOnKeyDown = node.prototype.onKeyDown;
node.prototype.onKeyDown = function(e) {
if (origNodeOnKeyDown && origNodeOnKeyDown.apply(this, e) === false) {
return false;
}
if (this.flags.collapsed || !this.imgs || this.imageIndex === null) {
return;
}
let handled = false;
if (e.key === "ArrowLeft" || e.key === "ArrowRight") {
if (e.key === "ArrowLeft") {
this.imageIndex -= 1;
} else if (e.key === "ArrowRight") {
this.imageIndex += 1;
}
this.imageIndex %= this.imgs.length;
if (this.imageIndex < 0) {
this.imageIndex = this.imgs.length + this.imageIndex;
}
handled = true;
} else if (e.key === "Escape") {
this.imageIndex = null;
handled = true;
}
if (handled === true) {
e.preventDefault();
e.stopImmediatePropagation();
return false;
}
}
}
/** /**
* Adds Custom drawing logic for nodes * Adds Custom drawing logic for nodes
* e.g. Draws images and handles thumbnail navigation on nodes that output images * e.g. Draws images and handles thumbnail navigation on nodes that output images
@ -803,6 +843,7 @@ class ComfyApp {
this.#addNodeContextMenuHandler(node); this.#addNodeContextMenuHandler(node);
this.#addDrawBackgroundHandler(node, app); this.#addDrawBackgroundHandler(node, app);
this.#addNodeKeyHandler(node);
await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData); await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData);
LiteGraph.registerNodeType(nodeId, node); LiteGraph.registerNodeType(nodeId, node);

View File

@ -115,7 +115,6 @@ function dragElement(dragEl, settings) {
savePos = value; savePos = value;
}, },
}); });
function dragMouseDown(e) { function dragMouseDown(e) {
e = e || window.event; e = e || window.event;
e.preventDefault(); e.preventDefault();
@ -163,7 +162,7 @@ class ComfyDialog {
$el("p", { $: (p) => (this.textElement = p) }), $el("p", { $: (p) => (this.textElement = p) }),
$el("button", { $el("button", {
type: "button", type: "button",
textContent: "CLOSE", textContent: "Close",
onclick: () => this.close(), onclick: () => this.close(),
}), }),
]), ]),
@ -226,6 +225,7 @@ class ComfySettingsDialog extends ComfyDialog {
}; };
let element; let element;
value = this.getSettingValue(id, defaultValue);
if (typeof type === "function") { if (typeof type === "function") {
element = type(name, setter, value, attrs); element = type(name, setter, value, attrs);
@ -282,6 +282,16 @@ class ComfySettingsDialog extends ComfyDialog {
return element; return element;
}, },
}); });
const self = this;
return {
get value() {
return self.getSettingValue(id, defaultValue);
},
set value(v) {
self.setSettingValue(id, v);
},
};
} }
show() { show() {
@ -403,6 +413,13 @@ export class ComfyUI {
this.history.update(); this.history.update();
}); });
const confirmClear = this.settings.addSetting({
id: "Comfy.ConfirmClear",
name: "Require confirmation when clearing workflow",
type: "boolean",
defaultValue: true,
});
const fileInput = $el("input", { const fileInput = $el("input", {
type: "file", type: "file",
accept: ".json,image/png", accept: ".json,image/png",
@ -414,7 +431,7 @@ export class ComfyUI {
}); });
this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [ this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [
$el("div", { style: { overflow: "hidden", position: "relative", width: "100%" } }, [ $el("div.drag-handle", { style: { overflow: "hidden", position: "relative", width: "100%", cursor: "default" } }, [
$el("span.drag-handle"), $el("span.drag-handle"),
$el("span", { $: (q) => (this.queueSize = q) }), $el("span", { $: (q) => (this.queueSize = q) }),
$el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }), $el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }),
@ -510,10 +527,16 @@ export class ComfyUI {
$el("button", { textContent: "Load", onclick: () => fileInput.click() }), $el("button", { textContent: "Load", onclick: () => fileInput.click() }),
$el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), $el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
$el("button", { textContent: "Clear", onclick: () => { $el("button", { textContent: "Clear", onclick: () => {
app.clean(); if (!confirmClear.value || confirm("Clear workflow?")) {
app.graph.clear(); app.clean();
app.graph.clear();
}
}}),
$el("button", { textContent: "Load Default", onclick: () => {
if (!confirmClear.value || confirm("Load default workflow?")) {
app.loadGraphData()
}
}}), }}),
$el("button", { textContent: "Load Default", onclick: () => app.loadGraphData() }),
]); ]);
dragElement(this.menuContainer, this.settings); dragElement(this.menuContainer, this.settings);

View File

@ -306,7 +306,7 @@ export const ComfyWidgets = {
const fileInput = document.createElement("input"); const fileInput = document.createElement("input");
Object.assign(fileInput, { Object.assign(fileInput, {
type: "file", type: "file",
accept: "image/jpeg,image/png", accept: "image/jpeg,image/png,image/webp",
style: "display: none", style: "display: none",
onchange: async () => { onchange: async () => {
if (fileInput.files.length) { if (fileInput.files.length) {

View File

@ -84,18 +84,19 @@ body {
position: fixed; /* Stay in place */ position: fixed; /* Stay in place */
z-index: 100; /* Sit on top */ z-index: 100; /* Sit on top */
padding: 30px 30px 10px 30px; padding: 30px 30px 10px 30px;
background-color: #ff0000; /* Modal background */ background-color: #353535; /* Modal background */
color: #ff4444;
box-shadow: 0px 0px 20px #888888; box-shadow: 0px 0px 20px #888888;
border-radius: 10px; border-radius: 10px;
text-align: center;
top: 50%; top: 50%;
left: 50%; left: 50%;
max-width: 80vw; max-width: 80vw;
max-height: 80vh; max-height: 80vh;
transform: translate(-50%, -50%); transform: translate(-50%, -50%);
overflow: hidden; overflow: hidden;
min-width: 60%;
justify-content: center; justify-content: center;
font-family: monospace;
font-size: 15px;
} }
.comfy-modal-content { .comfy-modal-content {
@ -115,31 +116,11 @@ body {
margin: 3px 3px 3px 4px; margin: 3px 3px 3px 4px;
} }
.comfy-modal button {
cursor: pointer;
color: #aaaaaa;
border: none;
background-color: transparent;
font-size: 24px;
font-weight: bold;
width: 100%;
}
.comfy-modal button:hover,
.comfy-modal button:focus {
color: #000;
text-decoration: none;
cursor: pointer;
}
.comfy-menu { .comfy-menu {
width: 200px;
font-size: 15px; font-size: 15px;
position: absolute; position: absolute;
top: 50%; top: 50%;
right: 0%; right: 0%;
background-color: white;
color: #000;
text-align: center; text-align: center;
z-index: 100; z-index: 100;
width: 170px; width: 170px;
@ -154,7 +135,8 @@ body {
box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4); box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4);
} }
.comfy-menu button { .comfy-menu button,
.comfy-modal button {
font-size: 20px; font-size: 20px;
} }
@ -175,7 +157,8 @@ body {
.comfy-menu > button, .comfy-menu > button,
.comfy-menu-btns button, .comfy-menu-btns button,
.comfy-menu .comfy-list button { .comfy-menu .comfy-list button,
.comfy-modal button{
color: #ddd; color: #ddd;
background-color: #222; background-color: #222;
border-radius: 8px; border-radius: 8px;
@ -265,11 +248,22 @@ button.comfy-queue-btn {
} }
.comfy-modal.comfy-settings { .comfy-modal.comfy-settings {
background-color: var(--bg-color); text-align: center;
color: var(--fg-color); font-family: sans-serif;
color: #999;
z-index: 99; z-index: 99;
} }
.comfy-modal input,
.comfy-modal select {
color: #ddd;
background-color: #222;
border-radius: 8px;
border-color: #4e4e4e;
border-style: solid;
font-size: inherit;
}
@media only screen and (max-height: 850px) { @media only screen and (max-height: 850px) {
.comfy-menu { .comfy-menu {
top: 0 !important; top: 0 !important;
@ -282,3 +276,28 @@ button.comfy-queue-btn {
visibility:hidden visibility:hidden
} }
} }
.graphdialog {
min-height: 1em;
}
.graphdialog .name {
font-size: 14px;
font-family: sans-serif;
color: #999999;
}
.graphdialog button {
margin-top: unset;
vertical-align: unset;
height: 1.6em;
padding-right: 8px;
}
.graphdialog input, .graphdialog textarea, .graphdialog select {
background-color: #222;
border: 2px solid;
border-color: #444444;
color: #ddd;
border-radius: 12px 0 0 12px;
}