Merge remote-tracking branch 'upstream/master'

This commit is contained in:
ssit 2023-06-18 14:14:13 -04:00
commit 27cfe3d0d3
12 changed files with 397 additions and 237 deletions

View File

@ -87,13 +87,13 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae Put your VAE in: models/vae
At the time of writing this pytorch has issues with python versions higher than 3.10 so make sure your python/pip versions are 3.10.
### AMD GPUs (Linux only) ### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2``` ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
This is the command to install the nightly with ROCm 5.5 that supports the 7000 series and might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.5 -r requirements.txt```
### NVIDIA ### NVIDIA
@ -178,16 +178,6 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
```embedding:embedding_filename.pt``` ```embedding:embedding_filename.pt```
### Fedora
To get python 3.10 on fedora:
```dnf install python3.10```
Then you can:
```python3.10 -m ensurepip```
This will let you use: pip3.10 to install all the dependencies.
## How to increase generation speed? ## How to increase generation speed?

View File

@ -134,7 +134,7 @@ class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for CompVis diffusion models.""" """A wrapper for CompVis diffusion models."""
def __init__(self, model, quantize=False, device='cpu'): def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize) super().__init__(model, model.alphas_cumprod.float(), quantize=quantize)
def get_eps(self, *args, **kwargs): def get_eps(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs) return self.inner_model.apply_model(*args, **kwargs)
@ -173,7 +173,7 @@ class CompVisVDenoiser(DiscreteVDDPMDenoiser):
"""A wrapper for CompVis diffusion models that output v.""" """A wrapper for CompVis diffusion models that output v."""
def __init__(self, model, quantize=False, device='cpu'): def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize) super().__init__(model, model.alphas_cumprod.float(), quantize=quantize)
def get_v(self, x, t, cond, **kwargs): def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond) return self.inner_model.apply_model(x, t, cond)

View File

@ -284,7 +284,7 @@ class DDIMSampler(object):
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.model.parameterization == "v": if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
else: else:
e_t = model_output e_t = model_output
@ -306,7 +306,7 @@ class DDIMSampler(object):
if self.model.parameterization != "v": if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else: else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
if quantize_denoised: if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)

View File

@ -51,9 +51,9 @@ def init_(tensor):
# feedforward # feedforward
class GEGLU(nn.Module): class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out): def __init__(self, dim_in, dim_out, dtype=None):
super().__init__() super().__init__()
self.proj = comfy.ops.Linear(dim_in, dim_out * 2) self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype)
def forward(self, x): def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1) x, gate = self.proj(x).chunk(2, dim=-1)
@ -68,7 +68,7 @@ class FeedForward(nn.Module):
project_in = nn.Sequential( project_in = nn.Sequential(
comfy.ops.Linear(dim, inner_dim, dtype=dtype), comfy.ops.Linear(dim, inner_dim, dtype=dtype),
nn.GELU() nn.GELU()
) if not glu else GEGLU(dim, inner_dim) ) if not glu else GEGLU(dim, inner_dim, dtype=dtype)
self.net = nn.Sequential( self.net = nn.Sequential(
project_in, project_in,
@ -89,8 +89,8 @@ def zero_module(module):
return module return module
def Normalize(in_channels): def Normalize(in_channels, dtype=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype)
class SpatialSelfAttention(nn.Module): class SpatialSelfAttention(nn.Module):
@ -594,7 +594,7 @@ class SpatialTransformer(nn.Module):
context_dim = [context_dim] context_dim = [context_dim]
self.in_channels = in_channels self.in_channels = in_channels
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels, dtype=dtype)
if not use_linear: if not use_linear:
self.proj_in = nn.Conv2d(in_channels, self.proj_in = nn.Conv2d(in_channels,
inner_dim, inner_dim,

View File

@ -111,14 +111,14 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims self.dims = dims
if use_conv: if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype)
def forward(self, x, output_shape=None): def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
@ -160,7 +160,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions. downsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
@ -169,7 +169,7 @@ class Downsample(nn.Module):
stride = 2 if dims != 3 else (1, 2, 2) stride = 2 if dims != 3 else (1, 2, 2)
if use_conv: if use_conv:
self.op = conv_nd( self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype
) )
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
@ -220,7 +220,7 @@ class ResBlock(TimestepBlock):
self.use_scale_shift_norm = use_scale_shift_norm self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential( self.in_layers = nn.Sequential(
normalization(channels), normalization(channels, dtype=dtype),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype), conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
) )
@ -228,11 +228,11 @@ class ResBlock(TimestepBlock):
self.updown = up or down self.updown = up or down
if up: if up:
self.h_upd = Upsample(channels, False, dims) self.h_upd = Upsample(channels, False, dims, dtype=dtype)
self.x_upd = Upsample(channels, False, dims) self.x_upd = Upsample(channels, False, dims, dtype=dtype)
elif down: elif down:
self.h_upd = Downsample(channels, False, dims) self.h_upd = Downsample(channels, False, dims, dtype=dtype)
self.x_upd = Downsample(channels, False, dims) self.x_upd = Downsample(channels, False, dims, dtype=dtype)
else: else:
self.h_upd = self.x_upd = nn.Identity() self.h_upd = self.x_upd = nn.Identity()
@ -240,11 +240,11 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
linear( linear(
emb_channels, emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype
), ),
) )
self.out_layers = nn.Sequential( self.out_layers = nn.Sequential(
normalization(self.out_channels), normalization(self.out_channels, dtype=dtype),
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(
@ -604,6 +604,7 @@ class UNetModel(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype
) )
] ]
ch = mult * model_channels ch = mult * model_channels
@ -651,10 +652,11 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
down=True, down=True,
dtype=self.dtype
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype
) )
) )
) )
@ -679,6 +681,7 @@ class UNetModel(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype
), ),
AttentionBlock( AttentionBlock(
ch, ch,
@ -698,6 +701,7 @@ class UNetModel(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype
), ),
) )
self._feature_size += ch self._feature_size += ch
@ -715,6 +719,7 @@ class UNetModel(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype
) )
] ]
ch = model_channels * mult ch = model_channels * mult
@ -758,18 +763,19 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
up=True, up=True,
dtype=self.dtype
) )
if resblock_updown if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype)
) )
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch self._feature_size += ch
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(ch, dtype=self.dtype),
nn.SiLU(), nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
) )
if self.predict_codebook_ids: if self.predict_codebook_ids:
self.id_predictor = nn.Sequential( self.id_predictor = nn.Sequential(

View File

@ -206,13 +206,13 @@ def mean_flat(tensor):
return tensor.mean(dim=list(range(1, len(tensor.shape)))) return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels): def normalization(channels, dtype=None):
""" """
Make a standard normalization layer. Make a standard normalization layer.
:param channels: number of input channels. :param channels: number of input channels.
:return: an nn.Module for normalization. :return: an nn.Module for normalization.
""" """
return GroupNorm32(32, channels) return GroupNorm32(32, channels, dtype=dtype)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.

View File

@ -1159,9 +1159,6 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
else: else:
model = model_base.BaseModel(unet_config, v_prediction=v_prediction) model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
if fp16:
model = model.half()
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
return (ModelPatcher(model), clip, vae, clipvision) return (ModelPatcher(model), clip, vae, clipvision)

View File

@ -756,7 +756,7 @@ class RepeatLatentBatch:
return (s,) return (s,)
class LatentUpscale: class LatentUpscale:
upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"] upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
crop_methods = ["disabled", "center"] crop_methods = ["disabled", "center"]
@classmethod @classmethod
@ -776,7 +776,7 @@ class LatentUpscale:
return (s,) return (s,)
class LatentUpscaleBy: class LatentUpscaleBy:
upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"] upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -1172,7 +1172,7 @@ class LoadImageMask:
return True return True
class ImageScale: class ImageScale:
upscale_methods = ["nearest-exact", "bilinear", "area"] upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
crop_methods = ["disabled", "center"] crop_methods = ["disabled", "center"]
@classmethod @classmethod
@ -1193,7 +1193,7 @@ class ImageScale:
return (s,) return (s,)
class ImageScaleBy: class ImageScaleBy:
upscale_methods = ["nearest-exact", "bilinear", "area"] upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):

View File

@ -56,7 +56,9 @@ const colorPalettes = {
"descrip-text": "#999", "descrip-text": "#999",
"drag-text": "#ccc", "drag-text": "#ccc",
"error-text": "#ff4444", "error-text": "#ff4444",
"border-color": "#4e4e4e" "border-color": "#4e4e4e",
"tr-even-bg-color": "#222",
"tr-odd-bg-color": "#353535",
} }
}, },
}, },
@ -111,7 +113,9 @@ const colorPalettes = {
"descrip-text": "#444", "descrip-text": "#444",
"drag-text": "#555", "drag-text": "#555",
"error-text": "#F44336", "error-text": "#F44336",
"border-color": "#888" "border-color": "#888",
"tr-even-bg-color": "#f9f9f9",
"tr-odd-bg-color": "#fff",
} }
}, },
}, },
@ -165,7 +169,9 @@ const colorPalettes = {
"descrip-text": "#586e75", // Base01 "descrip-text": "#586e75", // Base01
"drag-text": "#839496", // Base0 "drag-text": "#839496", // Base0
"error-text": "#dc322f", // Solarized Red "error-text": "#dc322f", // Solarized Red
"border-color": "#657b83" // Base00 "border-color": "#657b83", // Base00
"tr-even-bg-color": "#002b36",
"tr-odd-bg-color": "#073642",
} }
}, },
} }
@ -194,7 +200,7 @@ app.registerExtension({
const nodeData = defs[nodeId]; const nodeData = defs[nodeId];
var inputs = nodeData["input"]["required"]; var inputs = nodeData["input"]["required"];
if (nodeData["input"]["optional"] != undefined) { if (nodeData["input"]["optional"] !== undefined) {
inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"]) inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"])
} }
@ -214,7 +220,7 @@ app.registerExtension({
} }
return types; return types;
}; }
function completeColorPalette(colorPalette) { function completeColorPalette(colorPalette) {
var types = getSlotTypes(); var types = getSlotTypes();
@ -228,7 +234,7 @@ app.registerExtension({
colorPalette.colors.node_slot = sortObjectKeys(colorPalette.colors.node_slot); colorPalette.colors.node_slot = sortObjectKeys(colorPalette.colors.node_slot);
return colorPalette; return colorPalette;
}; }
const getColorPaletteTemplate = async () => { const getColorPaletteTemplate = async () => {
let colorPalette = { let colorPalette = {
@ -267,31 +273,31 @@ app.registerExtension({
const addCustomColorPalette = async (colorPalette) => { const addCustomColorPalette = async (colorPalette) => {
if (typeof (colorPalette) !== "object") { if (typeof (colorPalette) !== "object") {
app.ui.dialog.show("Invalid color palette"); alert("Invalid color palette.");
return; return;
} }
if (!colorPalette.id) { if (!colorPalette.id) {
app.ui.dialog.show("Color palette missing id"); alert("Color palette missing id.");
return; return;
} }
if (!colorPalette.name) { if (!colorPalette.name) {
app.ui.dialog.show("Color palette missing name"); alert("Color palette missing name.");
return; return;
} }
if (!colorPalette.colors) { if (!colorPalette.colors) {
app.ui.dialog.show("Color palette missing colors"); alert("Color palette missing colors.");
return; return;
} }
if (colorPalette.colors.node_slot && typeof (colorPalette.colors.node_slot) !== "object") { if (colorPalette.colors.node_slot && typeof (colorPalette.colors.node_slot) !== "object") {
app.ui.dialog.show("Invalid color palette colors.node_slot"); alert("Invalid color palette colors.node_slot.");
return; return;
} }
let customColorPalettes = getCustomColorPalettes(); const customColorPalettes = getCustomColorPalettes();
customColorPalettes[colorPalette.id] = colorPalette; customColorPalettes[colorPalette.id] = colorPalette;
setCustomColorPalettes(customColorPalettes); setCustomColorPalettes(customColorPalettes);
@ -312,7 +318,7 @@ app.registerExtension({
}; };
const deleteCustomColorPalette = async (colorPaletteId) => { const deleteCustomColorPalette = async (colorPaletteId) => {
let customColorPalettes = getCustomColorPalettes(); const customColorPalettes = getCustomColorPalettes();
delete customColorPalettes[colorPaletteId]; delete customColorPalettes[colorPaletteId];
setCustomColorPalettes(customColorPalettes); setCustomColorPalettes(customColorPalettes);
@ -387,8 +393,7 @@ app.registerExtension({
style: {display: "none"}, style: {display: "none"},
parent: document.body, parent: document.body,
onchange: () => { onchange: () => {
let file = fileInput.files[0]; const file = fileInput.files[0];
if (file.type === "application/json" || file.name.endsWith(".json")) { if (file.type === "application/json" || file.name.endsWith(".json")) {
const reader = new FileReader(); const reader = new FileReader();
reader.onload = async () => { reader.onload = async () => {
@ -403,104 +408,116 @@ app.registerExtension({
id, id,
name: "Color Palette", name: "Color Palette",
type: (name, setter, value) => { type: (name, setter, value) => {
let options = []; const options = [
...Object.values(colorPalettes).map(c=> $el("option", {
textContent: c.name,
value: c.id,
selected: c.id === value
})),
...Object.values(getCustomColorPalettes()).map(c=>$el("option", {
textContent: `${c.name} (custom)`,
value: `custom_${c.id}`,
selected: `custom_${c.id}` === value
})) ,
];
for (const c in colorPalettes) { els.select = $el("select", {
const colorPalette = colorPalettes[c]; style: {
options.push($el("option", { marginBottom: "0.15rem",
textContent: colorPalette.name, width: "100%",
value: colorPalette.id, },
selected: colorPalette.id === value onchange: (e) => {
})); setter(e.target.value);
} }
}, options)
let customColorPalettes = getCustomColorPalettes(); return $el("tr", [
for (const c in customColorPalettes) { $el("td", [
const colorPalette = customColorPalettes[c]; $el("label", {
options.push($el("option", { for: id.replaceAll(".", "-"),
textContent: colorPalette.name + " (custom)", textContent: "Color palette:",
value: "custom_" + colorPalette.id, }),
selected: "custom_" + colorPalette.id === value
}));
}
return $el("div", [
$el("label", {textContent: name || id}, [
els.select = $el("select", {
onchange: (e) => {
setter(e.target.value);
}
}, options)
]), ]),
$el("input", { $el("td", [
type: "button", els.select,
value: "Export", $el("div", {
onclick: async () => { style: {
const colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId); display: "grid",
const colorPalette = await completeColorPalette(getColorPalette(colorPaletteId)); gap: "4px",
const json = JSON.stringify(colorPalette, null, 2); // convert the data to a JSON string gridAutoFlow: "column",
const blob = new Blob([json], {type: "application/json"}); },
const url = URL.createObjectURL(blob); }, [
const a = $el("a", { $el("input", {
href: url, type: "button",
download: colorPaletteId + ".json", value: "Export",
style: {display: "none"}, onclick: async () => {
parent: document.body, const colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId);
}); const colorPalette = await completeColorPalette(getColorPalette(colorPaletteId));
a.click(); const json = JSON.stringify(colorPalette, null, 2); // convert the data to a JSON string
setTimeout(function () { const blob = new Blob([json], {type: "application/json"});
a.remove(); const url = URL.createObjectURL(blob);
window.URL.revokeObjectURL(url); const a = $el("a", {
}, 0); href: url,
}, download: colorPaletteId + ".json",
}), style: {display: "none"},
$el("input", { parent: document.body,
type: "button", });
value: "Import", a.click();
onclick: () => { setTimeout(function () {
fileInput.click(); a.remove();
} window.URL.revokeObjectURL(url);
}), }, 0);
$el("input", { },
type: "button", }),
value: "Template", $el("input", {
onclick: async () => { type: "button",
const colorPalette = await getColorPaletteTemplate(); value: "Import",
const json = JSON.stringify(colorPalette, null, 2); // convert the data to a JSON string onclick: () => {
const blob = new Blob([json], {type: "application/json"}); fileInput.click();
const url = URL.createObjectURL(blob); }
const a = $el("a", { }),
href: url, $el("input", {
download: "color_palette.json", type: "button",
style: {display: "none"}, value: "Template",
parent: document.body, onclick: async () => {
}); const colorPalette = await getColorPaletteTemplate();
a.click(); const json = JSON.stringify(colorPalette, null, 2); // convert the data to a JSON string
setTimeout(function () { const blob = new Blob([json], {type: "application/json"});
a.remove(); const url = URL.createObjectURL(blob);
window.URL.revokeObjectURL(url); const a = $el("a", {
}, 0); href: url,
} download: "color_palette.json",
}), style: {display: "none"},
$el("input", { parent: document.body,
type: "button", });
value: "Delete", a.click();
onclick: async () => { setTimeout(function () {
let colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId); a.remove();
window.URL.revokeObjectURL(url);
}, 0);
}
}),
$el("input", {
type: "button",
value: "Delete",
onclick: async () => {
let colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId);
if (colorPalettes[colorPaletteId]) { if (colorPalettes[colorPaletteId]) {
app.ui.dialog.show("You cannot delete built-in color palette"); alert("You cannot delete a built-in color palette.");
return; return;
} }
if (colorPaletteId.startsWith("custom_")) { if (colorPaletteId.startsWith("custom_")) {
colorPaletteId = colorPaletteId.substr(7); colorPaletteId = colorPaletteId.substr(7);
} }
await deleteCustomColorPalette(colorPaletteId); await deleteCustomColorPalette(colorPaletteId);
} }
}), }),
]); ]),
]),
])
}, },
defaultValue: defaultColorPaletteId, defaultValue: defaultColorPaletteId,
async onChange(value) { async onChange(value) {

View File

@ -10,7 +10,7 @@ app.registerExtension({
LiteGraph.middle_click_slot_add_default_node = true; LiteGraph.middle_click_slot_add_default_node = true;
this.suggestionsNumber = app.ui.settings.addSetting({ this.suggestionsNumber = app.ui.settings.addSetting({
id: "Comfy.NodeSuggestions.number", id: "Comfy.NodeSuggestions.number",
name: "number of nodes suggestions", name: "Number of nodes suggestions",
type: "slider", type: "slider",
attrs: { attrs: {
min: 1, min: 1,

View File

@ -1,19 +1,26 @@
import { api } from "./api.js"; import {api} from "./api.js";
export function $el(tag, propsOrChildren, children) { export function $el(tag, propsOrChildren, children) {
const split = tag.split("."); const split = tag.split(".");
const element = document.createElement(split.shift()); const element = document.createElement(split.shift());
element.classList.add(...split); if (split.length > 0) {
element.classList.add(...split);
}
if (propsOrChildren) { if (propsOrChildren) {
if (Array.isArray(propsOrChildren)) { if (Array.isArray(propsOrChildren)) {
element.append(...propsOrChildren); element.append(...propsOrChildren);
} else { } else {
const { parent, $: cb, dataset, style } = propsOrChildren; const {parent, $: cb, dataset, style} = propsOrChildren;
delete propsOrChildren.parent; delete propsOrChildren.parent;
delete propsOrChildren.$; delete propsOrChildren.$;
delete propsOrChildren.dataset; delete propsOrChildren.dataset;
delete propsOrChildren.style; delete propsOrChildren.style;
if (Object.hasOwn(propsOrChildren, "for")) {
element.setAttribute("for", propsOrChildren.for)
}
if (style) { if (style) {
Object.assign(element.style, style); Object.assign(element.style, style);
} }
@ -119,6 +126,7 @@ function dragElement(dragEl, settings) {
savePos = value; savePos = value;
}, },
}); });
function dragMouseDown(e) { function dragMouseDown(e) {
e = e || window.event; e = e || window.event;
e.preventDefault(); e.preventDefault();
@ -161,8 +169,8 @@ function dragElement(dragEl, settings) {
export class ComfyDialog { export class ComfyDialog {
constructor() { constructor() {
this.element = $el("div.comfy-modal", { parent: document.body }, [ this.element = $el("div.comfy-modal", {parent: document.body}, [
$el("div.comfy-modal-content", [$el("p", { $: (p) => (this.textElement = p) }), ...this.createButtons()]), $el("div.comfy-modal-content", [$el("p", {$: (p) => (this.textElement = p)}), ...this.createButtons()]),
]); ]);
} }
@ -193,7 +201,22 @@ export class ComfyDialog {
class ComfySettingsDialog extends ComfyDialog { class ComfySettingsDialog extends ComfyDialog {
constructor() { constructor() {
super(); super();
this.element.classList.add("comfy-settings"); this.element = $el("dialog", {
id: "comfy-settings-dialog",
parent: document.body,
}, [
$el("table.comfy-modal-content.comfy-table", [
$el("caption", {textContent: "Settings"}),
$el("tbody", {$: (tbody) => (this.textElement = tbody)}),
$el("button", {
type: "button",
textContent: "Close",
onclick: () => {
this.element.close();
},
}),
]),
]);
this.settings = []; this.settings = [];
} }
@ -208,15 +231,16 @@ class ComfySettingsDialog extends ComfyDialog {
localStorage[settingId] = JSON.stringify(value); localStorage[settingId] = JSON.stringify(value);
} }
addSetting({ id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", }) { addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "",}) {
if (!id) { if (!id) {
throw new Error("Settings must have an ID"); throw new Error("Settings must have an ID");
} }
if (this.settings.find((s) => s.id === id)) { if (this.settings.find((s) => s.id === id)) {
throw new Error("Setting IDs must be unique"); throw new Error(`Setting ${id} of type ${type} must have a unique ID.`);
} }
const settingId = "Comfy.Settings." + id; const settingId = `Comfy.Settings.${id}`;
const v = localStorage[settingId]; const v = localStorage[settingId];
let value = v == null ? defaultValue : JSON.parse(v); let value = v == null ? defaultValue : JSON.parse(v);
@ -234,34 +258,50 @@ class ComfySettingsDialog extends ComfyDialog {
localStorage[settingId] = JSON.stringify(v); localStorage[settingId] = JSON.stringify(v);
value = v; value = v;
}; };
value = this.getSettingValue(id, defaultValue);
let element; let element;
value = this.getSettingValue(id, defaultValue); const htmlID = id.replaceAll(".", "-");
const labelCell = $el("td", [
$el("label", {
for: htmlID,
classList: [tooltip !== "" ? "comfy-tooltip-indicator" : ""],
textContent: name.endsWith(":") ? name : `${name}:`,
})
]);
if (typeof type === "function") { if (typeof type === "function") {
element = type(name, setter, value, attrs); element = type(name, setter, value, attrs);
} else { } else {
switch (type) { switch (type) {
case "boolean": case "boolean":
element = $el("div", [ element = $el("tr", [
$el("label", { textContent: name || id }, [ labelCell,
$el("td", [
$el("input", { $el("input", {
id: htmlID,
type: "checkbox", type: "checkbox",
checked: !!value, checked: value,
oninput: (e) => { onchange: (event) => {
setter(e.target.checked); const isChecked = event.target.checked;
if (onChange !== undefined) {
onChange(isChecked)
}
this.setSettingValue(id, isChecked);
}, },
...attrs
}), }),
]), ]),
]); ])
break; break;
case "number": case "number":
element = $el("div", [ element = $el("tr", [
$el("label", { textContent: name || id }, [ labelCell,
$el("td", [
$el("input", { $el("input", {
type, type,
value, value,
id: htmlID,
oninput: (e) => { oninput: (e) => {
setter(e.target.value); setter(e.target.value);
}, },
@ -271,46 +311,62 @@ class ComfySettingsDialog extends ComfyDialog {
]); ]);
break; break;
case "slider": case "slider":
element = $el("div", [ element = $el("tr", [
$el("label", { textContent: name }, [ labelCell,
$el("input", { $el("td", [
type: "range", $el("div", {
value, style: {
oninput: (e) => { display: "grid",
setter(e.target.value); gridAutoFlow: "column",
e.target.nextElementSibling.value = e.target.value;
}, },
...attrs }, [
}), $el("input", {
$el("input", { ...attrs,
type: "number", value,
value, type: "range",
oninput: (e) => { oninput: (e) => {
setter(e.target.value); setter(e.target.value);
e.target.previousElementSibling.value = e.target.value; e.target.nextElementSibling.value = e.target.value;
}, },
...attrs }),
}), $el("input", {
...attrs,
value,
id: htmlID,
type: "number",
style: {maxWidth: "4rem"},
oninput: (e) => {
setter(e.target.value);
e.target.previousElementSibling.value = e.target.value;
},
}),
]),
]), ]),
]); ]);
break; break;
case "text":
default: default:
console.warn("Unsupported setting type, defaulting to text"); if (type !== "text") {
element = $el("div", [ console.warn(`Unsupported setting type '${type}, defaulting to text`);
$el("label", { textContent: name || id }, [ }
element = $el("tr", [
labelCell,
$el("td", [
$el("input", { $el("input", {
value, value,
id: htmlID,
oninput: (e) => { oninput: (e) => {
setter(e.target.value); setter(e.target.value);
}, },
...attrs ...attrs,
}), }),
]), ]),
]); ]);
break; break;
} }
} }
if(tooltip) { if (tooltip) {
element.title = tooltip; element.title = tooltip;
} }
@ -330,13 +386,16 @@ class ComfySettingsDialog extends ComfyDialog {
} }
show() { show() {
super.show(); this.textElement.replaceChildren(
Object.assign(this.textElement.style, { $el("tr", {
display: "flex", style: {display: "none"},
flexDirection: "column", }, [
gap: "10px" $el("th"),
}); $el("th", {style: {width: "33%"}})
this.textElement.replaceChildren(...this.settings.map((s) => s.render())); ]),
...this.settings.map((s) => s.render()),
)
this.element.showModal();
} }
} }
@ -369,7 +428,7 @@ class ComfyList {
name: "Delete", name: "Delete",
cb: () => api.deleteItem(this.#type, item.prompt[1]), cb: () => api.deleteItem(this.#type, item.prompt[1]),
}; };
return $el("div", { textContent: item.prompt[0] + ": " }, [ return $el("div", {textContent: item.prompt[0] + ": "}, [
$el("button", { $el("button", {
textContent: "Load", textContent: "Load",
onclick: () => { onclick: () => {
@ -398,7 +457,7 @@ class ComfyList {
await this.load(); await this.load();
}, },
}), }),
$el("button", { textContent: "Refresh", onclick: () => this.load() }), $el("button", {textContent: "Refresh", onclick: () => this.load()}),
]) ])
); );
} }
@ -475,8 +534,8 @@ export class ComfyUI {
*/ */
const previewImage = this.settings.addSetting({ const previewImage = this.settings.addSetting({
id: "Comfy.PreviewFormat", id: "Comfy.PreviewFormat",
name: "When displaying a preview in the image widget, convert it to a lightweight image. (webp, jpeg, webp;50, ...)", name: "When displaying a preview in the image widget, convert it to a lightweight image, e.g. webp, jpeg, webp;50, etc.",
type: "string", type: "text",
defaultValue: "", defaultValue: "",
}); });
@ -484,18 +543,25 @@ export class ComfyUI {
id: "comfy-file-input", id: "comfy-file-input",
type: "file", type: "file",
accept: ".json,image/png,.latent", accept: ".json,image/png,.latent",
style: { display: "none" }, style: {display: "none"},
parent: document.body, parent: document.body,
onchange: () => { onchange: () => {
app.handleFile(fileInput.files[0]); app.handleFile(fileInput.files[0]);
}, },
}); });
this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [ this.menuContainer = $el("div.comfy-menu", {parent: document.body}, [
$el("div.drag-handle", { style: { overflow: "hidden", position: "relative", width: "100%", cursor: "default" } }, [ $el("div.drag-handle", {
style: {
overflow: "hidden",
position: "relative",
width: "100%",
cursor: "default"
}
}, [
$el("span.drag-handle"), $el("span.drag-handle"),
$el("span", { $: (q) => (this.queueSize = q) }), $el("span", {$: (q) => (this.queueSize = q)}),
$el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }), $el("button.comfy-settings-btn", {textContent: "⚙️", onclick: () => this.settings.show()}),
]), ]),
$el("button.comfy-queue-btn", { $el("button.comfy-queue-btn", {
id: "queue-button", id: "queue-button",
@ -503,7 +569,7 @@ export class ComfyUI {
onclick: () => app.queuePrompt(0, this.batchCount), onclick: () => app.queuePrompt(0, this.batchCount),
}), }),
$el("div", {}, [ $el("div", {}, [
$el("label", { innerHTML: "Extra options" }, [ $el("label", {innerHTML: "Extra options"}, [
$el("input", { $el("input", {
type: "checkbox", type: "checkbox",
onchange: (i) => { onchange: (i) => {
@ -514,14 +580,14 @@ export class ComfyUI {
}), }),
]), ]),
]), ]),
$el("div", { id: "extraOptions", style: { width: "100%", display: "none" } }, [ $el("div", {id: "extraOptions", style: {width: "100%", display: "none"}}, [
$el("label", { innerHTML: "Batch count" }, [ $el("label", {innerHTML: "Batch count"}, [
$el("input", { $el("input", {
id: "batchCountInputNumber", id: "batchCountInputNumber",
type: "number", type: "number",
value: this.batchCount, value: this.batchCount,
min: "1", min: "1",
style: { width: "35%", "margin-left": "0.4em" }, style: {width: "35%", "margin-left": "0.4em"},
oninput: (i) => { oninput: (i) => {
this.batchCount = i.target.value; this.batchCount = i.target.value;
document.getElementById("batchCountInputRange").value = this.batchCount; document.getElementById("batchCountInputRange").value = this.batchCount;
@ -547,7 +613,11 @@ export class ComfyUI {
]), ]),
]), ]),
$el("div.comfy-menu-btns", [ $el("div.comfy-menu-btns", [
$el("button", { id: "queue-front-button", textContent: "Queue Front", onclick: () => app.queuePrompt(-1, this.batchCount) }), $el("button", {
id: "queue-front-button",
textContent: "Queue Front",
onclick: () => app.queuePrompt(-1, this.batchCount)
}),
$el("button", { $el("button", {
$: (b) => (this.queue.button = b), $: (b) => (this.queue.button = b),
id: "comfy-view-queue-button", id: "comfy-view-queue-button",
@ -582,12 +652,12 @@ export class ComfyUI {
} }
} }
const json = JSON.stringify(app.graph.serialize(), null, 2); // convert the data to a JSON string const json = JSON.stringify(app.graph.serialize(), null, 2); // convert the data to a JSON string
const blob = new Blob([json], { type: "application/json" }); const blob = new Blob([json], {type: "application/json"});
const url = URL.createObjectURL(blob); const url = URL.createObjectURL(blob);
const a = $el("a", { const a = $el("a", {
href: url, href: url,
download: filename, download: filename,
style: { display: "none" }, style: {display: "none"},
parent: document.body, parent: document.body,
}); });
a.click(); a.click();
@ -597,25 +667,33 @@ export class ComfyUI {
}, 0); }, 0);
}, },
}), }),
$el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }), $el("button", {id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click()}),
$el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), $el("button", {
$el("button", { id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace() }), id: "comfy-refresh-button",
$el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => { textContent: "Refresh",
if (!confirmClear.value || confirm("Clear workflow?")) { onclick: () => app.refreshComboInNodes()
app.clean(); }),
app.graph.clear(); $el("button", {id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace()}),
$el("button", {
id: "comfy-clear-button", textContent: "Clear", onclick: () => {
if (!confirmClear.value || confirm("Clear workflow?")) {
app.clean();
app.graph.clear();
}
} }
}}), }),
$el("button", { id: "comfy-load-default-button", textContent: "Load Default", onclick: () => { $el("button", {
if (!confirmClear.value || confirm("Load default workflow?")) { id: "comfy-load-default-button", textContent: "Load Default", onclick: () => {
app.loadGraphData() if (!confirmClear.value || confirm("Load default workflow?")) {
app.loadGraphData()
}
} }
}}), }),
]); ]);
dragElement(this.menuContainer, this.settings); dragElement(this.menuContainer, this.settings);
this.setStatus({ exec_info: { queue_remaining: "X" } }); this.setStatus({exec_info: {queue_remaining: "X"}});
} }
setStatus(status) { setStatus(status) {

View File

@ -8,6 +8,8 @@
--drag-text: #ccc; --drag-text: #ccc;
--error-text: #ff4444; --error-text: #ff4444;
--border-color: #4e4e4e; --border-color: #4e4e4e;
--tr-even-bg-color: #222;
--tr-odd-bg-color: #353535;
} }
@media (prefers-color-scheme: dark) { @media (prefers-color-scheme: dark) {
@ -220,7 +222,7 @@ button.comfy-queue-btn {
margin: 6px 0 !important; margin: 6px 0 !important;
} }
.comfy-modal.comfy-settings, .comfy-modal.comfy-settings,
.comfy-modal.comfy-manage-templates { .comfy-modal.comfy-manage-templates {
text-align: center; text-align: center;
font-family: sans-serif; font-family: sans-serif;
@ -246,6 +248,11 @@ button.comfy-queue-btn {
font-size: inherit; font-size: inherit;
} }
.comfy-tooltip-indicator {
text-decoration: underline;
text-decoration-style: dashed;
}
@media only screen and (max-height: 850px) { @media only screen and (max-height: 850px) {
.comfy-menu { .comfy-menu {
top: 0 !important; top: 0 !important;
@ -254,8 +261,9 @@ button.comfy-queue-btn {
right: 0 !important; right: 0 !important;
border-radius: 0; border-radius: 0;
} }
.comfy-menu span.drag-handle { .comfy-menu span.drag-handle {
visibility:hidden visibility: hidden
} }
} }
@ -287,11 +295,75 @@ button.comfy-queue-btn {
border-radius: 12px 0 0 12px; border-radius: 12px 0 0 12px;
} }
/* Dialogs */
dialog {
box-shadow: 0 0 20px #888888;
}
dialog::backdrop {
background: rgba(0, 0, 0, 0.5);
}
#comfy-settings-dialog {
padding: 0;
width: 41rem;
}
#comfy-settings-dialog tr > td:first-child {
text-align: right;
}
#comfy-settings-dialog button {
background-color: var(--bg-color);
border: 1px var(--border-color) solid;
border-radius: 0;
color: var(--input-text);
font-size: 1rem;
padding: 0.5rem;
}
#comfy-settings-dialog button:hover {
background-color: var(--tr-odd-bg-color);
}
/* General CSS for tables */
.comfy-table {
border-collapse: collapse;
color: var(--input-text);
font-family: Arial, sans-serif;
width: 100%;
}
.comfy-table caption {
background-color: var(--bg-color);
color: var(--input-text);
font-size: 1rem;
font-weight: bold;
padding: 8px;
text-align: center;
}
.comfy-table tr:nth-child(even) {
background-color: var(--tr-even-bg-color);
}
.comfy-table tr:nth-child(odd) {
background-color: var(--tr-odd-bg-color);
}
.comfy-table td,
.comfy-table th {
border: 1px solid var(--border-color);
padding: 8px;
}
/* Context menu */ /* Context menu */
.litegraph .dialog { .litegraph .dialog {
z-index: 1; z-index: 1;
font-family: Arial, sans-serif; font-family: Arial, sans-serif;
} }
.litegraph .litemenu-entry.has_submenu { .litegraph .litemenu-entry.has_submenu {