Merge remote-tracking branch 'origin/master' into group-nodes

This commit is contained in:
pythongosssss 2023-11-29 17:44:51 +00:00
commit 249fc9255b
13 changed files with 227 additions and 40 deletions

View File

@ -31,6 +31,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/) - [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews) - Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast. - Starts up very fast.
- Works fully offline: will never download anything. - Works fully offline: will never download anything.

View File

@ -164,12 +164,13 @@ class BaseModel(torch.nn.Module):
self.inpaint_model = True self.inpaint_model = True
def memory_required(self, input_shape): def memory_required(self, input_shape):
area = input_shape[0] * input_shape[2] * input_shape[3]
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
#TODO: this needs to be tweaked #TODO: this needs to be tweaked
return (area / (comfy.model_management.dtype_size(self.get_dtype()) * 10)) * (1024 * 1024) area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * comfy.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024)
else: else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * input_shape[2] * input_shape[3]
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)

View File

@ -65,15 +65,15 @@ class ModelSamplingDiscrete(torch.nn.Module):
def timestep(self, sigma): def timestep(self, sigma):
log_sigma = sigma.log() log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape) return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
def sigma(self, timestep): def sigma(self, timestep):
t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1)) t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long() low_idx = t.floor().long()
high_idx = t.ceil().long() high_idx = t.ceil().long()
w = t.frac() w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp() return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent): def percent_to_sigma(self, percent):
if percent <= 0.0: if percent <= 0.0:

View File

@ -83,7 +83,7 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
real_model = None real_model = None
models, inference_memory = get_additional_models(positive, negative, model.model_dtype()) models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, model.memory_required(noise_shape) + inference_memory) comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model real_model = model.model
return real_model, positive, negative, noise_mask, models return real_model, positive, negative, noise_mask, models

View File

@ -187,10 +187,12 @@ class VAE:
if device is None: if device is None:
device = model_management.vae_device() device = model_management.vae_device()
self.device = device self.device = device
self.offload_device = model_management.vae_offload_device() offload_device = model_management.vae_offload_device()
self.vae_dtype = model_management.vae_dtype() self.vae_dtype = model_management.vae_dtype()
self.first_stage_model.to(self.vae_dtype) self.first_stage_model.to(self.vae_dtype)
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
@ -219,10 +221,9 @@ class VAE:
return samples return samples
def decode(self, samples_in): def decode(self, samples_in):
self.first_stage_model = self.first_stage_model.to(self.device)
try: try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.free_memory(memory_used, self.device) model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
@ -235,22 +236,19 @@ class VAE:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in) pixel_samples = self.decode_tiled_(samples_in)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
pixel_samples = pixel_samples.cpu().movedim(1,-1) pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
self.first_stage_model = self.first_stage_model.to(self.device) model_management.load_model_gpu(self.patcher)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap) output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return output.movedim(1,-1) return output.movedim(1,-1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
try: try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.free_memory(memory_used, self.device) model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
@ -263,14 +261,12 @@ class VAE:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
samples = self.encode_tiled_(pixel_samples) samples = self.encode_tiled_(pixel_samples)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
self.first_stage_model = self.first_stage_model.to(self.device) model_management.load_model_gpu(self.patcher)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples return samples
def get_sd(self): def get_sd(self):
@ -481,20 +477,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
return (model_patcher, clip, vae, clipvision) return (model_patcher, clip, vae, clipvision)
def load_unet(unet_path): #load unet in diffusers format def load_unet_state_dict(sd): #load unet in diffusers format
sd = comfy.utils.load_torch_file(unet_path)
parameters = comfy.utils.calculate_parameters(sd) parameters = comfy.utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters) unet_dtype = model_management.unet_dtype(model_params=parameters)
if "input_blocks.0.0.weight" in sd: #ldm if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if model_config is None: if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) return None
new_sd = sd new_sd = sd
else: #diffusers else: #diffusers
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype) model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
if model_config is None: if model_config is None:
print("ERROR UNSUPPORTED UNET", unet_path)
return None return None
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
@ -514,6 +508,14 @@ def load_unet(unet_path): #load unet in diffusers format
print("left over keys in unet:", left_over) print("left over keys in unet:", left_over)
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
def load_unet(unet_path):
sd = comfy.utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd)
if model is None:
print("ERROR UNSUPPORTED UNET", unet_path)
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model
def save_checkpoint(output_path, model, clip, vae, metadata=None): def save_checkpoint(output_path, model, clip, vae, metadata=None):
model_management.load_models_gpu([model, clip.load_model()]) model_management.load_models_gpu([model, clip.load_model()])
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())

View File

