From c8c9926eeb0b25dba86f3d9e574e8527c090fc37 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Mon, 24 Apr 2023 11:55:44 +0100 Subject: [PATCH 1/8] Add progress to vae decode tiled --- comfy/sd.py | 12 +++++++++--- comfy/utils.py | 4 +++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 92dbb931d..2aadefadc 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,6 +1,7 @@ import torch import contextlib import copy +from tqdm.auto import tqdm import sd1_clip import sd2_clip @@ -437,11 +438,16 @@ class VAE: self.device = device def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): + it_1 = -(samples.shape[2] // -(tile_y * 2 - overlap)) * -(samples.shape[3] // -(tile_x // 2 - overlap)) + it_2 = -(samples.shape[2] // -(tile_y // 2 - overlap)) * -(samples.shape[3] // -(tile_x * 2 - overlap)) + it_3 = -(samples.shape[2] // -(tile_y - overlap)) * -(samples.shape[3] // -(tile_x - overlap)) + pbar = tqdm(total=samples.shape[0] * (it_1 + it_2 + it_3)) + decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) output = torch.clamp(( - (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) + - utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) + - utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8)) + (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + + utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + + utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar)) / 3.0) / 2.0, min=0.0, max=1.0) return output diff --git a/comfy/utils.py b/comfy/utils.py index 68f93403c..c7c6a08c5 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -63,7 +63,7 @@ def common_upscale(samples, width, height, upscale_method, crop): return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) @torch.inference_mode() -def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3): +def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None): output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu") for b in range(samples.shape[0]): s = samples[b:b+1] @@ -83,6 +83,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1)) out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask + if pbar is not None: + pbar.update(1) output[b:b+1] = out/out_div return output From 06ad35b4932fe6cc4382d8b1dfa79fef8284362a Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 2 May 2023 19:18:07 +0100 Subject: [PATCH 2/8] added progress to encode + upscale --- comfy/sd.py | 12 +++++++++--- comfy_extras/nodes_upscale_model.py | 8 +++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 2aadefadc..06d6c1a56 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -491,9 +491,15 @@ class VAE: model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1).to(self.device) - samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4) - samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4) - samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4) + + it_1 = -(pixel_samples.shape[2] // -(tile_y * 2 - overlap)) * -(pixel_samples.shape[3] // -(tile_x // 2 - overlap)) + it_2 = -(pixel_samples.shape[2] // -(tile_y // 2 - overlap)) * -(pixel_samples.shape[3] // -(tile_x * 2 - overlap)) + it_3 = -(pixel_samples.shape[2] // -(tile_y - overlap)) * -(pixel_samples.shape[3] // -(tile_x - overlap)) + pbar = tqdm(total=(it_1 + it_2 + it_3)) + + samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples /= 3.0 self.first_stage_model = self.first_stage_model.cpu() samples = samples.cpu() diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index d8754698c..4fc7dcd77 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -4,6 +4,7 @@ from comfy import model_management import torch import comfy.utils import folder_paths +from tqdm.auto import tqdm class UpscaleModelLoader: @classmethod @@ -37,7 +38,12 @@ class ImageUpscaleWithModel: device = model_management.get_torch_device() upscale_model.to(device) in_img = image.movedim(-1,-3).to(device) - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=128 + 64, tile_y=128 + 64, overlap = 8, upscale_amount=upscale_model.scale) + + tile = 128 + 64 + overlap = 8 + its = -(in_img.shape[2] // -(tile - overlap)) * -(in_img.shape[3] // -(tile - overlap)) + pbar = tqdm(total=its) + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) upscale_model.cpu() s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) return (s,) From 93c64afaa92b425fc863b80ee0b7c618705d7d60 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 May 2023 23:00:49 -0400 Subject: [PATCH 3/8] Use sampler callback instead of tqdm hook for progress bar. --- comfy/utils.py | 23 +++++++++++++++++++++++ main.py | 12 ++++-------- nodes.py | 6 +++++- 3 files changed, 32 insertions(+), 9 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 68f93403c..7f3c3978c 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -86,3 +86,26 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am output[b:b+1] = out/out_div return output + + +PROGRESS_BAR_HOOK = None +def set_progress_bar_global_hook(function): + global PROGRESS_BAR_HOOK + PROGRESS_BAR_HOOK = function + +class ProgressBar: + def __init__(self, total): + global PROGRESS_BAR_HOOK + self.total = total + self.current = 0 + self.hook = PROGRESS_BAR_HOOK + + def update_absolute(self, value): + if value > self.total: + value = self.total + self.current = value + if self.hook is not None: + self.hook(self.current, self.total) + + def update(self, value): + self.update_absolute(self.current + value) diff --git a/main.py b/main.py index 02c700ebc..f369b82f3 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ import shutil import threading from comfy.cli_args import args +import comfy.utils if os.name == "nt": import logging @@ -39,14 +40,9 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) def hijack_progress(server): - from tqdm.auto import tqdm - orig_func = getattr(tqdm, "update") - def wrapped_func(*args, **kwargs): - pbar = args[0] - v = orig_func(*args, **kwargs) - server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id) - return v - setattr(tqdm, "update", wrapped_func) + def hook(value, total): + server.send_sync("progress", { "value": value, "max": total}, server.client_id) + comfy.utils.set_progress_bar_global_hook(hook) def cleanup_temp(): temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") diff --git a/nodes.py b/nodes.py index 80d508854..90c943fe3 100644 --- a/nodes.py +++ b/nodes.py @@ -815,9 +815,13 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if "noise_mask" in latent: noise_mask = latent["noise_mask"] + pbar = comfy.utils.ProgressBar(steps) + def callback(step, x0, x): + pbar.update_absolute(step + 1) + samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, - force_full_denoise=force_full_denoise, noise_mask=noise_mask) + force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback) out = latent.copy() out["samples"] = samples return (out, ) From 27df74101e6e5bb761364b718d57313388b49182 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 3 May 2023 17:33:19 +0100 Subject: [PATCH 4/8] reduce duplication --- comfy/sd.py | 14 +++++--------- comfy/utils.py | 6 ++++++ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 06d6c1a56..87b380b1c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -438,10 +438,8 @@ class VAE: self.device = device def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): - it_1 = -(samples.shape[2] // -(tile_y * 2 - overlap)) * -(samples.shape[3] // -(tile_x // 2 - overlap)) - it_2 = -(samples.shape[2] // -(tile_y // 2 - overlap)) * -(samples.shape[3] // -(tile_x * 2 - overlap)) - it_3 = -(samples.shape[2] // -(tile_y - overlap)) * -(samples.shape[3] // -(tile_x - overlap)) - pbar = tqdm(total=samples.shape[0] * (it_1 + it_2 + it_3)) + steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) + pbar = tqdm(total=steps) decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) output = torch.clamp(( @@ -492,11 +490,9 @@ class VAE: self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1).to(self.device) - it_1 = -(pixel_samples.shape[2] // -(tile_y * 2 - overlap)) * -(pixel_samples.shape[3] // -(tile_x // 2 - overlap)) - it_2 = -(pixel_samples.shape[2] // -(tile_y // 2 - overlap)) * -(pixel_samples.shape[3] // -(tile_x * 2 - overlap)) - it_3 = -(pixel_samples.shape[2] // -(tile_y - overlap)) * -(pixel_samples.shape[3] // -(tile_x - overlap)) - pbar = tqdm(total=(it_1 + it_2 + it_3)) - + steps = utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) + pbar = tqdm(total=steps) + samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) diff --git a/comfy/utils.py b/comfy/utils.py index c7c6a08c5..82d3aa0d8 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -62,6 +62,12 @@ def common_upscale(samples, width, height, upscale_method, crop): s = samples return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) +def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): + it_1 = -(height // -(tile_y * 2 - overlap)) * -(width // -(tile_x // 2 - overlap)) + it_2 = -(height // -(tile_y // 2 - overlap)) * -(width // -(tile_x * 2 - overlap)) + it_3 = -(height // -(tile_y - overlap)) * -(width // -(tile_x - overlap)) + return it_1 + it_2 + it_3 + @torch.inference_mode() def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None): output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu") From 908dc1d5a8717073f44d136d6d2b4f983ea07d40 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 3 May 2023 12:58:10 -0400 Subject: [PATCH 5/8] Add a total_steps value to sampler callback. --- comfy/extra_samplers/uni_pc.py | 2 +- comfy/samplers.py | 8 +++++--- comfy/utils.py | 4 +++- nodes.py | 4 ++-- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 78bab5936..2ff10caf1 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -767,7 +767,7 @@ class UniPC: model_x = self.model_fn(x, vec_t) model_prev_list[-1] = model_x if callback is not None: - callback(step_index, model_prev_list[-1], x) + callback(step_index, model_prev_list[-1], x, steps) else: raise NotImplementedError() if denoise_to_zero: diff --git a/comfy/samplers.py b/comfy/samplers.py index b30fc3d9b..dcf93cca2 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -623,7 +623,8 @@ class KSampler: ddim_callback = None if callback is not None: - ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None) + total_steps = len(timesteps) - 1 + ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps) sampler = DDIMSampler(self.model, device=self.device) sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) @@ -654,13 +655,14 @@ class KSampler: noise = noise * sigmas[0] k_callback = None + total_steps = len(sigmas) - 1 if callback is not None: - k_callback = lambda x: callback(x["i"], x["denoised"], x["x"]) + k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) if latent_image is not None: noise += latent_image if self.sampler == "dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) + samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) elif self.sampler == "dpm_adaptive": samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) else: diff --git a/comfy/utils.py b/comfy/utils.py index 7f3c3978c..f1ff97792 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -100,7 +100,9 @@ class ProgressBar: self.current = 0 self.hook = PROGRESS_BAR_HOOK - def update_absolute(self, value): + def update_absolute(self, value, total=None): + if total is not None: + self.total = total if value > self.total: value = self.total self.current = value diff --git a/nodes.py b/nodes.py index 90c943fe3..c2bc36855 100644 --- a/nodes.py +++ b/nodes.py @@ -816,8 +816,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, noise_mask = latent["noise_mask"] pbar = comfy.utils.ProgressBar(steps) - def callback(step, x0, x): - pbar.update_absolute(step + 1) + def callback(step, x0, x, total_steps): + pbar.update_absolute(step + 1, total_steps) samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, From 8912623ea9929848b813f1aeafee0fa9e1281817 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 3 May 2023 18:19:22 +0100 Subject: [PATCH 6/8] use comfy progress bar --- comfy/sd.py | 6 +++--- comfy_extras/nodes_upscale_model.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 32499f600..e4c5282d7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -516,7 +516,7 @@ class VAE: def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) - pbar = tqdm(total=steps) + pbar = utils.ProgressBar(steps) decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) output = torch.clamp(( @@ -568,8 +568,8 @@ class VAE: pixel_samples = pixel_samples.movedim(-1,1).to(self.device) steps = utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) - pbar = tqdm(total=steps) - + pbar = utils.ProgressBar(steps) + samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 4fc7dcd77..dfd1994a6 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -41,8 +41,8 @@ class ImageUpscaleWithModel: tile = 128 + 64 overlap = 8 - its = -(in_img.shape[2] // -(tile - overlap)) * -(in_img.shape[3] // -(tile - overlap)) - pbar = tqdm(total=its) + steps = -(in_img.shape[2] // -(tile - overlap)) * -(in_img.shape[3] // -(tile - overlap)) + pbar = comfy.utils.ProgressBar(steps) s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) upscale_model.cpu() s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) From 5eeecf3fd5adedfa5a92d3549f77a78be714c2a3 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 3 May 2023 18:21:23 +0100 Subject: [PATCH 7/8] remove unused import --- comfy/sd.py | 1 - comfy_extras/nodes_upscale_model.py | 1 - 2 files changed, 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index e4c5282d7..d60b908b8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,7 +1,6 @@ import torch import contextlib import copy -from tqdm.auto import tqdm import sd1_clip import sd2_clip diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index dfd1994a6..f774b4b77 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -4,7 +4,6 @@ from comfy import model_management import torch import comfy.utils import folder_paths -from tqdm.auto import tqdm class UpscaleModelLoader: @classmethod From fcf513e0b6b599e23b7d6f9bde315be6f991652b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 3 May 2023 17:48:35 -0400 Subject: [PATCH 8/8] Refactor. --- comfy/sd.py | 6 +++++- comfy/utils.py | 6 ++---- comfy_extras/nodes_upscale_model.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index d60b908b8..174ed35e5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -515,6 +515,8 @@ class VAE: def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) + steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) + steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = utils.ProgressBar(steps) decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) @@ -566,7 +568,9 @@ class VAE: self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1).to(self.device) - steps = utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) + steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) + steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) + steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = utils.ProgressBar(steps) samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) diff --git a/comfy/utils.py b/comfy/utils.py index 5c7143fd9..09e05d4ed 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,4 +1,5 @@ import torch +import math def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -63,10 +64,7 @@ def common_upscale(samples, width, height, upscale_method, crop): return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): - it_1 = -(height // -(tile_y * 2 - overlap)) * -(width // -(tile_x // 2 - overlap)) - it_2 = -(height // -(tile_y // 2 - overlap)) * -(width // -(tile_x * 2 - overlap)) - it_3 = -(height // -(tile_y - overlap)) * -(width // -(tile_x - overlap)) - return it_1 + it_2 + it_3 + return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) @torch.inference_mode() def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None): diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index f774b4b77..ab5b0ccfc 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -40,7 +40,7 @@ class ImageUpscaleWithModel: tile = 128 + 64 overlap = 8 - steps = -(in_img.shape[2] // -(tile - overlap)) * -(in_img.shape[3] // -(tile - overlap)) + steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) pbar = comfy.utils.ProgressBar(steps) s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) upscale_model.cpu()