diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py
index efb2d5384..a95707e46 100644
--- a/comfy/clip_vision.py
+++ b/comfy/clip_vision.py
@@ -1,4 +1,4 @@
-from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor
+from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils
from .utils import load_torch_file, transformers_convert
import os
import torch
@@ -6,7 +6,8 @@ import torch
class ClipVisionModel():
def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config)
- self.model = CLIPVisionModelWithProjection(config)
+ with modeling_utils.no_init_weights():
+ self.model = CLIPVisionModelWithProjection(config)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index 573f4e1c6..5fb4fa2af 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -10,6 +10,7 @@ from .diffusionmodules.util import checkpoint
from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
+import comfy.ops
from . import tomesd
@@ -52,7 +53,7 @@ def init_(tensor):
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
- self.proj = nn.Linear(dim_in, dim_out * 2)
+ self.proj = comfy.ops.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
@@ -65,14 +66,14 @@ class FeedForward(nn.Module):
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
- nn.Linear(dim, inner_dim),
+ comfy.ops.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
- nn.Linear(inner_dim, dim_out)
+ comfy.ops.Linear(inner_dim, dim_out)
)
def forward(self, x):
@@ -154,12 +155,12 @@ class CrossAttentionBirchSan(nn.Module):
self.scale = dim_head ** -0.5
self.heads = heads
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
- nn.Linear(inner_dim, query_dim),
+ comfy.ops.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
@@ -251,12 +252,12 @@ class CrossAttentionDoggettx(nn.Module):
self.scale = dim_head ** -0.5
self.heads = heads
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
- nn.Linear(inner_dim, query_dim),
+ comfy.ops.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
@@ -349,12 +350,12 @@ class CrossAttention(nn.Module):
self.scale = dim_head ** -0.5
self.heads = heads
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
- nn.Linear(inner_dim, query_dim),
+ comfy.ops.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
@@ -407,11 +408,11 @@ class MemoryEfficientCrossAttention(nn.Module):
self.heads = heads
self.dim_head = dim_head
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
- self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None):
@@ -456,11 +457,11 @@ class CrossAttentionPytorch(nn.Module):
self.heads = heads
self.dim_head = dim_head
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
- self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None):
@@ -601,7 +602,7 @@ class SpatialTransformer(nn.Module):
stride=1,
padding=0)
else:
- self.proj_in = nn.Linear(in_channels, inner_dim)
+ self.proj_in = comfy.ops.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
@@ -609,13 +610,12 @@ class SpatialTransformer(nn.Module):
for d in range(depth)]
)
if not use_linear:
- self.proj_out = zero_module(nn.Conv2d(inner_dim,
- in_channels,
+ self.proj_out = nn.Conv2d(inner_dim,in_channels,
kernel_size=1,
stride=1,
- padding=0))
+ padding=0)
else:
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.proj_out = comfy.ops.Linear(in_channels, inner_dim)
self.use_linear = use_linear
def forward(self, x, context=None, transformer_options={}):
diff --git a/comfy/ops.py b/comfy/ops.py
new file mode 100644
index 000000000..0654dbcd9
--- /dev/null
+++ b/comfy/ops.py
@@ -0,0 +1,17 @@
+import torch
+
+class Linear(torch.nn.Module):
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
+ device=None, dtype=None) -> None:
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
+ if bias:
+ self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, input):
+ return torch.nn.functional.linear(input, self.weight, self.bias)
diff --git a/comfy/samplers.py b/comfy/samplers.py
index d3cd901e7..dffd7fe7c 100644
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -273,7 +273,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
if "sampler_cfg_function" in model_options:
- return model_options["sampler_cfg_function"](cond, uncond, cond_scale)
+ args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
+ return model_options["sampler_cfg_function"](args)
else:
return uncond + (cond - uncond) * cond_scale
diff --git a/comfy/sd.py b/comfy/sd.py
index 4b3cb83a3..db04e0426 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -1,6 +1,7 @@
import torch
import contextlib
import copy
+import inspect
from . import sd1_clip
from . import sd2_clip
@@ -313,8 +314,10 @@ class ModelPatcher:
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
def set_model_sampler_cfg_function(self, sampler_cfg_function):
- self.model_options["sampler_cfg_function"] = sampler_cfg_function
-
+ if len(inspect.signature(sampler_cfg_function).parameters) == 3:
+ self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
+ else:
+ self.model_options["sampler_cfg_function"] = sampler_cfg_function
def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"]
@@ -1152,9 +1155,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
else:
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
- model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
-
if fp16:
model = model.half()
+ model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
+
return (ModelPatcher(model), clip, vae, clipvision)
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index 91fb4ff27..0df3d9d91 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -1,6 +1,6 @@
import os
-from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
+from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils
import torch
import traceback
import zipfile
@@ -38,7 +38,8 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
config = CLIPTextConfig.from_json_file(textmodel_json_config)
- self.transformer = CLIPTextModel(config)
+ with modeling_utils.no_init_weights():
+ self.transformer = CLIPTextModel(config)
self.device = device
self.max_length = max_length
diff --git a/execution.py b/execution.py
index d721a0cb8..4fd271ddf 100644
--- a/execution.py
+++ b/execution.py
@@ -313,7 +313,6 @@ class PromptExecutor:
else:
self.server.client_id = None
- execution_start_time = time.perf_counter()
if self.server.client_id is not None:
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
@@ -361,12 +360,7 @@ class PromptExecutor:
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None
- if self.server.client_id is not None:
- self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
- print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
- gc.collect()
- comfy.model_management.soft_empty_cache()
def validate_inputs(prompt, item, validated):
diff --git a/main.py b/main.py
index 527b339aa..99efe5a5a 100644
--- a/main.py
+++ b/main.py
@@ -3,6 +3,8 @@ import itertools
import os
import shutil
import threading
+import gc
+import time
from comfy.cli_args import args
import comfy.utils
@@ -28,15 +30,22 @@ import folder_paths
import server
from server import BinaryEventTypes
from nodes import init_custom_nodes
-
+import comfy.model_management
def prompt_worker(q, server):
e = execution.PromptExecutor(server)
while True:
item, item_id = q.get()
- e.execute(item[2], item[1], item[3], item[4], item[5])
+ execution_start_time = time.perf_counter()
+ prompt_id = item[1]
+ e.execute(item[2], prompt_id, item[3], item[4], item[5])
q.task_done(item_id, e.outputs_ui)
+ if server.client_id is not None:
+ server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
+ print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
+ gc.collect()
+ comfy.model_management.soft_empty_cache()
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js
index 84c2a3d10..592dfd2d1 100644
--- a/web/extensions/core/colorPalette.js
+++ b/web/extensions/core/colorPalette.js
@@ -1,6 +1,5 @@
-import { app } from "/scripts/app.js";
-import { $el } from "/scripts/ui.js";
-import { api } from "/scripts/api.js";
+import {app} from "/scripts/app.js";
+import {$el} from "/scripts/ui.js";
// Manage color palettes
@@ -24,6 +23,8 @@ const colorPalettes = {
"TAESD": "#DCC274", // cheesecake
},
"litegraph_base": {
+ "BACKGROUND_IMAGE": "",
+ "CLEAR_BACKGROUND_COLOR": "#222",
"NODE_TITLE_COLOR": "#999",
"NODE_SELECTED_TITLE_COLOR": "#FFF",
"NODE_TEXT_SIZE": 14,
@@ -77,6 +78,8 @@ const colorPalettes = {
"VAE": "#FF7043", // deep orange
},
"litegraph_base": {
+ "BACKGROUND_IMAGE": "",
+ "CLEAR_BACKGROUND_COLOR": "lightgray",
"NODE_TITLE_COLOR": "#222",
"NODE_SELECTED_TITLE_COLOR": "#000",
"NODE_TEXT_SIZE": 14,
@@ -191,7 +194,7 @@ app.registerExtension({
const nodeData = defs[nodeId];
var inputs = nodeData["input"]["required"];
- if (nodeData["input"]["optional"] != undefined){
+ if (nodeData["input"]["optional"] != undefined) {
inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"])
}
@@ -232,12 +235,9 @@ app.registerExtension({
"id": "my_color_palette_unique_id",
"name": "My Color Palette",
"colors": {
- "node_slot": {
- },
- "litegraph_base": {
- },
- "comfy_base": {
- }
+ "node_slot": {},
+ "litegraph_base": {},
+ "comfy_base": {}
}
};
@@ -266,7 +266,7 @@ app.registerExtension({
};
const addCustomColorPalette = async (colorPalette) => {
- if (typeof(colorPalette) !== "object") {
+ if (typeof (colorPalette) !== "object") {
app.ui.dialog.show("Invalid color palette");
return;
}
@@ -286,7 +286,7 @@ app.registerExtension({
return;
}
- if (colorPalette.colors.node_slot && typeof(colorPalette.colors.node_slot) !== "object") {
+ if (colorPalette.colors.node_slot && typeof (colorPalette.colors.node_slot) !== "object") {
app.ui.dialog.show("Invalid color palette colors.node_slot");
return;
}
@@ -301,7 +301,11 @@ app.registerExtension({
}
}
- els.select.append($el("option", { textContent: colorPalette.name + " (custom)", value: "custom_" + colorPalette.id, selected: true }));
+ els.select.append($el("option", {
+ textContent: colorPalette.name + " (custom)",
+ value: "custom_" + colorPalette.id,
+ selected: true
+ }));
setColorPalette("custom_" + colorPalette.id);
await loadColorPalette(colorPalette);
@@ -350,7 +354,7 @@ app.registerExtension({
if (colorPalette.colors.comfy_base) {
const rootStyle = document.documentElement.style;
for (const key in colorPalette.colors.comfy_base) {
- rootStyle.setProperty('--' + key, colorPalette.colors.comfy_base[key]);
+ rootStyle.setProperty('--' + key, colorPalette.colors.comfy_base[key]);
}
}
app.canvas.draw(true, true);
@@ -380,7 +384,7 @@ app.registerExtension({
const fileInput = $el("input", {
type: "file",
accept: ".json",
- style: { display: "none" },
+ style: {display: "none"},
parent: document.body,
onchange: () => {
let file = fileInput.files[0];
@@ -403,17 +407,25 @@ app.registerExtension({
for (const c in colorPalettes) {
const colorPalette = colorPalettes[c];
- options.push($el("option", { textContent: colorPalette.name, value: colorPalette.id, selected: colorPalette.id === value }));
+ options.push($el("option", {
+ textContent: colorPalette.name,
+ value: colorPalette.id,
+ selected: colorPalette.id === value
+ }));
}
let customColorPalettes = getCustomColorPalettes();
for (const c in customColorPalettes) {
const colorPalette = customColorPalettes[c];
- options.push($el("option", { textContent: colorPalette.name + " (custom)", value: "custom_" + colorPalette.id, selected: "custom_" + colorPalette.id === value }));
+ options.push($el("option", {
+ textContent: colorPalette.name + " (custom)",
+ value: "custom_" + colorPalette.id,
+ selected: "custom_" + colorPalette.id === value
+ }));
}
return $el("div", [
- $el("label", { textContent: name || id }, [
+ $el("label", {textContent: name || id}, [
els.select = $el("select", {
onchange: (e) => {
setter(e.target.value);
@@ -427,12 +439,12 @@ app.registerExtension({
const colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId);
const colorPalette = await completeColorPalette(getColorPalette(colorPaletteId));
const json = JSON.stringify(colorPalette, null, 2); // convert the data to a JSON string
- const blob = new Blob([json], { type: "application/json" });
+ const blob = new Blob([json], {type: "application/json"});
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: colorPaletteId + ".json",
- style: { display: "none" },
+ style: {display: "none"},
parent: document.body,
});
a.click();
@@ -455,12 +467,12 @@ app.registerExtension({
onclick: async () => {
const colorPalette = await getColorPaletteTemplate();
const json = JSON.stringify(colorPalette, null, 2); // convert the data to a JSON string
- const blob = new Blob([json], { type: "application/json" });
+ const blob = new Blob([json], {type: "application/json"});
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: "color_palette.json",
- style: { display: "none" },
+ style: {display: "none"},
parent: document.body,
});
a.click();
@@ -496,15 +508,25 @@ app.registerExtension({
return;
}
- if (colorPalettes[value]) {
- await loadColorPalette(colorPalettes[value]);
+ let palette = colorPalettes[value];
+ if (palette) {
+ await loadColorPalette(palette);
} else if (value.startsWith("custom_")) {
value = value.substr(7);
let customColorPalettes = getCustomColorPalettes();
if (customColorPalettes[value]) {
+ palette = customColorPalettes[value];
await loadColorPalette(customColorPalettes[value]);
}
}
+
+ let {BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR} = palette.colors.litegraph_base;
+ if (BACKGROUND_IMAGE === undefined || CLEAR_BACKGROUND_COLOR === undefined) {
+ const base = colorPalettes["dark"].colors.litegraph_base;
+ BACKGROUND_IMAGE = base.BACKGROUND_IMAGE;
+ CLEAR_BACKGROUND_COLOR = base.CLEAR_BACKGROUND_COLOR;
+ }
+ app.canvas.updateBackground(BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR);
},
});
},
diff --git a/web/index.html b/web/index.html
index da0adb6c2..c48d716e1 100644
--- a/web/index.html
+++ b/web/index.html
@@ -7,6 +7,7 @@
+