Merge branch 'comfyanonymous:master' into feature/blockweights

This commit is contained in:
ltdrdata 2023-04-07 17:40:25 +09:00 committed by GitHub
commit 81720a855f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 721 additions and 212 deletions

View File

@ -14,7 +14,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) - Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
- Works even if you don't have a GPU with: ```--cpu``` (slow) - Works even if you don't have a GPU with: ```--cpu``` (slow)
- Can load both ckpt and safetensors models/checkpoints. Standalone VAEs and CLIP models. - Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
- Embeddings/Textual inversion - Embeddings/Textual inversion
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) - [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- Loading full workflows (with seeds) from generated PNG files. - Loading full workflows (with seeds) from generated PNG files.

31
comfy/cli_args.py Normal file
View File

@ -0,0 +1,31 @@
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build.")
args = parser.parse_args()

362
comfy/diffusers_convert.py Normal file
View File

@ -0,0 +1,362 @@
import json
import os
import yaml
import folder_paths
from comfy.ldm.util import instantiate_from_config
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
import os.path as osp
import re
import torch
from safetensors.torch import load_file, save_file
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
# =================#
# UNet Conversion #
# =================#
unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("out.0.weight", "conv_norm_out.weight"),
("out.0.bias", "conv_norm_out.bias"),
("out.2.weight", "conv_out.weight"),
("out.2.bias", "conv_out.bias"),
]
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0", "norm1"),
("in_layers.2", "conv1"),
("out_layers.0", "norm2"),
("out_layers.3", "conv2"),
("emb_layers.1", "time_emb_proj"),
("skip_connection", "conv_shortcut"),
]
unet_conversion_map_layer = []
# hardcoded number of downblocks and resnets/attentions...
# would need smarter logic for other networks.
for i in range(4):
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
def convert_unet_state_dict(unet_state_dict):
# buyer beware: this is a *brittle* function,
# and correct output requires that all of these pieces interact in
# the exact order in which I have arranged them.
mapping = {k: k for k in unet_state_dict.keys()}
for sd_name, hf_name in unet_conversion_map:
mapping[hf_name] = sd_name
for k, v in mapping.items():
if "resnets" in k:
for sd_part, hf_part in unet_conversion_map_resnet:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
for sd_part, hf_part in unet_conversion_map_layer:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
# ================#
# VAE Conversion #
# ================#
vae_conversion_map = [
# (stable-diffusion, HF Diffusers)
("nin_shortcut", "conv_shortcut"),
("norm_out", "conv_norm_out"),
("mid.attn_1.", "mid_block.attentions.0."),
]
for i in range(4):
# down_blocks have two resnets
for j in range(2):
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
sd_down_prefix = f"encoder.down.{i}.block.{j}."
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
if i < 3:
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
sd_downsample_prefix = f"down.{i}.downsample."
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"up.{3 - i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
sd_mid_res_prefix = f"mid.block_{i + 1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers)
("norm.", "group_norm."),
("q.", "query."),
("k.", "key."),
("v.", "value."),
("proj_out.", "proj_attn."),
]
def reshape_weight_for_sd(w):
# convert HF linear weights to SD conv2d weights
return w.reshape(*w.shape, 1, 1)
def convert_vae_state_dict(vae_state_dict):
mapping = {k: k for k in vae_state_dict.keys()}
for k, v in mapping.items():
for sd_part, hf_part in vae_conversion_map:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
if "attentions" in k:
for sd_part, hf_part in vae_conversion_map_attn:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
weights_to_convert = ["q", "k", "v", "proj_out"]
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
print(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict
# =========================#
# Text Encoder Conversion #
# =========================#
textenc_conversion_lst = [
# (stable-diffusion, HF Diffusers)
("resblocks.", "text_model.encoder.layers."),
("ln_1", "layer_norm1"),
("ln_2", "layer_norm2"),
(".c_fc.", ".fc1."),
(".c_proj.", ".fc2."),
(".attn", ".self_attn"),
("ln_final.", "transformer.text_model.final_layer_norm."),
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
]
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
textenc_pattern = re.compile("|".join(protected.keys()))
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx = {"q": 0, "k": 1, "v": 2}
def convert_text_enc_state_dict_v20(text_enc_dict):
new_state_dict = {}
capture_qkv_weight = {}
capture_qkv_bias = {}
for k, v in text_enc_dict.items():
if (
k.endswith(".self_attn.q_proj.weight")
or k.endswith(".self_attn.k_proj.weight")
or k.endswith(".self_attn.v_proj.weight")
):
k_pre = k[: -len(".q_proj.weight")]
k_code = k[-len("q_proj.weight")]
if k_pre not in capture_qkv_weight:
capture_qkv_weight[k_pre] = [None, None, None]
capture_qkv_weight[k_pre][code2idx[k_code]] = v
continue
if (
k.endswith(".self_attn.q_proj.bias")
or k.endswith(".self_attn.k_proj.bias")
or k.endswith(".self_attn.v_proj.bias")
):
k_pre = k[: -len(".q_proj.bias")]
k_code = k[-len("q_proj.bias")]
if k_pre not in capture_qkv_bias:
capture_qkv_bias[k_pre] = [None, None, None]
capture_qkv_bias[k_pre][code2idx[k_code]] = v
continue
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
new_state_dict[relabelled_key] = v
for k_pre, tensors in capture_qkv_weight.items():
if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
for k_pre, tensors in capture_qkv_bias.items():
if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
return new_state_dict
def convert_text_enc_state_dict(text_enc_dict):
return text_enc_dict
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
# magic
v2 = diffusers_unet_conf["sample_size"] == 96
if 'prediction_type' in diffusers_scheduler_conf:
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
if v2:
if v_pred:
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
vae_config = model_config_params['first_stage_config']
vae_config['scale_factor'] = scale_factor
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
# Load models from safetensors if it exists, if it doesn't pytorch
if osp.exists(unet_path):
unet_state_dict = load_file(unet_path, device="cpu")
else:
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
unet_state_dict = torch.load(unet_path, map_location="cpu")
if osp.exists(vae_path):
vae_state_dict = load_file(vae_path, device="cpu")
else:
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
vae_state_dict = torch.load(vae_path, map_location="cpu")
if osp.exists(text_enc_path):
text_enc_dict = load_file(text_enc_path, device="cpu")
else:
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
# Convert the UNet model
unet_state_dict = convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
# Convert the VAE model
vae_state_dict = convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
if is_v20_model:
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
else:
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
# Put together new checkpoint
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
clip = None
vae = None
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
load_state_dict_to = []
if output_vae:
vae = VAE(scale_factor=scale_factor, config=vae_config)
w.first_stage_model = vae.first_stage_model
load_state_dict_to = [w]
if output_clip:
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
model = instantiate_from_config(config["model"])
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
model = model.half()
return ModelPatcher(model), clip, vae

View File

@ -21,6 +21,8 @@ if model_management.xformers_enabled():
import os import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
from cli_args import args
def exists(val): def exists(val):
return val is not None return val is not None
@ -474,7 +476,6 @@ class CrossAttentionPytorch(nn.Module):
return self.to_out(out) return self.to_out(out)
import sys
if model_management.xformers_enabled(): if model_management.xformers_enabled():
print("Using xformers cross attention") print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention CrossAttention = MemoryEfficientCrossAttention
@ -482,7 +483,7 @@ elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention") print("Using pytorch cross attention")
CrossAttention = CrossAttentionPytorch CrossAttention = CrossAttentionPytorch
else: else:
if "--use-split-cross-attention" in sys.argv: if args.use_split_cross_attention:
print("Using split optimization for cross attention") print("Using split optimization for cross attention")
CrossAttention = CrossAttentionDoggettx CrossAttention = CrossAttentionDoggettx
else: else:

View File

@ -1,36 +1,42 @@
import psutil
from enum import Enum
from cli_args import args
CPU = 0 class VRAMState(Enum):
NO_VRAM = 1 CPU = 0
LOW_VRAM = 2 NO_VRAM = 1
NORMAL_VRAM = 3 LOW_VRAM = 2
HIGH_VRAM = 4 NORMAL_VRAM = 3
MPS = 5 HIGH_VRAM = 4
MPS = 5
accelerate_enabled = False # Determine VRAM State
vram_state = NORMAL_VRAM vram_state = VRAMState.NORMAL_VRAM
set_vram_to = VRAMState.NORMAL_VRAM
total_vram = 0 total_vram = 0
total_vram_available_mb = -1 total_vram_available_mb = -1
import sys accelerate_enabled = False
import psutil xpu_available = False
forced_cpu = "--cpu" in sys.argv
set_vram_to = NORMAL_VRAM
try: try:
import torch import torch
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
except:
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
forced_normal_vram = "--normalvram" in sys.argv if not args.normalvram and not args.cpu:
if not forced_normal_vram and not forced_cpu:
if total_vram <= 4096: if total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = LOW_VRAM set_vram_to = VRAMState.LOW_VRAM
elif total_vram > total_ram * 1.1 and total_vram > 14336: elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = HIGH_VRAM vram_state = VRAMState.HIGH_VRAM
except: except:
pass pass
@ -39,34 +45,37 @@ try:
except: except:
OOM_EXCEPTION = Exception OOM_EXCEPTION = Exception
if "--disable-xformers" in sys.argv: if args.disable_xformers:
XFORMERS_IS_AVAILBLE = False XFORMERS_IS_AVAILABLE = False
else: else:
try: try:
import xformers import xformers
import xformers.ops import xformers.ops
XFORMERS_IS_AVAILBLE = True XFORMERS_IS_AVAILABLE = True
except: except:
XFORMERS_IS_AVAILBLE = False XFORMERS_IS_AVAILABLE = False
ENABLE_PYTORCH_ATTENTION = False ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
if "--use-pytorch-cross-attention" in sys.argv: if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True)
ENABLE_PYTORCH_ATTENTION = True XFORMERS_IS_AVAILABLE = False
XFORMERS_IS_AVAILBLE = False
if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM
elif args.novram:
set_vram_to = VRAMState.NO_VRAM
elif args.highvram:
vram_state = VRAMState.HIGH_VRAM
FORCE_FP32 = False
if args.force_fp32:
print("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True
if "--lowvram" in sys.argv: if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
set_vram_to = LOW_VRAM
if "--novram" in sys.argv:
set_vram_to = NO_VRAM
if "--highvram" in sys.argv:
vram_state = HIGH_VRAM
if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
try: try:
import accelerate import accelerate
accelerate_enabled = True accelerate_enabled = True
@ -81,14 +90,14 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
try: try:
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
vram_state = MPS vram_state = VRAMState.MPS
except: except:
pass pass
if forced_cpu: if args.cpu:
vram_state = CPU vram_state = VRAMState.CPU
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state]) print(f"Set vram state to: {vram_state.name}")
current_loaded_model = None current_loaded_model = None
@ -109,12 +118,12 @@ def unload_model():
model_accelerated = False model_accelerated = False
#never unload models from GPU on high vram #never unload models from GPU on high vram
if vram_state != HIGH_VRAM: if vram_state != VRAMState.HIGH_VRAM:
current_loaded_model.model.cpu() current_loaded_model.model.cpu()
current_loaded_model.unpatch_model() current_loaded_model.unpatch_model()
current_loaded_model = None current_loaded_model = None
if vram_state != HIGH_VRAM: if vram_state != VRAMState.HIGH_VRAM:
if len(current_gpu_controlnets) > 0: if len(current_gpu_controlnets) > 0:
for n in current_gpu_controlnets: for n in current_gpu_controlnets:
n.cpu() n.cpu()
@ -135,32 +144,32 @@ def load_model_gpu(model):
model.unpatch_model() model.unpatch_model()
raise e raise e
current_loaded_model = model current_loaded_model = model
if vram_state == CPU: if vram_state == VRAMState.CPU:
pass pass
elif vram_state == MPS: elif vram_state == VRAMState.MPS:
mps_device = torch.device("mps") mps_device = torch.device("mps")
real_model.to(mps_device) real_model.to(mps_device)
pass pass
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM: elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
model_accelerated = False model_accelerated = False
real_model.cuda() real_model.to(get_torch_device())
else: else:
if vram_state == NO_VRAM: if vram_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_state == LOW_VRAM: elif vram_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda") accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
model_accelerated = True model_accelerated = True
return current_loaded_model return current_loaded_model
def load_controlnet_gpu(models): def load_controlnet_gpu(models):
global current_gpu_controlnets global current_gpu_controlnets
global vram_state global vram_state
if vram_state == CPU: if vram_state == VRAMState.CPU:
return return
if vram_state == LOW_VRAM or vram_state == NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return return
@ -176,23 +185,27 @@ def load_controlnet_gpu(models):
def load_if_low_vram(model): def load_if_low_vram(model):
global vram_state global vram_state
if vram_state == LOW_VRAM or vram_state == NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.cuda() return model.to(get_torch_device())
return model return model
def unload_if_low_vram(model): def unload_if_low_vram(model):
global vram_state global vram_state
if vram_state == LOW_VRAM or vram_state == NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.cpu() return model.cpu()
return model return model
def get_torch_device(): def get_torch_device():
if vram_state == MPS: global xpu_available
if vram_state == VRAMState.MPS:
return torch.device("mps") return torch.device("mps")
if vram_state == CPU: if vram_state == VRAMState.CPU:
return torch.device("cpu") return torch.device("cpu")
else: else:
return torch.cuda.current_device() if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_autocast_device(dev): def get_autocast_device(dev):
if hasattr(dev, 'type'): if hasattr(dev, 'type'):
@ -201,9 +214,9 @@ def get_autocast_device(dev):
def xformers_enabled(): def xformers_enabled():
if vram_state == CPU: if vram_state == VRAMState.CPU:
return False return False
return XFORMERS_IS_AVAILBLE return XFORMERS_IS_AVAILABLE
def xformers_enabled_vae(): def xformers_enabled_vae():
@ -222,6 +235,7 @@ def pytorch_attention_enabled():
return ENABLE_PYTORCH_ATTENTION return ENABLE_PYTORCH_ATTENTION
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global xpu_available
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()
@ -229,12 +243,16 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = psutil.virtual_memory().available mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
else: else:
stats = torch.cuda.memory_stats(dev) if xpu_available:
mem_active = stats['active_bytes.all.current'] mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_free_torch = mem_free_total
mem_free_cuda, _ = torch.cuda.mem_get_info(dev) else:
mem_free_torch = mem_reserved - mem_active stats = torch.cuda.memory_stats(dev)
mem_free_total = mem_free_cuda + mem_free_torch mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
if torch_free_too: if torch_free_too:
return (mem_free_total, mem_free_torch) return (mem_free_total, mem_free_torch)
@ -243,7 +261,7 @@ def get_free_memory(dev=None, torch_free_too=False):
def maximum_batch_area(): def maximum_batch_area():
global vram_state global vram_state
if vram_state == NO_VRAM: if vram_state == VRAMState.NO_VRAM:
return 0 return 0
memory_free = get_free_memory() / (1024 * 1024) memory_free = get_free_memory() / (1024 * 1024)
@ -252,14 +270,18 @@ def maximum_batch_area():
def cpu_mode(): def cpu_mode():
global vram_state global vram_state
return vram_state == CPU return vram_state == VRAMState.CPU
def mps_mode(): def mps_mode():
global vram_state global vram_state
return vram_state == MPS return vram_state == VRAMState.MPS
def should_use_fp16(): def should_use_fp16():
if cpu_mode() or mps_mode(): global xpu_available
if FORCE_FP32:
return False
if cpu_mode() or mps_mode() or xpu_available:
return False #TODO ? return False #TODO ?
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():