@ -81,6 +81,25 @@ class PolyexponentialScheduler:
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
return (sigmas, ) return (sigmas, )
class SDTurboScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"steps": ("INT", {"default": 1, "min": 1, "max": 10}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, model, steps):
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps]
sigmas = model.model.model_sampling.sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, )
class VPScheduler: class VPScheduler:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -257,6 +276,7 @@ NODE_CLASS_MAPPINGS = {
"ExponentialScheduler": ExponentialScheduler, "ExponentialScheduler": ExponentialScheduler,
"PolyexponentialScheduler": PolyexponentialScheduler, "PolyexponentialScheduler": PolyexponentialScheduler,
"VPScheduler": VPScheduler, "VPScheduler": VPScheduler,
"SDTurboScheduler": SDTurboScheduler,
"KSamplerSelect": KSamplerSelect, "KSamplerSelect": KSamplerSelect,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE,

View File

@ -17,7 +17,9 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input return c_out * x0 + c_skip * model_input
class ModelSamplingDiscreteLCM(torch.nn.Module): class ModelSamplingDiscreteDistilled(torch.nn.Module):
original_timesteps = 50
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.sigma_data = 1.0 self.sigma_data = 1.0
@ -29,13 +31,12 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
alphas = 1.0 - betas alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod = torch.cumprod(alphas, dim=0)
original_timesteps = 50 self.skip_steps = timesteps // self.original_timesteps
self.skip_steps = timesteps // original_timesteps
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32) alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
for x in range(original_timesteps): for x in range(self.original_timesteps):
alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5 sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
self.set_sigmas(sigmas) self.set_sigmas(sigmas)
@ -55,15 +56,15 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
def timestep(self, sigma): def timestep(self, sigma):
log_sigma = sigma.log() log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1) return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device)
def sigma(self, timestep): def sigma(self, timestep):
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long() low_idx = t.floor().long()
high_idx = t.ceil().long() high_idx = t.ceil().long()
w = t.frac() w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp() return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent): def percent_to_sigma(self, percent):
if percent <= 0.0: if percent <= 0.0:
@ -116,7 +117,7 @@ class ModelSamplingDiscrete:
sampling_type = comfy.model_sampling.V_PREDICTION sampling_type = comfy.model_sampling.V_PREDICTION
elif sampling == "lcm": elif sampling == "lcm":
sampling_type = LCM sampling_type = LCM
sampling_base = ModelSamplingDiscreteLCM sampling_base = ModelSamplingDiscreteDistilled
class ModelSamplingAdvanced(sampling_base, sampling_type): class ModelSamplingAdvanced(sampling_base, sampling_type):
pass pass

View File

@ -228,8 +228,12 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
full_output_folder = os.path.join(output_dir, subfolder) full_output_folder = os.path.join(output_dir, subfolder)
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir: if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
print("Saving image outside the output folder is not allowed.") err = "**** ERROR: Saving image outside the output folder is not allowed." + \
return {} "\n full_output_folder: " + os.path.abspath(full_output_folder) + \
"\n output_dir: " + output_dir + \
"\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))
print(err)
raise Exception(err)
try: try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1

12
main.py
View File

@ -88,6 +88,7 @@ def cuda_malloc_warning():
def prompt_worker(q, server): def prompt_worker(q, server):
e = execution.PromptExecutor(server) e = execution.PromptExecutor(server)
last_gc_collect = 0
while True: while True:
item, item_id = q.get() item, item_id = q.get()
execution_start_time = time.perf_counter() execution_start_time = time.perf_counter()
@ -97,9 +98,14 @@ def prompt_worker(q, server):
if server.client_id is not None: if server.client_id is not None:
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) current_time = time.perf_counter()
gc.collect() execution_time = current_time - execution_start_time
comfy.model_management.soft_empty_cache() print("Prompt executed in {:.2f} seconds".format(execution_time))
if (current_time - last_gc_collect) > 10.0:
gc.collect()
comfy.model_management.soft_empty_cache()
last_gc_collect = current_time
print("gc collect")
async def run(server, address='', port=8188, verbose=True, call_on_start=None): async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())

View File

