From 881258acb659b8e190b7e9a7fc90f8be9d87b161 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Thu, 9 May 2024 13:24:06 -0700 Subject: [PATCH] Progress bar hooks, via the server, are now set via a context. This will be used in other places too. --- comfy/cmd/execution.py | 7 + comfy/cmd/main.py | 2 - comfy/cmd/server.py | 1 + comfy/component_model/executor_types.py | 1 + comfy/distributed/distributed_progress.py | 4 +- comfy/distributed/server_stub.py | 1 + comfy/execution_context.py | 30 +++ comfy/nodes/base_nodes.py | 3 +- comfy/nodes/vanilla_node_importing.py | 3 +- comfy/utils.py | 221 ++++++++++++--------- comfy_extras/nodes/nodes_custom_sampler.py | 5 +- 11 files changed, 176 insertions(+), 102 deletions(-) create mode 100644 comfy/execution_context.py diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 6c7c0afb6..4b82f98e1 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -21,6 +21,7 @@ from .. import model_management from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus +from ..execution_context import new_execution_context, ExecutionContext from ..nodes.package import import_all_nodes_in_workspace from ..nodes.package_typing import ExportedNodes @@ -153,9 +154,11 @@ def recursive_execute(server: ExecutorToClientProgress, prompt_id, outputs_ui, object_storage): + span = get_current_span() unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] + span.set_attribute("class_type", class_type) class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: return (True, None, None) @@ -374,6 +377,10 @@ class PromptExecutor: del d def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None): + with new_execution_context(ExecutionContext(self.server)): + self._execute_inner(prompt, prompt_id, extra_data, execute_outputs) + + def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None): if execute_outputs is None: execute_outputs = [] if extra_data is None: diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index efb03e599..48886eef5 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -9,7 +9,6 @@ import time # main_pre must be the earliest import since it suppresses some spurious warnings from .main_pre import args -from ..utils import hijack_progress from .extra_model_paths import load_extra_path_config from .. import model_management from ..analytics.analytics import initialize_event_tracking @@ -167,7 +166,6 @@ async def main(): server.prompt_queue = q server.add_routes() - hijack_progress(server) cuda_malloc_warning() # in a distributed setting, the default prompt worker will not be able to send execution events via the websocket diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index fca02d485..5a5d18c49 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -102,6 +102,7 @@ class PromptServer(ExecutorToClientProgress): self.number: int = 0 self.port: int = 8188 self._external_address: Optional[str] = None + self.receive_all_progress_notifications = True middlewares = [cache_control] if args.enable_cors_header: diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index 6dfc1e4be..1aa8b1a38 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -61,6 +61,7 @@ class ExecutorToClientProgress(Protocol): client_id: Optional[str] last_node_id: Optional[str] last_prompt_id: Optional[str] + receive_all_progress_notifications: Optional[bool] def send_sync(self, event: SendSyncEvent, diff --git a/comfy/distributed/distributed_progress.py b/comfy/distributed/distributed_progress.py index 264f385cd..4b8a5c0a0 100644 --- a/comfy/distributed/distributed_progress.py +++ b/comfy/distributed/distributed_progress.py @@ -12,7 +12,6 @@ from aio_pika.patterns import RPC from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress, \ UnencodedPreviewImageMessage from ..component_model.queue_types import BinaryEventTypes -from ..utils import hijack_progress async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[str] = None, @@ -38,8 +37,7 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress): self.node_id = None self.last_node_id = None self.last_prompt_id = None - if receive_all_progress_notifications: - hijack_progress(self) + self.receive_all_progress_notifications = receive_all_progress_notifications async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None: # for now, do not send binary data this way, since it cannot be json serialized / it's impractical diff --git a/comfy/distributed/server_stub.py b/comfy/distributed/server_stub.py index 490492a02..ea8ec952e 100644 --- a/comfy/distributed/server_stub.py +++ b/comfy/distributed/server_stub.py @@ -16,6 +16,7 @@ class ServerStub(ExecutorToClientProgress): self.client_id = str(uuid.uuid4()) self.last_node_id = None self.last_prompt_id = None + self.receive_all_progress_notifications = False def send_sync(self, event: Literal["status", "executing"] | BinaryEventTypes | str | None, diff --git a/comfy/execution_context.py b/comfy/execution_context.py new file mode 100644 index 000000000..15a78f97d --- /dev/null +++ b/comfy/execution_context.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from contextlib import contextmanager +from contextvars import ContextVar +from typing import NamedTuple + +from comfy.component_model.executor_types import ExecutorToClientProgress +from comfy.distributed.server_stub import ServerStub + +_current_context = ContextVar("comfyui_execution_context") + + +class ExecutionContext(NamedTuple): + server: ExecutorToClientProgress + + +_empty_execution_context = ExecutionContext(ServerStub()) + + +def current_execution_context() -> ExecutionContext: + return _current_context.get() + + +@contextmanager +def new_execution_context(ctx: ExecutionContext): + token = _current_context.set(ctx) + try: + yield + finally: + _current_context.reset(token) diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 1a3397479..0a0cc7862 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -23,6 +23,7 @@ from .. import model_management from ..cli_args import args from ..cmd import folder_paths, latent_preview +from ..execution_context import current_execution_context from ..images import open_image from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES from ..nodes.common import MAX_RESOLUTION @@ -1297,7 +1298,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, noise_mask = latent["noise_mask"] callback = latent_preview.prepare_callback(model, steps) - disable_pbar = not utils.PROGRESS_BAR_ENABLED + disable_pbar = not current_execution_context().server.receive_all_progress_notifications samples = sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) diff --git a/comfy/nodes/vanilla_node_importing.py b/comfy/nodes/vanilla_node_importing.py index b6e40b308..a08198f72 100644 --- a/comfy/nodes/vanilla_node_importing.py +++ b/comfy/nodes/vanilla_node_importing.py @@ -162,7 +162,8 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes: # found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things, # the way community custom nodes is pretty radioactive from ..cmd import cuda_malloc, folder_paths, latent_preview - for module in (cuda_malloc, folder_paths, latent_preview): + from .. import node_helpers + for module in (cuda_malloc, folder_paths, latent_preview, node_helpers): module_short_name = module.__name__.split(".")[-1] sys.modules[module_short_name] = module sys.modules['nodes'] = base_nodes diff --git a/comfy/utils.py b/comfy/utils.py index f26b8b834..75e56c1db 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,26 +1,34 @@ -import os.path -import threading -from contextlib import contextmanager - -import torch +import logging import math +import os.path import struct +import sys +import warnings +from contextlib import contextmanager +from typing import Optional +import numpy as np +import safetensors.torch +import torch +from PIL import Image from tqdm import tqdm from . import checkpoint_pickle, interruption -import safetensors.torch -import numpy as np -from PIL import Image -import logging - -from .component_model.executor_types import ExecutorToClientProgress from .component_model.queue_types import BinaryEventTypes - -PROGRESS_BAR_ENABLED = True -_progress_bar_hook = threading.local() +from .execution_context import current_execution_context +# deprecate PROGRESS_BAR_ENABLED +def _get_progress_bar_enabled(): + warnings.warn( + "The global variable 'PROGRESS_BAR_ENABLED' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.", + DeprecationWarning, + stacklevel=2 + ) + return current_execution_context().server.receive_all_progress_notifications + + +setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled)) def load_torch_file(ckpt, safe_load=False, device=None): if device is None: @@ -46,12 +54,14 @@ def load_torch_file(ckpt, safe_load=False, device=None): sd = pl_sd return sd + def save_torch_file(sd, ckpt, metadata=None): if metadata is not None: safetensors.torch.save_file(sd, ckpt, metadata=metadata) else: safetensors.torch.save_file(sd, ckpt) + def calculate_parameters(sd, prefix=""): params = 0 for k in sd.keys(): @@ -59,12 +69,14 @@ def calculate_parameters(sd, prefix=""): params += sd[k].nelement() return params + def state_dict_key_replace(state_dict, keys_to_replace): for x in keys_to_replace: if x in state_dict: state_dict[keys_to_replace[x]] = state_dict.pop(x) return state_dict + def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): if filter_keys: out = {} @@ -115,10 +127,11 @@ def transformers_convert(sd, prefix_from, prefix_to, number): for x in range(3): p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) - sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] + sd[k_to] = weights[shape_from * x:shape_from * (x + 1)] return sd + def clip_text_transformers_convert(sd, prefix_from, prefix_to): sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32) @@ -200,6 +213,7 @@ UNET_MAP_BASIC = { ("time_embed.2.bias", "time_embedding.linear_2.bias") } + def unet_to_diffusers(unet_config): if "num_res_blocks" not in unet_config: return {} @@ -266,6 +280,7 @@ def unet_to_diffusers(unet_config): return diffusers_unet_map + def repeat_to_batch_size(tensor, batch_size): if tensor.shape[0] > batch_size: return tensor[:batch_size] @@ -273,6 +288,7 @@ def repeat_to_batch_size(tensor, batch_size): return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size] return tensor + def resize_to_batch_size(tensor, batch_size): in_batch_size = tensor.shape[0] if in_batch_size == batch_size: @@ -293,13 +309,15 @@ def resize_to_batch_size(tensor, batch_size): return output + def convert_sd_to(state_dict, dtype): keys = list(state_dict.keys()) for k in keys: state_dict[k] = state_dict[k].to(dtype) return state_dict -def safetensors_header(safetensors_path, max_size=100*1024*1024): + +def safetensors_header(safetensors_path, max_size=100 * 1024 * 1024): with open(safetensors_path, "rb") as f: header = f.read(8) length_of_header = struct.unpack(' 1 - 1e-5] = b1[dot > 1 - 1e-5] - res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] + res[dot < 1e-5 - 1] = (b1 * (1.0 - r) + b2 * r)[dot < 1e-5 - 1] return res def generate_bilinear_data(length_old, length_new, device): - coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1)) coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") ratios = coords_1 - coords_1.floor() coords_1 = coords_1.to(torch.int64) - coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1 - coords_2[:,:,:,-1] -= 1 + coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1)) + 1 + coords_2[:, :, :, -1] -= 1 coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") coords_2 = coords_2.to(torch.int64) return ratios, coords_1, coords_2 orig_dtype = samples.dtype samples = samples.float() - n,c,h,w = samples.shape + n, c, h, w = samples.shape h_new, w_new = (height, width) - #linear w + # linear w ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device) coords_1 = coords_1.expand((n, c, h, -1)) coords_2 = coords_2.expand((n, c, h, -1)) ratios = ratios.expand((n, 1, h, -1)) - pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c)) - pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c)) - ratios = ratios.movedim(1, -1).reshape((-1,1)) + pass_1 = samples.gather(-1, coords_1).movedim(1, -1).reshape((-1, c)) + pass_2 = samples.gather(-1, coords_2).movedim(1, -1).reshape((-1, c)) + ratios = ratios.movedim(1, -1).reshape((-1, 1)) result = slerp(pass_1, pass_2, ratios) result = result.reshape(n, h, w_new, c).movedim(-1, 1) - #linear h + # linear h ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device) - coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) - coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) - ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) + coords_1 = coords_1.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new)) + coords_2 = coords_2.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new)) + ratios = ratios.reshape((1, 1, -1, 1)).expand((n, 1, -1, w_new)) - pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c)) - pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c)) - ratios = ratios.movedim(1, -1).reshape((-1,1)) + pass_1 = result.gather(-2, coords_1).movedim(1, -1).reshape((-1, c)) + pass_2 = result.gather(-2, coords_2).movedim(1, -1).reshape((-1, c)) + ratios = ratios.movedim(1, -1).reshape((-1, 1)) result = slerp(pass_1, pass_2, ratios) result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) return result.to(orig_dtype) + def lanczos(samples, width, height): images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] @@ -415,81 +439,95 @@ def lanczos(samples, width, height): result = torch.stack(images) return result.to(samples.device, samples.dtype) -def common_upscale(samples, width, height, upscale_method, crop): - if crop == "center": - old_width = samples.shape[3] - old_height = samples.shape[2] - old_aspect = old_width / old_height - new_aspect = width / height - x = 0 - y = 0 - if old_aspect > new_aspect: - x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) - elif old_aspect < new_aspect: - y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) - s = samples[:,:,y:old_height-y,x:old_width-x] - else: - s = samples - if upscale_method == "bislerp": - return bislerp(s, width, height) - elif upscale_method == "lanczos": - return lanczos(s, width, height) - else: - return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) +def common_upscale(samples, width, height, upscale_method, crop): + if crop == "center": + old_width = samples.shape[3] + old_height = samples.shape[2] + old_aspect = old_width / old_height + new_aspect = width / height + x = 0 + y = 0 + if old_aspect > new_aspect: + x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) + elif old_aspect < new_aspect: + y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) + s = samples[:, :, y:old_height - y, x:old_width - x] + else: + s = samples + + if upscale_method == "bislerp": + return bislerp(s, width, height) + elif upscale_method == "lanczos": + return lanczos(s, width, height) + else: + return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) + def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) + @torch.inference_mode() -def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): +def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None): output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device) for b in range(samples.shape[0]): - s = samples[b:b+1] + s = samples[b:b + 1] out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device) out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device) for y in range(0, s.shape[2], tile_y - overlap): for x in range(0, s.shape[3], tile_x - overlap): x = max(0, min(s.shape[-1] - overlap, x)) y = max(0, min(s.shape[-2] - overlap, y)) - s_in = s[:,:,y:y+tile_y,x:x+tile_x] + s_in = s[:, :, y:y + tile_y, x:x + tile_x] ps = function(s_in).to(output_device) mask = torch.ones_like(ps) feather = round(overlap * upscale_amount) for t in range(feather): - mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1)) - mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1)) - mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1)) - mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1)) - out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask - out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask + mask[:, :, t:1 + t, :] *= ((1.0 / feather) * (t + 1)) + mask[:, :, mask.shape[2] - 1 - t: mask.shape[2] - t, :] *= ((1.0 / feather) * (t + 1)) + mask[:, :, :, t:1 + t] *= ((1.0 / feather) * (t + 1)) + mask[:, :, :, mask.shape[3] - 1 - t: mask.shape[3] - t] *= ((1.0 / feather) * (t + 1)) + out[:, :, round(y * upscale_amount):round((y + tile_y) * upscale_amount), round(x * upscale_amount):round((x + tile_x) * upscale_amount)] += ps * mask + out_div[:, :, round(y * upscale_amount):round((y + tile_y) * upscale_amount), round(x * upscale_amount):round((x + tile_x) * upscale_amount)] += mask if pbar is not None: pbar.update(1) - output[b:b+1] = out/out_div + output[b:b + 1] = out / out_div return output -def hijack_progress(server: ExecutorToClientProgress): - def hook(value: float, total: float, preview_image): - interruption.throw_exception_if_processing_interrupted() - progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} +def _progress_bar_update(value: float, total: float, preview_image, client_id: Optional[str] = None): + server = current_execution_context().server + # todo: this should really be from the context. right now the server is behaving like a context + client_id = client_id or server.client_id + interruption.throw_exception_if_processing_interrupted() + progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} - server.send_sync("progress", progress, server.client_id) - if preview_image is not None: - server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) - - _progress_bar_hook.hook = hook + server.send_sync("progress", progress, client_id) + if preview_image is not None: + server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, client_id) -def set_progress_bar_enabled(enabled): - global PROGRESS_BAR_ENABLED - PROGRESS_BAR_ENABLED = enabled +def set_progress_bar_enabled(enabled: bool): + warnings.warn( + "The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.", + DeprecationWarning, + stacklevel=2 + ) + + current_execution_context().server.receive_all_progress_notifications = enabled + pass def get_progress_bar_enabled() -> bool: - return PROGRESS_BAR_ENABLED + warnings.warn( + "The global method 'get_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.", + DeprecationWarning, + stacklevel=2 + ) + return current_execution_context().server.receive_all_progress_notifications class _DisabledProgressBar: @@ -505,10 +543,8 @@ class _DisabledProgressBar: class ProgressBar: def __init__(self, total: float): - global _progress_bar_hook self.total: float = total self.current: float = 0.0 - self.hook = _progress_bar_hook.hook if hasattr(_progress_bar_hook, "hook") else None def update_absolute(self, value, total=None, preview=None): if total is not None: @@ -516,8 +552,7 @@ class ProgressBar: if value > self.total: value = self.total self.current = value - if self.hook is not None: - self.hook(self.current, self.total, preview) + _progress_bar_update(self.current, self.total, preview) def update(self, value): self.update_absolute(self.current + value) @@ -556,8 +591,8 @@ def comfy_tqdm(): @contextmanager def comfy_progress(total: float) -> ProgressBar: - global PROGRESS_BAR_ENABLED - if PROGRESS_BAR_ENABLED: + ctx = current_execution_context() + if ctx.server.receive_all_progress_notifications: yield ProgressBar(total) else: yield _DisabledProgressBar() diff --git a/comfy_extras/nodes/nodes_custom_sampler.py b/comfy_extras/nodes/nodes_custom_sampler.py index f3a7d7e37..7baaf5307 100644 --- a/comfy_extras/nodes/nodes_custom_sampler.py +++ b/comfy_extras/nodes/nodes_custom_sampler.py @@ -2,6 +2,7 @@ import comfy.sampler_names from comfy import samplers from comfy import model_management from comfy import sample +from comfy.execution_context import current_execution_context from comfy.k_diffusion import sampling as k_diffusion_sampling from comfy.cmd import latent_preview import torch @@ -416,7 +417,7 @@ class SamplerCustom: x0_output = {} callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) - disable_pbar = not utils.PROGRESS_BAR_ENABLED + disable_pbar = not current_execution_context().server.receive_all_progress_notifications samples = sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) out = latent.copy() @@ -570,7 +571,7 @@ class SamplerCustomAdvanced: x0_output = {} callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output) - disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED + disable_pbar = not current_execution_context().server.receive_all_progress_notifications samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed) samples = samples.to(comfy.model_management.intermediate_device())