Merge branch 'Main' into feature/refresh

* MODIFIED: Change refresh single node (each node button) -> refresh whole node (main menu button)
This commit is contained in:
Lt.Dr.Data 2023-03-23 13:37:49 +09:00
commit b8dd84f545
12 changed files with 242 additions and 87 deletions

View File

@ -26,6 +26,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- Starts up very fast.
- Works fully offline: will never download anything.
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
@ -136,9 +137,9 @@ This will let you use: pip3.10 to install all the dependencies.
## How to increase generation speed?
The fp16 model configs in the CheckpointLoader can be used to load them in fp16 mode, depending on your GPU this will increase your gen speed by a significant amount.
Make sure you use the CheckpointLoaderSimple node to load checkpoints. It will auto pick the right settings depending on your GPU.
You can also set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models.
You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this doesn't do anything when xformers is enabled and will very likely give you black images on SD2.x models.
```--dont-upcast-attention```

View File

@ -20,11 +20,6 @@ if model_management.xformers_enabled():
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
def exists(val):
return val is not None
@ -312,7 +307,7 @@ class CrossAttentionDoggettx(nn.Module):
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except OOM_EXCEPTION as e:
except model_management.OOM_EXCEPTION as e:
if first_op_done == False:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

View File

@ -13,11 +13,6 @@ if model_management.xformers_enabled():
import xformers
import xformers.ops
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
@ -221,7 +216,7 @@ class AttnBlock(nn.Module):
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except OOM_EXCEPTION as e:
except model_management.OOM_EXCEPTION as e:
steps *= 2
if steps > 128:
raise e
@ -616,19 +611,17 @@ class Encoder(nn.Module):
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
already_padded = True
# downsampling
hs = [self.conv_in(x)]
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
h = self.down[i_level].block[i_block](h, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1], already_padded))
h = self.down[i_level].downsample(h, already_padded)
already_padded = False
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)

View File

@ -24,10 +24,7 @@ except ImportError:
from torch import Tensor
from typing import List
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
import model_management
def dynamic_slice(
x: Tensor,
@ -161,7 +158,7 @@ def _get_attention_scores_no_kv_chunking(
try:
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
except OOM_EXCEPTION:
except model_management.OOM_EXCEPTION:
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
torch.exp(attn_scores, out=attn_scores)

View File

@ -31,6 +31,11 @@ try:
except:
pass
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
if "--disable-xformers" in sys.argv:
XFORMERS_IS_AVAILBLE = False
else:
@ -231,7 +236,7 @@ def should_use_fp16():
return False
#FP32 is faster on those cards?
nvidia_16_series = ["1660", "1650", "1630"]
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600"]
for x in nvidia_16_series:
if x in props.name:
return False

View File

@ -129,12 +129,17 @@ def load_lora(path, to_load):
A_name = "{}.lora_up.weight".format(x)
B_name = "{}.lora_down.weight".format(x)
alpha_name = "{}.alpha".format(x)
mid_name = "{}.lora_mid.weight".format(x)
if A_name in lora.keys():
alpha = None
if alpha_name in lora.keys():
alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name)
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha)
mid = None
if mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
loaded_keys.add(A_name)
loaded_keys.add(B_name)
for x in lora.keys():
@ -279,6 +284,10 @@ class ModelPatcher:
mat2 = v[1]
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
return self.model
def unpatch_model(self):
@ -374,20 +383,34 @@ class VAE:
device = model_management.get_torch_device()
self.device = device
def decode(self, samples):
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0)
output = torch.clamp((
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) +
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) +
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8))
/ 3.0) / 2.0, min=0.0, max=1.0)
return output
def decode(self, samples_in):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
samples = samples.to(self.device)
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples)
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0)
try:
samples = samples_in.to(self.device)
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples)
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0)
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
self.first_stage_model = self.first_stage_model.cpu()
pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 8):
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
output = utils.tiled_scale(samples, lambda a: torch.clamp((self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) / 2.0, min=0.0, max=1.0), tile_x, tile_y, overlap, upscale_amount = 8)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
self.first_stage_model = self.first_stage_model.cpu()
return output.movedim(1,-1)
@ -405,6 +428,9 @@ class VAE:
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4)
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4)
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4)
samples /= 3.0
self.first_stage_model = self.first_stage_model.cpu()
samples = samples.cpu()
return samples