View File

@ -23,10 +23,45 @@ folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions) folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
if not os.path.exists(input_directory):
os.makedirs(input_directory)
def set_output_directory(output_dir):
global output_directory
output_directory = output_dir
def get_output_directory():
global output_directory
return output_directory
def get_temp_directory():
global temp_directory
return temp_directory
def get_input_directory():
global input_directory
return input_directory
#NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name):
if type_name == "output":
return get_output_directory()
if type_name == "temp":
return get_temp_directory()
if type_name == "input":
return get_input_directory()
return None
def add_model_folder_path(folder_name, full_folder_path): def add_model_folder_path(folder_name, full_folder_path):
global folder_names_and_paths global folder_names_and_paths

97
main.py
View File

@ -1,56 +1,33 @@
import os
import sys
import shutil
import threading
import asyncio import asyncio
import itertools
import os
import shutil
import threading
from comfy.cli_args import args
if os.name == "nt": if os.name == "nt":
import logging import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
if __name__ == "__main__": if __name__ == "__main__":
if '--help' in sys.argv: if args.dont_upcast_attention:
print()
print("Valid Command line Arguments:")
print("\t--listen [ip]\t\t\tListen on ip or 0.0.0.0 if none given so the UI can be accessed from other computers.")
print("\t--port 8188\t\t\tSet the listen port.")
print()
print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.")
print()
print()
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.")
print("\t--disable-xformers\t\tdisables xformers")
print("\t--cuda-device 1\t\tSet the id of the cuda device this instance will use.")
print()
print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n")
print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.")
print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.")
print("\t--novram\t\t\tWhen lowvram isn't enough.")
print()
print("\t--cpu\t\t\tTo use the CPU for everything (slow).")
exit()
if '--dont-upcast-attention' in sys.argv:
print("disabling upcasting of attention") print("disabling upcasting of attention")
os.environ['ATTN_PRECISION'] = "fp16" os.environ['ATTN_PRECISION'] = "fp16"
try: if args.cuda_device is not None:
index = sys.argv.index('--cuda-device') os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
device = sys.argv[index + 1] print("Set cuda device to:", args.cuda_device)
os.environ['CUDA_VISIBLE_DEVICES'] = device
print("Set cuda device to:", device)
except:
pass
from nodes import init_custom_nodes
import execution
import server
import folder_paths
import yaml import yaml
import execution
import folder_paths
import server
from nodes import init_custom_nodes
def prompt_worker(q, server): def prompt_worker(q, server):
e = execution.PromptExecutor(server) e = execution.PromptExecutor(server)
while True: while True:
@ -109,43 +86,31 @@ if __name__ == "__main__":
hijack_progress(server) hijack_progress(server)
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
try:
address = '0.0.0.0'
p_index = sys.argv.index('--listen')
try:
ip = sys.argv[p_index + 1]
if ip[:2] != '--':
address = ip
except:
pass
except:
address = '127.0.0.1'
dont_print = False address = args.listen
if '--dont-print-server' in sys.argv:
dont_print = True dont_print = args.dont_print_server
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path): if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path) load_extra_path_config(extra_model_paths_config_path)
if '--extra-model-paths-config' in sys.argv: if args.extra_model_paths_config:
indices = [(i + 1) for i in range(len(sys.argv) - 1) if sys.argv[i] == '--extra-model-paths-config'] for config_path in itertools.chain(*args.extra_model_paths_config):
for i in indices: load_extra_path_config(config_path)
load_extra_path_config(sys.argv[i])
port = 8188 if args.output_directory:
try: output_dir = os.path.abspath(args.output_directory)
p_index = sys.argv.index('--port') print(f"Setting output directory to: {output_dir}")
port = int(sys.argv[p_index + 1]) folder_paths.set_output_directory(output_dir)
except:
pass
if '--quick-test-for-ci' in sys.argv: port = args.port
if args.quick_test_for_ci:
exit(0) exit(0)
call_on_start = None call_on_start = None
if "--windows-standalone-build" in sys.argv: if args.windows_standalone_build:
def startup_server(address, port): def startup_server(address, port):
import webbrowser import webbrowser
webbrowser.open("http://{}:{}".format(address, port)) webbrowser.open("http://{}:{}".format(address, port))