@ -1337,6 +1337,7 @@ class SaveImage:
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
self.type = "output" self.type = "output"
self.prefix_append = "" self.prefix_append = ""
self.compress_level = 4
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -1370,7 +1371,7 @@ class SaveImage:
metadata.add_text(x, json.dumps(extra_pnginfo[x])) metadata.add_text(x, json.dumps(extra_pnginfo[x]))
file = f"{filename}_{counter:05}_.png" file = f"{filename}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4) img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
results.append({ results.append({
"filename": file, "filename": file,
"subfolder": subfolder, "subfolder": subfolder,
@ -1385,6 +1386,7 @@ class PreviewImage(SaveImage):
self.output_dir = folder_paths.get_temp_directory() self.output_dir = folder_paths.get_temp_directory()
self.type = "temp" self.type = "temp"
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
self.compress_level = 1
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):

View File

@ -576,7 +576,7 @@ class PromptServer():
bytesIO = BytesIO() bytesIO = BytesIO()
header = struct.pack(">I", type_num) header = struct.pack(">I", type_num)
bytesIO.write(header) bytesIO.write(header)
image.save(bytesIO, format=image_type, quality=95, compress_level=4) image.save(bytesIO, format=image_type, quality=95, compress_level=1)
preview_bytes = bytesIO.getvalue() preview_bytes = bytesIO.getvalue()
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)

View File

@ -0,0 +1,150 @@
import { app } from "../../scripts/app.js";
const MAX_HISTORY = 50;
let undo = [];
let redo = [];
let activeState = null;
let isOurLoad = false;
function checkState() {
const currentState = app.graph.serialize();
if (!graphEqual(activeState, currentState)) {
undo.push(activeState);
if (undo.length > MAX_HISTORY) {
undo.shift();
}
activeState = clone(currentState);
redo.length = 0;
}
}
const loadGraphData = app.loadGraphData;
app.loadGraphData = async function () {
const v = await loadGraphData.apply(this, arguments);
if (isOurLoad) {
isOurLoad = false;
} else {
checkState();
}
return v;
};
function clone(obj) {
try {
if (typeof structuredClone !== "undefined") {
return structuredClone(obj);
}
} catch (error) {
// structuredClone is stricter than using JSON.parse/stringify so fallback to that
}
return JSON.parse(JSON.stringify(obj));
}
function graphEqual(a, b, root = true) {
if (a === b) return true;
if (typeof a == "object" && a && typeof b == "object" && b) {
const keys = Object.getOwnPropertyNames(a);
if (keys.length != Object.getOwnPropertyNames(b).length) {
return false;
}
for (const key of keys) {
let av = a[key];
let bv = b[key];
if (root && key === "nodes") {
// Nodes need to be sorted as the order changes when selecting nodes
av = [...av].sort((a, b) => a.id - b.id);
bv = [...bv].sort((a, b) => a.id - b.id);
}
if (!graphEqual(av, bv, false)) {
return false;
}
}
return true;
}
return false;
}
const undoRedo = async (e) => {
if (e.ctrlKey || e.metaKey) {
if (e.key === "y") {
const prevState = redo.pop();
if (prevState) {
undo.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState);
activeState = prevState;
}
return true;
} else if (e.key === "z") {
const prevState = undo.pop();
if (prevState) {
redo.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState);
activeState = prevState;
}
return true;
}
}
};
const bindInput = (activeEl) => {
if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") {
for (const evt of ["change", "input", "blur"]) {
if (`on${evt}` in activeEl) {
const listener = () => {
checkState();
activeEl.removeEventListener(evt, listener);
};
activeEl.addEventListener(evt, listener);
return true;
}
}
}
};
window.addEventListener(
"keydown",
(e) => {
requestAnimationFrame(async () => {
const activeEl = document.activeElement;
if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") {
// Ignore events on inputs, they have their native history
return;
}
// Check if this is a ctrl+z ctrl+y
if (await undoRedo(e)) return;
// If our active element is some type of input then handle changes after they're done
if (bindInput(activeEl)) return;
checkState();
});
},
true
);
// Handle clicking DOM elements (e.g. widgets)
window.addEventListener("mouseup", () => {
checkState();
});
// Handle litegraph clicks
const processMouseUp = LGraphCanvas.prototype.processMouseUp;
LGraphCanvas.prototype.processMouseUp = function (e) {
const v = processMouseUp.apply(this, arguments);
checkState();
return v;
};
const processMouseDown = LGraphCanvas.prototype.processMouseDown;
LGraphCanvas.prototype.processMouseDown = function (e) {
const v = processMouseDown.apply(this, arguments);
checkState();
return v;
};

View File

@ -254,9 +254,9 @@ class ComfyApi extends EventTarget {
* Gets the prompt execution history * Gets the prompt execution history
* @returns Prompt history including node outputs * @returns Prompt history including node outputs
*/ */
async getHistory() { async getHistory(max_items=200) {
try { try {
const res = await this.fetchApi("/history?max_items=200"); const res = await this.fetchApi(`/history?max_items=${max_items}`);
return { History: Object.values(await res.json()) }; return { History: Object.values(await res.json()) };
} catch (error) { } catch (error) {
console.error(error); console.error(error);