View File

@ -10,7 +10,7 @@ class UpscaleModelLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), )},
"widget": { "Refresh": ("REFRESH", [("model_name", "upscale_models")]) }}
}
RETURN_TYPES = ("UPSCALE_MODEL",)
FUNCTION = "load_model"

View File

@ -143,7 +143,7 @@ class PromptExecutor:
else:
self.server.client_id = None
with torch.no_grad():
with torch.inference_mode():
for x in prompt:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)

View File

@ -31,6 +31,8 @@ def before_node_execution():
def interrupt_processing(value=True):
model_management.interrupt_current_processing(value)
MAX_RESOLUTION=8192
class CLIPTextEncode:
@classmethod
def INPUT_TYPES(s):
@ -59,10 +61,10 @@ class ConditioningSetArea:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"width": ("INT", {"default": 64, "min": 64, "max": 4096, "step": 64}),
"height": ("INT", {"default": 64, "min": 64, "max": 4096, "step": 64}),
"x": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 64}),
"y": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 64}),
"width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
"height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
@ -192,7 +194,7 @@ class CheckpointLoader:
def INPUT_TYPES(s):
return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ) },
"widget": { "Refresh": ("REFRESH", [("config_name", "configs"), ("ckpt_name", "checkpoints")]) }}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
@ -207,7 +209,7 @@ class CheckpointLoaderSimple:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ) },
"widget": { "Refresh": ("REFRESH", [("ckpt_name", "checkpoints")]) }}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
@ -242,7 +244,7 @@ class LoraLoader:
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}) },
"widget": { "Refresh": ("REFRESH", [("lora_name", "loras")]) }}
}
RETURN_TYPES = ("MODEL", "CLIP")
FUNCTION = "load_lora"
@ -257,7 +259,7 @@ class VAELoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), ) },
"widget": { "Refresh": ("REFRESH", [("vae_name", "vae")]) }}
}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
@ -273,7 +275,7 @@ class ControlNetLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), )},
"widget": { "Refresh": ("REFRESH", [("control_net_name", "controlnet")]) }}
}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"
@ -290,7 +292,7 @@ class DiffControlNetLoader:
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"control_net_name": (folder_paths.get_filename_list("controlnet"), )},
"widget": { "Refresh": ("REFRESH", [("control_net_name", "controlnet")]) }}
}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"
@ -333,7 +335,8 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),},
"widget": { "Refresh": ("REFRESH", [("clip_name", "clip")]) }}
}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@ -341,14 +344,14 @@ class CLIPLoader:
def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory)
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
class CLIPVisionLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"), )},
"widget": { "Refresh": ("REFRESH", [("clip_name", "clip_vision")]) } }
}
RETURN_TYPES = ("CLIP_VISION",)
FUNCTION = "load_clip"
@ -379,7 +382,7 @@ class StyleModelLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), )},
"widget": { "Refresh": ("REFRESH", [("style_model_name", "style_models")]) }}
}
RETURN_TYPES = ("STYLE_MODEL",)
FUNCTION = "load_style_model"
@ -418,8 +421,8 @@ class EmptyLatentImage:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
"height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
@ -439,8 +442,8 @@ class LatentUpscale:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
"width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
"height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
"crop": (s.crop_methods,)}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "upscale"
@ -501,9 +504,9 @@ class LatentComposite:
def INPUT_TYPES(s):
return {"required": { "samples_to": ("LATENT",),
"samples_from": ("LATENT",),
"x": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}),
"feather": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "composite"
@ -542,10 +545,10 @@ class LatentCrop:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
"height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
"x": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}),
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "crop"
@ -812,7 +815,7 @@ class LoadImage:
if not os.path.exists(s.input_dir):
os.makedirs(s.input_dir)
return {"required": {"image": (sorted(os.listdir(s.input_dir)), )},
"widget": { "Refresh": ("REFRESH", [("image", "input")]) } }
}
CATEGORY = "image"
@ -881,8 +884,8 @@ class ImageScale:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
"width": ("INT", {"default": 512, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": 4096, "step": 1}),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"crop": (s.crop_methods,)}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"

View File

@ -80,10 +80,23 @@ class ComfyApp {
img = this.imgs[this.overIndex];
}
if (img) {
options.unshift({
content: "Open Image",
callback: () => window.open(img.src, "_blank"),
});
options.unshift(
{
content: "Open Image",
callback: () => window.open(img.src, "_blank"),
},
{
content: "Save Image",
callback: () => {
const a = document.createElement("a");
a.href = img.src;
a.setAttribute("download", new URLSearchParams(new URL(img.src).search).get("filename"));
document.body.append(a);
a.click();
requestAnimationFrame(() => a.remove());
},
}
);
}
}
};
@ -481,6 +494,7 @@ class ComfyApp {
// Create and mount the LiteGraph in the DOM
const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" }));
canvasEl.tabIndex = "1"
document.body.prepend(canvasEl);
this.graph = new LGraph();
@ -760,6 +774,71 @@ class ComfyApp {
}
this.extensions.push(extension);
}
/**
* Refresh file list on whole nodes
*/
async refreshNodes() {
for(let nodeNum in this.graph._nodes) {
const node = this.graph._nodes[nodeNum];
var data = [];
switch(node.type) {
case "CheckpointLoader":
data = { "config_name": "configs",
"ckpt_name": "checkpoints" };
break;
case "CheckpointLoaderSimple":
data = { "ckpt_name": "checkpoints" };
break;
case "LoraLoader":
data = { "lora_name": "loras" };
break;
case "VAELoader":
data = { "vae_name": "vae" };
break;
case "ControlNetLoader":
case "DiffControlNetLoader":
data = { "control_net_name": "controlnet" };
break;
case "CLIPLoader":
data = { "clip_name": "clip" };
break;
case "CLIPVisionLoader":
data = { "clip_name": "clip_vision" };
break;
case "StyleModelLoader":
data = { "style_model_name": "style_models" };
break;
case "LoadImage":
data = { "image": "input" };
break;
case "UpscaleModelLoader":
data = { "model_name": "upscale_models" };
break;
default:
break;
}
for (let i in data) {
const w = node.widgets.find((w) => w.name === i);
const filelist = await api.getFiles(data[i]);
w.options.values = filelist.files;
w.value = filelist.files[0];
}
}
}
}
export const app = new ComfyApp();