View File

@ -4,16 +4,17 @@ import os
import sys import sys
import json import json
import hashlib import hashlib
import copy
import traceback import traceback
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
import numpy as np import numpy as np
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
import comfy.diffusers_convert
import comfy.samplers import comfy.samplers
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
@ -219,6 +220,30 @@ class CheckpointLoaderSimple:
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out return out
class DiffusersLoader:
@classmethod
def INPUT_TYPES(cls):
paths = []
for search_path in folder_paths.get_folder_paths("diffusers"):
if os.path.exists(search_path):
paths += next(os.walk(search_path))[1]
return {"required": {"model_path": (paths,), }}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders"
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
for search_path in folder_paths.get_folder_paths("diffusers"):
if os.path.exists(search_path):
paths = next(os.walk(search_path))[1]
if model_path in paths:
model_path = os.path.join(search_path, model_path)
break
return comfy.diffusers_convert.load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class unCLIPCheckpointLoader: class unCLIPCheckpointLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -853,7 +878,7 @@ class KSamplerAdvanced:
class SaveImage: class SaveImage:
def __init__(self): def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") self.output_dir = folder_paths.get_output_directory()
self.type = "output" self.type = "output"
@classmethod @classmethod
@ -905,9 +930,6 @@ class SaveImage:
os.makedirs(full_output_folder, exist_ok=True) os.makedirs(full_output_folder, exist_ok=True)
counter = 1 counter = 1
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
results = list() results = list()
for image in images: for image in images:
i = 255. * image.cpu().numpy() i = 255. * image.cpu().numpy()
@ -932,7 +954,7 @@ class SaveImage:
class PreviewImage(SaveImage): class PreviewImage(SaveImage):
def __init__(self): def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") self.output_dir = folder_paths.get_temp_directory()
self.type = "temp" self.type = "temp"
@classmethod @classmethod
@ -943,13 +965,11 @@ class PreviewImage(SaveImage):
} }
class LoadImage: class LoadImage:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
if not os.path.exists(s.input_dir): input_dir = folder_paths.get_input_directory()
os.makedirs(s.input_dir)
return {"required": return {"required":
{"image": (sorted(os.listdir(s.input_dir)), )}, {"image": (sorted(os.listdir(input_dir)), )},
} }
CATEGORY = "image" CATEGORY = "image"
@ -957,7 +977,8 @@ class LoadImage:
RETURN_TYPES = ("IMAGE", "MASK") RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image): def load_image(self, image):
image_path = os.path.join(self.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
i = Image.open(image_path) i = Image.open(image_path)
image = i.convert("RGB") image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
@ -971,18 +992,19 @@ class LoadImage:
@classmethod @classmethod
def IS_CHANGED(s, image): def IS_CHANGED(s, image):
image_path = os.path.join(s.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
m = hashlib.sha256() m = hashlib.sha256()
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
m.update(f.read()) m.update(f.read())
return m.digest().hex() return m.digest().hex()
class LoadImageMask: class LoadImageMask:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
return {"required": return {"required":
{"image": (sorted(os.listdir(s.input_dir)), ), {"image": (sorted(os.listdir(input_dir)), ),
"channel": (["alpha", "red", "green", "blue"], ),} "channel": (["alpha", "red", "green", "blue"], ),}
} }
@ -991,7 +1013,8 @@ class LoadImageMask:
RETURN_TYPES = ("MASK",) RETURN_TYPES = ("MASK",)
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image, channel): def load_image(self, image, channel):
image_path = os.path.join(self.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
i = Image.open(image_path) i = Image.open(image_path)
mask = None mask = None
c = channel[0].upper() c = channel[0].upper()
@ -1006,7 +1029,8 @@ class LoadImageMask:
@classmethod @classmethod
def IS_CHANGED(s, image, channel): def IS_CHANGED(s, image, channel):
image_path = os.path.join(s.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
m = hashlib.sha256() m = hashlib.sha256()
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
m.update(f.read()) m.update(f.read())
@ -1154,6 +1178,7 @@ NODE_CLASS_MAPPINGS = {
"TomePatchModel": TomePatchModel, "TomePatchModel": TomePatchModel,
"unCLIPCheckpointLoader": unCLIPCheckpointLoader, "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"CheckpointLoader": CheckpointLoader, "CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader,
} }
def load_custom_node(module_path): def load_custom_node(module_path):

View File

@ -4,7 +4,7 @@ torchsde
einops einops
open-clip-torch open-clip-torch
transformers>=4.25.1 transformers>=4.25.1
safetensors safetensors>=0.3.0
pytorch_lightning pytorch_lightning
aiohttp aiohttp
accelerate accelerate

View File

@ -18,6 +18,7 @@ except ImportError:
sys.exit() sys.exit()
import mimetypes import mimetypes
from comfy.cli_args import args
@web.middleware @web.middleware
@ -27,6 +28,23 @@ async def cache_control(request: web.Request, handler):
response.headers.setdefault('Cache-Control', 'no-cache') response.headers.setdefault('Cache-Control', 'no-cache')
return response return response
def create_cors_middleware(allowed_origin: str):
@web.middleware
async def cors_middleware(request: web.Request, handler):
if request.method == "OPTIONS":
# Pre-flight request. Reply successfully:
response = web.Response()
else:
response = await handler(request)
response.headers['Access-Control-Allow-Origin'] = allowed_origin
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response.headers['Access-Control-Allow-Credentials'] = 'true'
return response
return cors_middleware
class PromptServer(): class PromptServer():
def __init__(self, loop): def __init__(self, loop):
PromptServer.instance = self PromptServer.instance = self
@ -37,7 +55,12 @@ class PromptServer():
self.loop = loop self.loop = loop
self.messages = asyncio.Queue() self.messages = asyncio.Queue()
self.number = 0 self.number = 0
self.app = web.Application(client_max_size=20971520, middlewares=[cache_control])
middlewares = [cache_control]
if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header))
self.app = web.Application(client_max_size=20971520, middlewares=middlewares)
self.sockets = dict() self.sockets = dict()
self.web_root = os.path.join(os.path.dirname( self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "web") os.path.realpath(__file__)), "web")
@ -89,7 +112,7 @@ class PromptServer():
@routes.post("/upload/image") @routes.post("/upload/image")
async def upload_image(request): async def upload_image(request):
upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") upload_dir = folder_paths.get_input_directory()
if not os.path.exists(upload_dir): if not os.path.exists(upload_dir):
os.makedirs(upload_dir) os.makedirs(upload_dir)
@ -122,10 +145,10 @@ class PromptServer():
async def view_image(request): async def view_image(request):
if "filename" in request.rel_url.query: if "filename" in request.rel_url.query:
type = request.rel_url.query.get("type", "output") type = request.rel_url.query.get("type", "output")
if type not in ["output", "input", "temp"]: output_dir = folder_paths.get_directory_by_type(type)
if output_dir is None:
return web.Response(status=400) return web.Response(status=400)
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type)
if "subfolder" in request.rel_url.query: if "subfolder" in request.rel_url.query:
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"]) full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir: if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:

