From 6aa1bcd601dfdcb4485ea31947ffbf992a5b54fc Mon Sep 17 00:00:00 2001 From: Jack Bauer <2308123+dmx974@users.noreply.github.com> Date: Sun, 26 Nov 2023 17:23:11 +0400 Subject: [PATCH 01/15] Remove hard coded max_items in history API --- web/scripts/api.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/scripts/api.js b/web/scripts/api.js index de56b2310..9aa7528af 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -254,9 +254,9 @@ class ComfyApi extends EventTarget { * Gets the prompt execution history * @returns Prompt history including node outputs */ - async getHistory() { + async getHistory(max_items=200) { try { - const res = await this.fetchApi("/history?max_items=200"); + const res = await this.fetchApi(`/history?max_items=${max_items}`); return { History: Object.values(await res.json()) }; } catch (error) { console.error(error); From edd6f75d3ad243e6c2d38f2d94191da40d12b2f3 Mon Sep 17 00:00:00 2001 From: David Jeske Date: Sun, 26 Nov 2023 13:10:31 -0700 Subject: [PATCH 02/15] better error for invalid output paths --- folder_paths.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 4a38deec0..5479fd7b2 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -228,8 +228,12 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height full_output_folder = os.path.join(output_dir, subfolder) if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir: - print("Saving image outside the output folder is not allowed.") - return {} + err = "**** ERROR: Saving image outside the output folder is not allowed." + \ + "\n full_output_folder: " + os.path.abspath(full_output_folder) + \ + "\n output_dir: " + output_dir + \ + "\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) + print(err) + raise Exception(err) try: counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 From 34eccd863bb41f48346de178a55be308dc36e5e5 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Mon, 27 Nov 2023 14:00:15 +0000 Subject: [PATCH 03/15] Add simple undo redo history --- web/extensions/core/undoRedo.js | 150 ++++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 web/extensions/core/undoRedo.js diff --git a/web/extensions/core/undoRedo.js b/web/extensions/core/undoRedo.js new file mode 100644 index 000000000..1c1d785a8 --- /dev/null +++ b/web/extensions/core/undoRedo.js @@ -0,0 +1,150 @@ +import { app } from "../../scripts/app.js"; + +const MAX_HISTORY = 50; + +let undo = []; +let redo = []; +let activeState = null; +let isOurLoad = false; +function checkState() { + const currentState = app.graph.serialize(); + if (!graphEqual(activeState, currentState)) { + undo.push(activeState); + if(undo.length > MAX_HISTORY) { + undo.shift(); + } + activeState = clone(currentState); + redo.length = 0; + } +} + +const loadGraphData = app.loadGraphData; +app.loadGraphData = async function () { + const v = await loadGraphData.apply(this, arguments); + if (isOurLoad) { + isOurLoad = false; + } else { + checkState(); + } + return v; +}; + +function clone(obj) { + try { + if (typeof structuredClone !== "undefined") { + return structuredClone(obj); + } + } catch (error) { + // structuredClone is stricter than using JSON.parse/stringify so fallback to that + } + + return JSON.parse(JSON.stringify(obj)); +} + +function graphEqual(a, b, root = true) { + if (a === b) return true; + + if (typeof a == "object" && a && typeof b == "object" && b) { + const keys = Object.getOwnPropertyNames(a); + + if (keys.length != Object.getOwnPropertyNames(b).length) { + return false; + } + + for (const key of keys) { + let av = a[key]; + let bv = b[key]; + if (root && key === "nodes") { + // Nodes need to be sorted as the order changes when selecting nodes + av = [...av].sort((a, b) => a.id - b.id); + bv = [...bv].sort((a, b) => a.id - b.id); + } + if (!graphEqual(av, bv, false)) { + return false; + } + } + + return true; + } + + return false; +} + +const undoRedo = async (e) => { + if (e.ctrlKey || e.metaKey) { + if (e.key === "y") { + const prevState = redo.pop(); + if (prevState) { + undo.push(activeState); + isOurLoad = true; + await app.loadGraphData(prevState); + activeState = prevState; + } + return true; + } else if (e.key === "z") { + const prevState = undo.pop(); + if (prevState) { + redo.push(activeState); + isOurLoad = true; + await app.loadGraphData(prevState); + activeState = prevState; + } + return true; + } + } +}; + +const bindInput = (activeEl) => { + if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") { + for (const evt of ["change", "input", "blur"]) { + if (`on${evt}` in activeEl) { + const listener = () => { + checkState(); + activeEl.removeEventListener(evt, listener); + }; + activeEl.addEventListener(evt, listener); + return true; + } + } + } +}; + +window.addEventListener( + "keydown", + (e) => { + requestAnimationFrame(async () => { + const activeEl = document.activeElement; + if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") { + // Ignore events on inputs, they have their native history + return; + } + + // Check if this is a ctrl+z ctrl+y + if (await undoRedo(e)) return; + + // If our active element is some type of input then handle changes after they're done + if (bindInput(activeEl)) return; + checkState(); + }); + }, + true +); + +// Handle clicking DOM elements (e.g. widgets) +window.addEventListener("mouseup", () => { + checkState(); +}); + +// Handle litegraph clicks +const processMouseUp = LGraphCanvas.prototype.processMouseUp; +LGraphCanvas.prototype.processMouseUp = function (e) { + const v = processMouseUp.apply(this, arguments); + checkState(); + return v; +}; +const processMouseDown = LGraphCanvas.prototype.processMouseDown; +LGraphCanvas.prototype.processMouseDown = function (e) { + const v = processMouseDown.apply(this, arguments); + checkState(); + return v; +}; From 9be0b30cf1f69384e72823f5112072b15f1f431d Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Mon, 27 Nov 2023 14:02:50 +0000 Subject: [PATCH 04/15] fix formatting --- web/extensions/core/undoRedo.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/extensions/core/undoRedo.js b/web/extensions/core/undoRedo.js index 1c1d785a8..c6613b0f0 100644 --- a/web/extensions/core/undoRedo.js +++ b/web/extensions/core/undoRedo.js @@ -10,9 +10,9 @@ function checkState() { const currentState = app.graph.serialize(); if (!graphEqual(activeState, currentState)) { undo.push(activeState); - if(undo.length > MAX_HISTORY) { - undo.shift(); - } + if (undo.length > MAX_HISTORY) { + undo.shift(); + } activeState = clone(currentState); redo.length = 0; } From be71bb5e13d716c541a5372a518e9d512073fe18 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 27 Nov 2023 14:04:16 -0500 Subject: [PATCH 05/15] Tweak memory inference calculations a bit. --- comfy/model_base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 34274c4ae..3d6879ae6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -164,12 +164,13 @@ class BaseModel(torch.nn.Module): self.inpaint_model = True def memory_required(self, input_shape): - area = input_shape[0] * input_shape[2] * input_shape[3] if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): #TODO: this needs to be tweaked - return (area / (comfy.model_management.dtype_size(self.get_dtype()) * 10)) * (1024 * 1024) + area = max(input_shape[0], 3) * input_shape[2] * input_shape[3] + return (area * comfy.model_management.dtype_size(self.get_dtype()) / 60) * (1024 * 1024) else: #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. + area = input_shape[0] * input_shape[2] * input_shape[3] return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) From 13fdee6abf7a7b072ad0f1ebbaa76aca13ddd2a8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 27 Nov 2023 14:55:40 -0500 Subject: [PATCH 06/15] Try to free memory for both cond+uncond before inference. --- comfy/model_base.py | 4 ++-- comfy/sample.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 3d6879ae6..786c9cf47 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -166,8 +166,8 @@ class BaseModel(torch.nn.Module): def memory_required(self, input_shape): if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): #TODO: this needs to be tweaked - area = max(input_shape[0], 3) * input_shape[2] * input_shape[3] - return (area * comfy.model_management.dtype_size(self.get_dtype()) / 60) * (1024 * 1024) + area = input_shape[0] * input_shape[2] * input_shape[3] + return (area * comfy.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024) else: #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = input_shape[0] * input_shape[2] * input_shape[3] diff --git a/comfy/sample.py b/comfy/sample.py index 4bfdb8ce5..034db97ee 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -83,7 +83,7 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask): real_model = None models, inference_memory = get_additional_models(positive, negative, model.model_dtype()) - comfy.model_management.load_models_gpu([model] + models, model.memory_required(noise_shape) + inference_memory) + comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory) real_model = model.model return real_model, positive, negative, noise_mask, models From 488de0b4df524589c11a9bd0e2b3663d03003342 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 27 Nov 2023 16:32:03 -0500 Subject: [PATCH 07/15] ModelSamplingDiscreteLCM -> ModelSamplingDiscreteDistilled --- comfy_extras/nodes_model_advanced.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 6991c9837..20261aade 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -17,7 +17,9 @@ class LCM(comfy.model_sampling.EPS): return c_out * x0 + c_skip * model_input -class ModelSamplingDiscreteLCM(torch.nn.Module): +class ModelSamplingDiscreteDistilled(torch.nn.Module): + original_timesteps = 50 + def __init__(self): super().__init__() self.sigma_data = 1.0 @@ -29,13 +31,12 @@ class ModelSamplingDiscreteLCM(torch.nn.Module): alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) - original_timesteps = 50 - self.skip_steps = timesteps // original_timesteps + self.skip_steps = timesteps // self.original_timesteps - alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32) - for x in range(original_timesteps): - alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] + alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) + for x in range(self.original_timesteps): + alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5 self.set_sigmas(sigmas) @@ -116,7 +117,7 @@ class ModelSamplingDiscrete: sampling_type = comfy.model_sampling.V_PREDICTION elif sampling == "lcm": sampling_type = LCM - sampling_base = ModelSamplingDiscreteLCM + sampling_base = ModelSamplingDiscreteDistilled class ModelSamplingAdvanced(sampling_base, sampling_type): pass From f30b992b18078415f7c31c6c2f5ad1513db0bf5e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 27 Nov 2023 16:41:33 -0500 Subject: [PATCH 08/15] .sigma and .timestep now return tensors on the same device as the input. --- comfy/model_sampling.py | 6 +++--- comfy_extras/nodes_model_advanced.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index fac5c995e..69c8b1f01 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -65,15 +65,15 @@ class ModelSamplingDiscrete(torch.nn.Module): def timestep(self, sigma): log_sigma = sigma.log() dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] - return dists.abs().argmin(dim=0).view(sigma.shape) + return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device) def sigma(self, timestep): - t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1)) + t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1)) low_idx = t.floor().long() high_idx = t.ceil().long() w = t.frac() log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] - return log_sigma.exp() + return log_sigma.exp().to(timestep.device) def percent_to_sigma(self, percent): if percent <= 0.0: diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 20261aade..efcdf1932 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -56,15 +56,15 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module): def timestep(self, sigma): log_sigma = sigma.log() dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] - return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1) + return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device) def sigma(self, timestep): - t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) + t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) low_idx = t.floor().long() high_idx = t.ceil().long() w = t.frac() log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] - return log_sigma.exp() + return log_sigma.exp().to(timestep.device) def percent_to_sigma(self, percent): if percent <= 0.0: From c45d1b9b67a98c9ff9743b93caf8303286a430c3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 27 Nov 2023 17:32:07 -0500 Subject: [PATCH 09/15] Add a function to load a unet from a state dict. --- comfy/sd.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 7f85540c4..53c79e1c5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -481,20 +481,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o return (model_patcher, clip, vae, clipvision) -def load_unet(unet_path): #load unet in diffusers format - sd = comfy.utils.load_torch_file(unet_path) +def load_unet_state_dict(sd): #load unet in diffusers format parameters = comfy.utils.calculate_parameters(sd) unet_dtype = model_management.unet_dtype(model_params=parameters) if "input_blocks.0.0.weight" in sd: #ldm model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) if model_config is None: - raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) + return None new_sd = sd else: #diffusers model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype) if model_config is None: - print("ERROR UNSUPPORTED UNET", unet_path) return None diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) @@ -514,6 +512,14 @@ def load_unet(unet_path): #load unet in diffusers format print("left over keys in unet:", left_over) return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) +def load_unet(unet_path): + sd = comfy.utils.load_torch_file(unet_path) + model = load_unet_state_dict(sd) + if model is None: + print("ERROR UNSUPPORTED UNET", unet_path) + raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) + return model + def save_checkpoint(output_path, model, clip, vae, metadata=None): model_management.load_models_gpu([model, clip.load_model()]) sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) From 798a34d009cd78f02bd4c0b30f1c9fd6a594d345 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 28 Nov 2023 04:57:59 -0500 Subject: [PATCH 10/15] Lower compress level for image preview. --- nodes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index df40f8094..8b4a9b119 100644 --- a/nodes.py +++ b/nodes.py @@ -1337,6 +1337,7 @@ class SaveImage: self.output_dir = folder_paths.get_output_directory() self.type = "output" self.prefix_append = "" + self.compress_level = 4 @classmethod def INPUT_TYPES(s): @@ -1370,7 +1371,7 @@ class SaveImage: metadata.add_text(x, json.dumps(extra_pnginfo[x])) file = f"{filename}_{counter:05}_.png" - img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4) + img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level) results.append({ "filename": file, "subfolder": subfolder, @@ -1385,6 +1386,7 @@ class PreviewImage(SaveImage): self.output_dir = folder_paths.get_temp_directory() self.type = "temp" self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) + self.compress_level = 1 @classmethod def INPUT_TYPES(s): From 983ebc579212e209f52dff014b79bfe1932c0959 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 28 Nov 2023 04:58:32 -0500 Subject: [PATCH 11/15] Use smart model management for VAE to decrease latency. --- comfy/sd.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 53c79e1c5..f4f84d0a0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -187,10 +187,12 @@ class VAE: if device is None: device = model_management.vae_device() self.device = device - self.offload_device = model_management.vae_offload_device() + offload_device = model_management.vae_offload_device() self.vae_dtype = model_management.vae_dtype() self.first_stage_model.to(self.vae_dtype) + self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) @@ -219,10 +221,9 @@ class VAE: return samples def decode(self, samples_in): - self.first_stage_model = self.first_stage_model.to(self.device) try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) - model_management.free_memory(memory_used, self.device) + model_management.load_models_gpu([self.patcher], memory_required=memory_used) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -235,22 +236,19 @@ class VAE: print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) - self.first_stage_model = self.first_stage_model.to(self.offload_device) pixel_samples = pixel_samples.cpu().movedim(1,-1) return pixel_samples def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): - self.first_stage_model = self.first_stage_model.to(self.device) + model_management.load_model_gpu(self.patcher) output = self.decode_tiled_(samples, tile_x, tile_y, overlap) - self.first_stage_model = self.first_stage_model.to(self.offload_device) return output.movedim(1,-1) def encode(self, pixel_samples): - self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) - model_management.free_memory(memory_used, self.device) + model_management.load_models_gpu([self.patcher], memory_required=memory_used) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -263,14 +261,12 @@ class VAE: print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") samples = self.encode_tiled_(pixel_samples) - self.first_stage_model = self.first_stage_model.to(self.offload_device) return samples def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): - self.first_stage_model = self.first_stage_model.to(self.device) + model_management.load_model_gpu(self.patcher) pixel_samples = pixel_samples.movedim(-1,1) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) - self.first_stage_model = self.first_stage_model.to(self.offload_device) return samples def get_sd(self): From 21063fa35b53683f6ca01ccf1a5d5b509f702ba7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 28 Nov 2023 11:01:05 -0500 Subject: [PATCH 12/15] Lower compress level of png sent on websocket. --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 1a8e92b8f..9b1e3269d 100644 --- a/server.py +++ b/server.py @@ -576,7 +576,7 @@ class PromptServer(): bytesIO = BytesIO() header = struct.pack(">I", type_num) bytesIO.write(header) - image.save(bytesIO, format=image_type, quality=95, compress_level=4) + image.save(bytesIO, format=image_type, quality=95, compress_level=1) preview_bytes = bytesIO.getvalue() await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) From 57d7f4464f2a40521666cc8436711f73bf728a97 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 28 Nov 2023 13:35:32 -0500 Subject: [PATCH 13/15] Add SDTurboScheduler node. --- comfy_extras/nodes_custom_sampler.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index d3c1d4a23..008d0b8d6 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -81,6 +81,25 @@ class PolyexponentialScheduler: sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) return (sigmas, ) +class SDTurboScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "steps": ("INT", {"default": 1, "min": 1, "max": 10}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, model, steps): + timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps] + sigmas = model.model.model_sampling.sigma(timesteps) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + return (sigmas, ) + class VPScheduler: @classmethod def INPUT_TYPES(s): @@ -257,6 +276,7 @@ NODE_CLASS_MAPPINGS = { "ExponentialScheduler": ExponentialScheduler, "PolyexponentialScheduler": PolyexponentialScheduler, "VPScheduler": VPScheduler, + "SDTurboScheduler": SDTurboScheduler, "KSamplerSelect": KSamplerSelect, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE, From b911eefc4278b6069390d01a6ac9010ae6eecbac Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 28 Nov 2023 14:20:56 -0500 Subject: [PATCH 14/15] Limit gc.collect() to once every 10 seconds. --- main.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 1100a07f4..3997fbefc 100644 --- a/main.py +++ b/main.py @@ -88,6 +88,7 @@ def cuda_malloc_warning(): def prompt_worker(q, server): e = execution.PromptExecutor(server) + last_gc_collect = 0 while True: item, item_id = q.get() execution_start_time = time.perf_counter() @@ -97,9 +98,14 @@ def prompt_worker(q, server): if server.client_id is not None: server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) - print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) - gc.collect() - comfy.model_management.soft_empty_cache() + current_time = time.perf_counter() + execution_time = current_time - execution_start_time + print("Prompt executed in {:.2f} seconds".format(execution_time)) + if (current_time - last_gc_collect) > 10.0: + gc.collect() + comfy.model_management.soft_empty_cache() + last_gc_collect = current_time + print("gc collect") async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) From 777f6b15225197898a5f49742682a2be859072d7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 28 Nov 2023 14:45:00 -0500 Subject: [PATCH 15/15] Add to README that SDXL Turbo is supported. --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 9d7e31790..af1f22811 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/) - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) +- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/) - Latent previews with [TAESD](#how-to-show-high-quality-previews) - Starts up very fast. - Works fully offline: will never download anything.