View File

@ -326,6 +326,7 @@ export class ComfyUI {
}, 0);
},
}),
$el("button", { textContent: "Refresh", onclick: () => app.refreshNodes() }),
$el("button", { textContent: "Load", onclick: () => fileInput.click() }),
$el("button", { textContent: "Clear", onclick: () => app.graph.clear() }),
$el("button", { textContent: "Load Default", onclick: () => app.loadGraphData() }),

View File

@ -1,5 +1,3 @@
import { api } from "./api.js";
function getNumberDefaults(inputData, defaultStep) {
let defaultVal = inputData[1]["default"];
let { min, max, step } = inputData[1];
@ -29,22 +27,59 @@ function seedWidget(node, inputName, inputData) {
return { widget: seed, randomize };
}
function refreshWidget(node, name, data) {
async function refresh_callback() {
const items = data[1];
for (let i in items) {
const w = node.widgets.find((w) => w.name === items[i][0]);
const filelist = await api.getFiles(items[i][1]);
w.options.values = filelist.files;
w.value = filelist.files[0];
const MultilineSymbol = Symbol();
function addMultilineWidget(node, name, opts, app) {
const MIN_SIZE = 50;
function computeSize(size) {
if (node.widgets[0].last_y == null) return;
let y = node.widgets[0].last_y;
let freeSpace = size[1] - y;
// Compute the height of all non customtext widgets
let widgetHeight = 0;
const multi = [];
for (let i = 0; i < node.widgets.length; i++) {
const w = node.widgets[i];
if (w.type === "customtext") {
multi.push(w);
} else {
if (w.computeSize) {
widgetHeight += w.computeSize()[1] + 4;
} else {
widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
}
// See how large each text input can be
freeSpace -= widgetHeight;
freeSpace /= multi.length;
if (freeSpace < MIN_SIZE) {
// There isnt enough space for all the widgets, increase the size of the node
freeSpace = MIN_SIZE;
node.size[1] = y + widgetHeight + freeSpace * multi.length;
node.graph.setDirtyCanvas(true);
}
// Position each of the widgets
for (const w of node.widgets) {
w.y = y;
if (w.type === "customtext") {
y += freeSpace;
} else if (w.computeSize) {
y += w.computeSize()[1] + 4;
} else {
y += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
node.inputHeight = freeSpace;
}
const refresh = node.addWidget("button", name, true, function(v) { refresh_callback(); }, {});
return { refresh };
}
function addMultilineWidget(node, name, defaultVal, app) {
const widget = {
type: "customtext",
name,
@ -55,14 +90,19 @@ function addMultilineWidget(node, name, defaultVal, app) {
this.inputEl.value = x;
},
draw: function (ctx, _, widgetWidth, y, widgetHeight) {
if (!this.parent.inputHeight) {
// If we are initially offscreen when created we wont have received a resize event
// Calculate it here instead
computeSize(node.size);
}
const visible = app.canvas.ds.scale > 0.5;
const t = ctx.getTransform();
const margin = 10;
Object.assign(this.inputEl.style, {
left: `${t.a * margin + t.e}px`,
top: `${t.d * (y + widgetHeight - margin) + t.f}px`,
top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`,
width: `${(widgetWidth - margin * 2 - 3) * t.a}px`,
height: `${(this.parent.size[1] - (y + widgetHeight) - 3) * t.d}px`,
height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`,
position: "absolute",
zIndex: 1,
fontSize: `${t.d * 10.0}px`,
@ -72,7 +112,8 @@ function addMultilineWidget(node, name, defaultVal, app) {
};
widget.inputEl = document.createElement("textarea");
widget.inputEl.className = "comfy-multiline-input";
widget.inputEl.value = defaultVal;
widget.inputEl.value = opts.defaultVal;
widget.inputEl.placeholder = opts.placeholder || "";
document.addEventListener("mousedown", function (event) {
if (!widget.inputEl.contains(event.target)) {
widget.inputEl.blur();
@ -108,6 +149,20 @@ function addMultilineWidget(node, name, defaultVal, app) {
}
};
if (!(MultilineSymbol in node)) {
node[MultilineSymbol] = true;
const onResize = node.onResize;
node.onResize = function (size) {
computeSize(size);
// Call original resizer handler
if (onResize) {
onResize.apply(this, arguments);
}
};
}
return { minWidth: 400, minHeight: 200, widget };
}
@ -120,6 +175,7 @@ export const ComfyWidgets = {
},
INT(node, inputName, inputData) {
const { val, config } = getNumberDefaults(inputData, 1);
Object.assign(config, { precision: 0 });
return {
widget: node.addWidget(
"number",
@ -133,13 +189,12 @@ export const ComfyWidgets = {
),
};
},
REFRESH:refreshWidget,
STRING(node, inputName, inputData, app) {
const defaultVal = inputData[1].default || "";
const multiline = !!inputData[1].multiline;
if (multiline) {
return addMultilineWidget(node, inputName, defaultVal, app);
return addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app);
} else {
return { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) };
}
@ -151,7 +206,7 @@ export const ComfyWidgets = {
function showImage(name) {
// Position the image somewhere sensible
if (!node.imageOffset) {
node.imageOffset = uploadWidget.last_y ? uploadWidget.last_y + 50 : 100;
node.imageOffset = uploadWidget.last_y ? uploadWidget.last_y + 25 : 75;
}
const img = new Image();