View File

@ -112,6 +112,46 @@ class ComfyApp {
}; };
} }
#addNodeKeyHandler(node) {
const app = this;
const origNodeOnKeyDown = node.prototype.onKeyDown;
node.prototype.onKeyDown = function(e) {
if (origNodeOnKeyDown && origNodeOnKeyDown.apply(this, e) === false) {
return false;
}
if (this.flags.collapsed || !this.imgs || this.imageIndex === null) {
return;
}
let handled = false;
if (e.key === "ArrowLeft" || e.key === "ArrowRight") {
if (e.key === "ArrowLeft") {
this.imageIndex -= 1;
} else if (e.key === "ArrowRight") {
this.imageIndex += 1;
}
this.imageIndex %= this.imgs.length;
if (this.imageIndex < 0) {
this.imageIndex = this.imgs.length + this.imageIndex;
}
handled = true;
} else if (e.key === "Escape") {
this.imageIndex = null;
handled = true;
}
if (handled === true) {
e.preventDefault();
e.stopImmediatePropagation();
return false;
}
}
}
/** /**
* Adds Custom drawing logic for nodes * Adds Custom drawing logic for nodes
* e.g. Draws images and handles thumbnail navigation on nodes that output images * e.g. Draws images and handles thumbnail navigation on nodes that output images
@ -803,6 +843,7 @@ class ComfyApp {
this.#addNodeContextMenuHandler(node); this.#addNodeContextMenuHandler(node);
this.#addDrawBackgroundHandler(node, app); this.#addDrawBackgroundHandler(node, app);
this.#addNodeKeyHandler(node);
await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData); await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData);
LiteGraph.registerNodeType(nodeId, node); LiteGraph.registerNodeType(nodeId, node);

View File

@ -115,14 +115,6 @@ function dragElement(dragEl, settings) {
savePos = value; savePos = value;
}, },
}); });
settings.addSetting({
id: "Comfy.ConfirmClear",
name: "Require confirmation when clearing workflow",
type: "boolean",
defaultValue: true,
});
function dragMouseDown(e) { function dragMouseDown(e) {
e = e || window.event; e = e || window.event;
e.preventDefault(); e.preventDefault();
@ -170,7 +162,7 @@ class ComfyDialog {
$el("p", { $: (p) => (this.textElement = p) }), $el("p", { $: (p) => (this.textElement = p) }),
$el("button", { $el("button", {
type: "button", type: "button",
textContent: "CLOSE", textContent: "Close",
onclick: () => this.close(), onclick: () => this.close(),
}), }),
]), ]),
@ -233,6 +225,7 @@ class ComfySettingsDialog extends ComfyDialog {
}; };
let element; let element;
value = this.getSettingValue(id, defaultValue);
if (typeof type === "function") { if (typeof type === "function") {
element = type(name, setter, value, attrs); element = type(name, setter, value, attrs);
@ -289,6 +282,16 @@ class ComfySettingsDialog extends ComfyDialog {
return element; return element;
}, },
}); });
const self = this;
return {
get value() {
return self.getSettingValue(id, defaultValue);
},
set value(v) {
self.setSettingValue(id, v);
},
};
} }
show() { show() {
@ -410,6 +413,13 @@ export class ComfyUI {
this.history.update(); this.history.update();
}); });
const confirmClear = this.settings.addSetting({
id: "Comfy.ConfirmClear",
name: "Require confirmation when clearing workflow",
type: "boolean",
defaultValue: true,
});
const fileInput = $el("input", { const fileInput = $el("input", {
type: "file", type: "file",
accept: ".json,image/png", accept: ".json,image/png",
@ -421,7 +431,7 @@ export class ComfyUI {
}); });
this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [ this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [
$el("div", { style: { overflow: "hidden", position: "relative", width: "100%" } }, [ $el("div.drag-handle", { style: { overflow: "hidden", position: "relative", width: "100%", cursor: "default" } }, [
$el("span.drag-handle"), $el("span.drag-handle"),
$el("span", { $: (q) => (this.queueSize = q) }), $el("span", { $: (q) => (this.queueSize = q) }),
$el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }), $el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }),
@ -517,13 +527,13 @@ export class ComfyUI {
$el("button", { textContent: "Load", onclick: () => fileInput.click() }), $el("button", { textContent: "Load", onclick: () => fileInput.click() }),
$el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), $el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
$el("button", { textContent: "Clear", onclick: () => { $el("button", { textContent: "Clear", onclick: () => {
if (localStorage.getItem("Comfy.Settings.Comfy.ConfirmClear") == "false" || confirm("Clear workflow?")) { if (!confirmClear.value || confirm("Clear workflow?")) {
app.clean(); app.clean();
app.graph.clear(); app.graph.clear();
} }
}}), }}),
$el("button", { textContent: "Load Default", onclick: () => { $el("button", { textContent: "Load Default", onclick: () => {
if (localStorage.getItem("Comfy.Settings.Comfy.ConfirmClear") == "false" || confirm("Load default workflow?")) { if (!confirmClear.value || confirm("Load default workflow?")) {
app.loadGraphData() app.loadGraphData()
} }
}}), }}),

View File

@ -306,7 +306,7 @@ export const ComfyWidgets = {
const fileInput = document.createElement("input"); const fileInput = document.createElement("input");
Object.assign(fileInput, { Object.assign(fileInput, {
type: "file", type: "file",
accept: "image/jpeg,image/png", accept: "image/jpeg,image/png,image/webp",
style: "display: none", style: "display: none",
onchange: async () => { onchange: async () => {
if (fileInput.files.length) { if (fileInput.files.length) {

View File

@ -39,18 +39,19 @@ body {
position: fixed; /* Stay in place */ position: fixed; /* Stay in place */
z-index: 100; /* Sit on top */ z-index: 100; /* Sit on top */
padding: 30px 30px 10px 30px; padding: 30px 30px 10px 30px;
background-color: #ff0000; /* Modal background */ background-color: #353535; /* Modal background */
color: #ff4444;
box-shadow: 0px 0px 20px #888888; box-shadow: 0px 0px 20px #888888;
border-radius: 10px; border-radius: 10px;
text-align: center;
top: 50%; top: 50%;
left: 50%; left: 50%;
max-width: 80vw; max-width: 80vw;
max-height: 80vh; max-height: 80vh;
transform: translate(-50%, -50%); transform: translate(-50%, -50%);
overflow: hidden; overflow: hidden;
min-width: 60%;
justify-content: center; justify-content: center;
font-family: monospace;
font-size: 15px;
} }
.comfy-modal-content { .comfy-modal-content {
@ -70,31 +71,11 @@ body {
margin: 3px 3px 3px 4px; margin: 3px 3px 3px 4px;
} }
.comfy-modal button {
cursor: pointer;
color: #aaaaaa;
border: none;
background-color: transparent;
font-size: 24px;
font-weight: bold;
width: 100%;
}
.comfy-modal button:hover,
.comfy-modal button:focus {
color: #000;
text-decoration: none;
cursor: pointer;
}
.comfy-menu { .comfy-menu {
width: 200px;
font-size: 15px; font-size: 15px;
position: absolute; position: absolute;
top: 50%; top: 50%;
right: 0%; right: 0%;
background-color: white;
color: #000;
text-align: center; text-align: center;
z-index: 100; z-index: 100;
width: 170px; width: 170px;
@ -109,7 +90,8 @@ body {
box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4); box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4);
} }
.comfy-menu button { .comfy-menu button,
.comfy-modal button {
font-size: 20px; font-size: 20px;
} }
@ -130,7 +112,8 @@ body {
.comfy-menu > button, .comfy-menu > button,
.comfy-menu-btns button, .comfy-menu-btns button,
.comfy-menu .comfy-list button { .comfy-menu .comfy-list button,
.comfy-modal button{
color: #ddd; color: #ddd;
background-color: #222; background-color: #222;
border-radius: 8px; border-radius: 8px;
@ -220,11 +203,22 @@ button.comfy-queue-btn {
} }
.comfy-modal.comfy-settings { .comfy-modal.comfy-settings {
background-color: var(--bg-color); text-align: center;
color: var(--fg-color); font-family: sans-serif;
color: #999;
z-index: 99; z-index: 99;
} }
.comfy-modal input,
.comfy-modal select {
color: #ddd;
background-color: #222;
border-radius: 8px;
border-color: #4e4e4e;
border-style: solid;
font-size: inherit;
}
@media only screen and (max-height: 850px) { @media only screen and (max-height: 850px) {
.comfy-menu { .comfy-menu {
top: 0 !important; top: 0 !important;
@ -239,26 +233,26 @@ button.comfy-queue-btn {
} }
.graphdialog { .graphdialog {
min-height: 1em; min-height: 1em;
} }
.graphdialog .name { .graphdialog .name {
font-size: 14px; font-size: 14px;
font-family: sans-serif; font-family: sans-serif;
color: #999999; color: #999999;
} }
.graphdialog button { .graphdialog button {
margin-top: unset; margin-top: unset;
vertical-align: unset; vertical-align: unset;
height: 1.6em; height: 1.6em;
padding-right: 8px; padding-right: 8px;
} }
.graphdialog input, .graphdialog textarea, .graphdialog select { .graphdialog input, .graphdialog textarea, .graphdialog select {
background-color: #222; background-color: #222;
border: 2px solid; border: 2px solid;
border-color: #444444; border-color: #444444;
color: #ddd; color: #ddd;
border-radius: 12px 0 0 12px; border-radius: 12px 0 0 12px;
} }