diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index efb2d5384..a95707e46 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,4 +1,4 @@ -from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor +from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils from .utils import load_torch_file, transformers_convert import os import torch @@ -6,7 +6,8 @@ import torch class ClipVisionModel(): def __init__(self, json_config): config = CLIPVisionConfig.from_json_file(json_config) - self.model = CLIPVisionModelWithProjection(config) + with modeling_utils.no_init_weights(): + self.model = CLIPVisionModelWithProjection(config) self.processor = CLIPImageProcessor(crop_size=224, do_center_crop=True, do_convert_rgb=True, diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 573f4e1c6..5fb4fa2af 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -10,6 +10,7 @@ from .diffusionmodules.util import checkpoint from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management +import comfy.ops from . import tomesd @@ -52,7 +53,7 @@ def init_(tensor): class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) + self.proj = comfy.ops.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) @@ -65,14 +66,14 @@ class FeedForward(nn.Module): inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential( - nn.Linear(dim, inner_dim), + comfy.ops.Linear(dim, inner_dim), nn.GELU() ) if not glu else GEGLU(dim, inner_dim) self.net = nn.Sequential( project_in, nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) + comfy.ops.Linear(inner_dim, dim_out) ) def forward(self, x): @@ -154,12 +155,12 @@ class CrossAttentionBirchSan(nn.Module): self.scale = dim_head ** -0.5 self.heads = heads - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False) + self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False) + self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), + comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) @@ -251,12 +252,12 @@ class CrossAttentionDoggettx(nn.Module): self.scale = dim_head ** -0.5 self.heads = heads - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False) + self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False) + self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), + comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) @@ -349,12 +350,12 @@ class CrossAttention(nn.Module): self.scale = dim_head ** -0.5 self.heads = heads - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False) + self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False) + self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), + comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) @@ -407,11 +408,11 @@ class MemoryEfficientCrossAttention(nn.Module): self.heads = heads self.dim_head = dim_head - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False) + self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False) + self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None def forward(self, x, context=None, value=None, mask=None): @@ -456,11 +457,11 @@ class CrossAttentionPytorch(nn.Module): self.heads = heads self.dim_head = dim_head - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False) + self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False) + self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None def forward(self, x, context=None, value=None, mask=None): @@ -601,7 +602,7 @@ class SpatialTransformer(nn.Module): stride=1, padding=0) else: - self.proj_in = nn.Linear(in_channels, inner_dim) + self.proj_in = comfy.ops.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList( [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], @@ -609,13 +610,12 @@ class SpatialTransformer(nn.Module): for d in range(depth)] ) if not use_linear: - self.proj_out = zero_module(nn.Conv2d(inner_dim, - in_channels, + self.proj_out = nn.Conv2d(inner_dim,in_channels, kernel_size=1, stride=1, - padding=0)) + padding=0) else: - self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.proj_out = comfy.ops.Linear(in_channels, inner_dim) self.use_linear = use_linear def forward(self, x, context=None, transformer_options={}): diff --git a/comfy/ops.py b/comfy/ops.py new file mode 100644 index 000000000..0654dbcd9 --- /dev/null +++ b/comfy/ops.py @@ -0,0 +1,17 @@ +import torch + +class Linear(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter('bias', None) + + def forward(self, input): + return torch.nn.functional.linear(input, self.weight, self.bias) diff --git a/comfy/samplers.py b/comfy/samplers.py index d3cd901e7..dffd7fe7c 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -273,7 +273,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con max_total_area = model_management.maximum_batch_area() cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) if "sampler_cfg_function" in model_options: - return model_options["sampler_cfg_function"](cond, uncond, cond_scale) + args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep} + return model_options["sampler_cfg_function"](args) else: return uncond + (cond - uncond) * cond_scale diff --git a/comfy/sd.py b/comfy/sd.py index 4b3cb83a3..db04e0426 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,6 +1,7 @@ import torch import contextlib import copy +import inspect from . import sd1_clip from . import sd2_clip @@ -313,8 +314,10 @@ class ModelPatcher: self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio} def set_model_sampler_cfg_function(self, sampler_cfg_function): - self.model_options["sampler_cfg_function"] = sampler_cfg_function - + if len(inspect.signature(sampler_cfg_function).parameters) == 3: + self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way + else: + self.model_options["sampler_cfg_function"] = sampler_cfg_function def set_model_patch(self, patch, name): to = self.model_options["transformer_options"] @@ -1152,9 +1155,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o else: model = model_base.BaseModel(unet_config, v_prediction=v_prediction) - model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) - if fp16: model = model.half() + model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) + return (ModelPatcher(model), clip, vae, clipvision) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 91fb4ff27..0df3d9d91 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -1,6 +1,6 @@ import os -from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils import torch import traceback import zipfile @@ -38,7 +38,8 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") config = CLIPTextConfig.from_json_file(textmodel_json_config) - self.transformer = CLIPTextModel(config) + with modeling_utils.no_init_weights(): + self.transformer = CLIPTextModel(config) self.device = device self.max_length = max_length diff --git a/execution.py b/execution.py index d721a0cb8..4fd271ddf 100644 --- a/execution.py +++ b/execution.py @@ -313,7 +313,6 @@ class PromptExecutor: else: self.server.client_id = None - execution_start_time = time.perf_counter() if self.server.client_id is not None: self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) @@ -361,12 +360,7 @@ class PromptExecutor: for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) self.server.last_node_id = None - if self.server.client_id is not None: - self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) - print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) - gc.collect() - comfy.model_management.soft_empty_cache() def validate_inputs(prompt, item, validated): diff --git a/main.py b/main.py index 527b339aa..99efe5a5a 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,8 @@ import itertools import os import shutil import threading +import gc +import time from comfy.cli_args import args import comfy.utils @@ -28,15 +30,22 @@ import folder_paths import server from server import BinaryEventTypes from nodes import init_custom_nodes - +import comfy.model_management def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: item, item_id = q.get() - e.execute(item[2], item[1], item[3], item[4], item[5]) + execution_start_time = time.perf_counter() + prompt_id = item[1] + e.execute(item[2], prompt_id, item[3], item[4], item[5]) q.task_done(item_id, e.outputs_ui) + if server.client_id is not None: + server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) + print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) + gc.collect() + comfy.model_management.soft_empty_cache() async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js index 84c2a3d10..592dfd2d1 100644 --- a/web/extensions/core/colorPalette.js +++ b/web/extensions/core/colorPalette.js @@ -1,6 +1,5 @@ -import { app } from "/scripts/app.js"; -import { $el } from "/scripts/ui.js"; -import { api } from "/scripts/api.js"; +import {app} from "/scripts/app.js"; +import {$el} from "/scripts/ui.js"; // Manage color palettes @@ -24,6 +23,8 @@ const colorPalettes = { "TAESD": "#DCC274", // cheesecake }, "litegraph_base": { + "BACKGROUND_IMAGE": "", + "CLEAR_BACKGROUND_COLOR": "#222", "NODE_TITLE_COLOR": "#999", "NODE_SELECTED_TITLE_COLOR": "#FFF", "NODE_TEXT_SIZE": 14, @@ -77,6 +78,8 @@ const colorPalettes = { "VAE": "#FF7043", // deep orange }, "litegraph_base": { + "BACKGROUND_IMAGE": "", + "CLEAR_BACKGROUND_COLOR": "lightgray", "NODE_TITLE_COLOR": "#222", "NODE_SELECTED_TITLE_COLOR": "#000", "NODE_TEXT_SIZE": 14, @@ -191,7 +194,7 @@ app.registerExtension({ const nodeData = defs[nodeId]; var inputs = nodeData["input"]["required"]; - if (nodeData["input"]["optional"] != undefined){ + if (nodeData["input"]["optional"] != undefined) { inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"]) } @@ -232,12 +235,9 @@ app.registerExtension({ "id": "my_color_palette_unique_id", "name": "My Color Palette", "colors": { - "node_slot": { - }, - "litegraph_base": { - }, - "comfy_base": { - } + "node_slot": {}, + "litegraph_base": {}, + "comfy_base": {} } }; @@ -266,7 +266,7 @@ app.registerExtension({ }; const addCustomColorPalette = async (colorPalette) => { - if (typeof(colorPalette) !== "object") { + if (typeof (colorPalette) !== "object") { app.ui.dialog.show("Invalid color palette"); return; } @@ -286,7 +286,7 @@ app.registerExtension({ return; } - if (colorPalette.colors.node_slot && typeof(colorPalette.colors.node_slot) !== "object") { + if (colorPalette.colors.node_slot && typeof (colorPalette.colors.node_slot) !== "object") { app.ui.dialog.show("Invalid color palette colors.node_slot"); return; } @@ -301,7 +301,11 @@ app.registerExtension({ } } - els.select.append($el("option", { textContent: colorPalette.name + " (custom)", value: "custom_" + colorPalette.id, selected: true })); + els.select.append($el("option", { + textContent: colorPalette.name + " (custom)", + value: "custom_" + colorPalette.id, + selected: true + })); setColorPalette("custom_" + colorPalette.id); await loadColorPalette(colorPalette); @@ -350,7 +354,7 @@ app.registerExtension({ if (colorPalette.colors.comfy_base) { const rootStyle = document.documentElement.style; for (const key in colorPalette.colors.comfy_base) { - rootStyle.setProperty('--' + key, colorPalette.colors.comfy_base[key]); + rootStyle.setProperty('--' + key, colorPalette.colors.comfy_base[key]); } } app.canvas.draw(true, true); @@ -380,7 +384,7 @@ app.registerExtension({ const fileInput = $el("input", { type: "file", accept: ".json", - style: { display: "none" }, + style: {display: "none"}, parent: document.body, onchange: () => { let file = fileInput.files[0]; @@ -403,17 +407,25 @@ app.registerExtension({ for (const c in colorPalettes) { const colorPalette = colorPalettes[c]; - options.push($el("option", { textContent: colorPalette.name, value: colorPalette.id, selected: colorPalette.id === value })); + options.push($el("option", { + textContent: colorPalette.name, + value: colorPalette.id, + selected: colorPalette.id === value + })); } let customColorPalettes = getCustomColorPalettes(); for (const c in customColorPalettes) { const colorPalette = customColorPalettes[c]; - options.push($el("option", { textContent: colorPalette.name + " (custom)", value: "custom_" + colorPalette.id, selected: "custom_" + colorPalette.id === value })); + options.push($el("option", { + textContent: colorPalette.name + " (custom)", + value: "custom_" + colorPalette.id, + selected: "custom_" + colorPalette.id === value + })); } return $el("div", [ - $el("label", { textContent: name || id }, [ + $el("label", {textContent: name || id}, [ els.select = $el("select", { onchange: (e) => { setter(e.target.value); @@ -427,12 +439,12 @@ app.registerExtension({ const colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId); const colorPalette = await completeColorPalette(getColorPalette(colorPaletteId)); const json = JSON.stringify(colorPalette, null, 2); // convert the data to a JSON string - const blob = new Blob([json], { type: "application/json" }); + const blob = new Blob([json], {type: "application/json"}); const url = URL.createObjectURL(blob); const a = $el("a", { href: url, download: colorPaletteId + ".json", - style: { display: "none" }, + style: {display: "none"}, parent: document.body, }); a.click(); @@ -455,12 +467,12 @@ app.registerExtension({ onclick: async () => { const colorPalette = await getColorPaletteTemplate(); const json = JSON.stringify(colorPalette, null, 2); // convert the data to a JSON string - const blob = new Blob([json], { type: "application/json" }); + const blob = new Blob([json], {type: "application/json"}); const url = URL.createObjectURL(blob); const a = $el("a", { href: url, download: "color_palette.json", - style: { display: "none" }, + style: {display: "none"}, parent: document.body, }); a.click(); @@ -496,15 +508,25 @@ app.registerExtension({ return; } - if (colorPalettes[value]) { - await loadColorPalette(colorPalettes[value]); + let palette = colorPalettes[value]; + if (palette) { + await loadColorPalette(palette); } else if (value.startsWith("custom_")) { value = value.substr(7); let customColorPalettes = getCustomColorPalettes(); if (customColorPalettes[value]) { + palette = customColorPalettes[value]; await loadColorPalette(customColorPalettes[value]); } } + + let {BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR} = palette.colors.litegraph_base; + if (BACKGROUND_IMAGE === undefined || CLEAR_BACKGROUND_COLOR === undefined) { + const base = colorPalettes["dark"].colors.litegraph_base; + BACKGROUND_IMAGE = base.BACKGROUND_IMAGE; + CLEAR_BACKGROUND_COLOR = base.CLEAR_BACKGROUND_COLOR; + } + app.canvas.updateBackground(BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR); }, }); }, diff --git a/web/index.html b/web/index.html index da0adb6c2..c48d716e1 100644 --- a/web/index.html +++ b/web/index.html @@ -7,6 +7,7 @@ +