Progress bar hooks, via the server, are now set via a context. This will be used in other places too.

This commit is contained in:
doctorpangloss 2024-05-09 13:24:06 -07:00
parent d6c374942e
commit 881258acb6
11 changed files with 176 additions and 102 deletions

View File

@ -21,6 +21,7 @@ from .. import model_management
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import new_execution_context, ExecutionContext
from ..nodes.package import import_all_nodes_in_workspace
from ..nodes.package_typing import ExportedNodes
@ -153,9 +154,11 @@ def recursive_execute(server: ExecutorToClientProgress,
prompt_id,
outputs_ui,
object_storage):
span = get_current_span()
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
span.set_attribute("class_type", class_type)
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if unique_id in outputs:
return (True, None, None)
@ -374,6 +377,10 @@ class PromptExecutor:
del d
def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
with new_execution_context(ExecutionContext(self.server)):
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)
def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
if execute_outputs is None:
execute_outputs = []
if extra_data is None:

View File

@ -9,7 +9,6 @@ import time
# main_pre must be the earliest import since it suppresses some spurious warnings
from .main_pre import args
from ..utils import hijack_progress
from .extra_model_paths import load_extra_path_config
from .. import model_management
from ..analytics.analytics import initialize_event_tracking
@ -167,7 +166,6 @@ async def main():
server.prompt_queue = q
server.add_routes()
hijack_progress(server)
cuda_malloc_warning()
# in a distributed setting, the default prompt worker will not be able to send execution events via the websocket

View File

@ -102,6 +102,7 @@ class PromptServer(ExecutorToClientProgress):
self.number: int = 0
self.port: int = 8188
self._external_address: Optional[str] = None
self.receive_all_progress_notifications = True
middlewares = [cache_control]
if args.enable_cors_header:

View File

@ -61,6 +61,7 @@ class ExecutorToClientProgress(Protocol):
client_id: Optional[str]
last_node_id: Optional[str]
last_prompt_id: Optional[str]
receive_all_progress_notifications: Optional[bool]
def send_sync(self,
event: SendSyncEvent,

View File

@ -12,7 +12,6 @@ from aio_pika.patterns import RPC
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress, \
UnencodedPreviewImageMessage
from ..component_model.queue_types import BinaryEventTypes
from ..utils import hijack_progress
async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[str] = None,
@ -38,8 +37,7 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
self.node_id = None
self.last_node_id = None
self.last_prompt_id = None
if receive_all_progress_notifications:
hijack_progress(self)
self.receive_all_progress_notifications = receive_all_progress_notifications
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical

View File

@ -16,6 +16,7 @@ class ServerStub(ExecutorToClientProgress):
self.client_id = str(uuid.uuid4())
self.last_node_id = None
self.last_prompt_id = None
self.receive_all_progress_notifications = False
def send_sync(self,
event: Literal["status", "executing"] | BinaryEventTypes | str | None,

View File

@ -0,0 +1,30 @@
from __future__ import annotations
from contextlib import contextmanager
from contextvars import ContextVar
from typing import NamedTuple
from comfy.component_model.executor_types import ExecutorToClientProgress
from comfy.distributed.server_stub import ServerStub
_current_context = ContextVar("comfyui_execution_context")
class ExecutionContext(NamedTuple):
server: ExecutorToClientProgress
_empty_execution_context = ExecutionContext(ServerStub())
def current_execution_context() -> ExecutionContext:
return _current_context.get()
@contextmanager
def new_execution_context(ctx: ExecutionContext):
token = _current_context.set(ctx)
try:
yield
finally:
_current_context.reset(token)

View File

@ -23,6 +23,7 @@ from .. import model_management
from ..cli_args import args
from ..cmd import folder_paths, latent_preview
from ..execution_context import current_execution_context
from ..images import open_image
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES
from ..nodes.common import MAX_RESOLUTION
@ -1297,7 +1298,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
noise_mask = latent["noise_mask"]
callback = latent_preview.prepare_callback(model, steps)
disable_pbar = not utils.PROGRESS_BAR_ENABLED
disable_pbar = not current_execution_context().server.receive_all_progress_notifications
samples = sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)

View File

@ -162,7 +162,8 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
# found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things,
# the way community custom nodes is pretty radioactive
from ..cmd import cuda_malloc, folder_paths, latent_preview
for module in (cuda_malloc, folder_paths, latent_preview):
from .. import node_helpers
for module in (cuda_malloc, folder_paths, latent_preview, node_helpers):
module_short_name = module.__name__.split(".")[-1]
sys.modules[module_short_name] = module
sys.modules['nodes'] = base_nodes

View File

@ -1,26 +1,34 @@
import os.path
import threading
from contextlib import contextmanager
import torch
import logging
import math
import os.path
import struct
import sys
import warnings
from contextlib import contextmanager
from typing import Optional
import numpy as np
import safetensors.torch
import torch
from PIL import Image
from tqdm import tqdm
from . import checkpoint_pickle, interruption
import safetensors.torch
import numpy as np
from PIL import Image
import logging
from .component_model.executor_types import ExecutorToClientProgress
from .component_model.queue_types import BinaryEventTypes
PROGRESS_BAR_ENABLED = True
_progress_bar_hook = threading.local()
from .execution_context import current_execution_context
# deprecate PROGRESS_BAR_ENABLED
def _get_progress_bar_enabled():
warnings.warn(
"The global variable 'PROGRESS_BAR_ENABLED' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
DeprecationWarning,
stacklevel=2
)
return current_execution_context().server.receive_all_progress_notifications
setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled))
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
@ -46,12 +54,14 @@ def load_torch_file(ckpt, safe_load=False, device=None):
sd = pl_sd
return sd
def save_torch_file(sd, ckpt, metadata=None):
if metadata is not None:
safetensors.torch.save_file(sd, ckpt, metadata=metadata)
else:
safetensors.torch.save_file(sd, ckpt)
def calculate_parameters(sd, prefix=""):
params = 0
for k in sd.keys():
@ -59,12 +69,14 @@ def calculate_parameters(sd, prefix=""):
params += sd[k].nelement()
return params
def state_dict_key_replace(state_dict, keys_to_replace):
for x in keys_to_replace:
if x in state_dict:
state_dict[keys_to_replace[x]] = state_dict.pop(x)
return state_dict
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
if filter_keys:
out = {}
@ -115,10 +127,11 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
for x in range(3):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
sd[k_to] = weights[shape_from * x:shape_from * (x + 1)]
return sd
def clip_text_transformers_convert(sd, prefix_from, prefix_to):
sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32)
@ -200,6 +213,7 @@ UNET_MAP_BASIC = {
("time_embed.2.bias", "time_embedding.linear_2.bias")
}
def unet_to_diffusers(unet_config):
if "num_res_blocks" not in unet_config:
return {}
@ -266,6 +280,7 @@ def unet_to_diffusers(unet_config):
return diffusers_unet_map
def repeat_to_batch_size(tensor, batch_size):
if tensor.shape[0] > batch_size:
return tensor[:batch_size]
@ -273,6 +288,7 @@ def repeat_to_batch_size(tensor, batch_size):
return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size]
return tensor
def resize_to_batch_size(tensor, batch_size):
in_batch_size = tensor.shape[0]
if in_batch_size == batch_size:
@ -293,13 +309,15 @@ def resize_to_batch_size(tensor, batch_size):
return output
def convert_sd_to(state_dict, dtype):
keys = list(state_dict.keys())
for k in keys:
state_dict[k] = state_dict[k].to(dtype)
return state_dict
def safetensors_header(safetensors_path, max_size=100*1024*1024):
def safetensors_header(safetensors_path, max_size=100 * 1024 * 1024):
with open(safetensors_path, "rb") as f:
header = f.read(8)
length_of_header = struct.unpack('<Q', header)[0]
@ -307,6 +325,7 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
return None
return f.read(length_of_header)
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
@ -315,9 +334,11 @@ def set_attr(obj, attr, value):
setattr(obj, attrs[-1], value)
return prev
def set_attr_param(obj, attr, value):
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def copy_to_param(obj, attr, value):
# inplace update tensor instead of replacing it
attrs = attr.split(".")
@ -326,88 +347,91 @@ def copy_to_param(obj, attr, value):
prev = getattr(obj, attrs[-1])
prev.data.copy_(value)
def get_attr(obj, attr):
attrs = attr.split(".")
for name in attrs:
obj = getattr(obj, name)
return obj
def bislerp(samples, width, height):
def slerp(b1, b2, r):
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
c = b1.shape[-1]
#norms
# norms
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
#normalize
# normalize
b1_normalized = b1 / b1_norms
b2_normalized = b2 / b2_norms
#zero when norms are zero
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
# zero when norms are zero
b1_normalized[b1_norms.expand(-1, c) == 0.0] = 0.0
b2_normalized[b2_norms.expand(-1, c) == 0.0] = 0.0
#slerp
dot = (b1_normalized*b2_normalized).sum(1)
# slerp
dot = (b1_normalized * b2_normalized).sum(1)
omega = torch.acos(dot)
so = torch.sin(omega)
#technically not mathematically correct, but more pleasing?
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
# technically not mathematically correct, but more pleasing?
res = (torch.sin((1.0 - r.squeeze(1)) * omega) / so).unsqueeze(1) * b1_normalized + (torch.sin(r.squeeze(1) * omega) / so).unsqueeze(1) * b2_normalized
res *= (b1_norms * (1.0 - r) + b2_norms * r).expand(-1, c)
#edge cases for same or polar opposites
# edge cases for same or polar opposites
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
res[dot < 1e-5 - 1] = (b1 * (1.0 - r) + b2 * r)[dot < 1e-5 - 1]
return res
def generate_bilinear_data(length_old, length_new, device):
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1))
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
ratios = coords_1 - coords_1.floor()
coords_1 = coords_1.to(torch.int64)
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
coords_2[:,:,:,-1] -= 1
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1)) + 1
coords_2[:, :, :, -1] -= 1
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
coords_2 = coords_2.to(torch.int64)
return ratios, coords_1, coords_2
orig_dtype = samples.dtype
samples = samples.float()
n,c,h,w = samples.shape
n, c, h, w = samples.shape
h_new, w_new = (height, width)
#linear w
# linear w
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
coords_1 = coords_1.expand((n, c, h, -1))
coords_2 = coords_2.expand((n, c, h, -1))
ratios = ratios.expand((n, 1, h, -1))
pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c))
pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c))
ratios = ratios.movedim(1, -1).reshape((-1,1))
pass_1 = samples.gather(-1, coords_1).movedim(1, -1).reshape((-1, c))
pass_2 = samples.gather(-1, coords_2).movedim(1, -1).reshape((-1, c))
ratios = ratios.movedim(1, -1).reshape((-1, 1))
result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
#linear h
# linear h
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
coords_1 = coords_1.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new))
coords_2 = coords_2.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new))
ratios = ratios.reshape((1, 1, -1, 1)).expand((n, 1, -1, w_new))
pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c))
pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c))
ratios = ratios.movedim(1, -1).reshape((-1,1))
pass_1 = result.gather(-2, coords_1).movedim(1, -1).reshape((-1, c))
pass_2 = result.gather(-2, coords_2).movedim(1, -1).reshape((-1, c))
ratios = ratios.movedim(1, -1).reshape((-1, 1))
result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
return result.to(orig_dtype)
def lanczos(samples, width, height):
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
@ -415,81 +439,95 @@ def lanczos(samples, width, height):
result = torch.stack(images)
return result.to(samples.device, samples.dtype)
def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center":
old_width = samples.shape[3]
old_height = samples.shape[2]
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect:
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples[:,:,y:old_height-y,x:old_width-x]
else:
s = samples
if upscale_method == "bislerp":
return bislerp(s, width, height)
elif upscale_method == "lanczos":
return lanczos(s, width, height)
else:
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center":
old_width = samples.shape[3]
old_height = samples.shape[2]
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect:
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples[:, :, y:old_height - y, x:old_width - x]
else:
s = samples
if upscale_method == "bislerp":
return bislerp(s, width, height)
elif upscale_method == "lanczos":
return lanczos(s, width, height)
else:
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
@torch.inference_mode()
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None):
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device)
for b in range(samples.shape[0]):
s = samples[b:b+1]
s = samples[b:b + 1]
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap):
x = max(0, min(s.shape[-1] - overlap, x))
y = max(0, min(s.shape[-2] - overlap, y))
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
s_in = s[:, :, y:y + tile_y, x:x + tile_x]
ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount)
for t in range(feather):
mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask
out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask
mask[:, :, t:1 + t, :] *= ((1.0 / feather) * (t + 1))
mask[:, :, mask.shape[2] - 1 - t: mask.shape[2] - t, :] *= ((1.0 / feather) * (t + 1))
mask[:, :, :, t:1 + t] *= ((1.0 / feather) * (t + 1))
mask[:, :, :, mask.shape[3] - 1 - t: mask.shape[3] - t] *= ((1.0 / feather) * (t + 1))
out[:, :, round(y * upscale_amount):round((y + tile_y) * upscale_amount), round(x * upscale_amount):round((x + tile_x) * upscale_amount)] += ps * mask
out_div[:, :, round(y * upscale_amount):round((y + tile_y) * upscale_amount), round(x * upscale_amount):round((x + tile_x) * upscale_amount)] += mask
if pbar is not None:
pbar.update(1)
output[b:b+1] = out/out_div
output[b:b + 1] = out / out_div
return output
def hijack_progress(server: ExecutorToClientProgress):
def hook(value: float, total: float, preview_image):
interruption.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
def _progress_bar_update(value: float, total: float, preview_image, client_id: Optional[str] = None):
server = current_execution_context().server
# todo: this should really be from the context. right now the server is behaving like a context
client_id = client_id or server.client_id
interruption.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
server.send_sync("progress", progress, server.client_id)
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
_progress_bar_hook.hook = hook
server.send_sync("progress", progress, client_id)
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, client_id)
def set_progress_bar_enabled(enabled):
global PROGRESS_BAR_ENABLED
PROGRESS_BAR_ENABLED = enabled
def set_progress_bar_enabled(enabled: bool):
warnings.warn(
"The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
DeprecationWarning,
stacklevel=2
)
current_execution_context().server.receive_all_progress_notifications = enabled
pass
def get_progress_bar_enabled() -> bool:
return PROGRESS_BAR_ENABLED
warnings.warn(
"The global method 'get_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
DeprecationWarning,
stacklevel=2
)
return current_execution_context().server.receive_all_progress_notifications
class _DisabledProgressBar:
@ -505,10 +543,8 @@ class _DisabledProgressBar:
class ProgressBar:
def __init__(self, total: float):
global _progress_bar_hook
self.total: float = total
self.current: float = 0.0
self.hook = _progress_bar_hook.hook if hasattr(_progress_bar_hook, "hook") else None
def update_absolute(self, value, total=None, preview=None):
if total is not None:
@ -516,8 +552,7 @@ class ProgressBar:
if value > self.total:
value = self.total
self.current = value
if self.hook is not None:
self.hook(self.current, self.total, preview)
_progress_bar_update(self.current, self.total, preview)
def update(self, value):
self.update_absolute(self.current + value)
@ -556,8 +591,8 @@ def comfy_tqdm():
@contextmanager
def comfy_progress(total: float) -> ProgressBar:
global PROGRESS_BAR_ENABLED
if PROGRESS_BAR_ENABLED:
ctx = current_execution_context()
if ctx.server.receive_all_progress_notifications:
yield ProgressBar(total)
else:
yield _DisabledProgressBar()

View File

@ -2,6 +2,7 @@ import comfy.sampler_names
from comfy import samplers
from comfy import model_management
from comfy import sample
from comfy.execution_context import current_execution_context
from comfy.k_diffusion import sampling as k_diffusion_sampling
from comfy.cmd import latent_preview
import torch
@ -416,7 +417,7 @@ class SamplerCustom:
x0_output = {}
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
disable_pbar = not utils.PROGRESS_BAR_ENABLED
disable_pbar = not current_execution_context().server.receive_all_progress_notifications
samples = sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
out = latent.copy()
@ -570,7 +571,7 @@ class SamplerCustomAdvanced:
x0_output = {}
callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
disable_pbar = not current_execution_context().server.receive_all_progress_notifications
samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
samples = samples.to(comfy.model_management.intermediate_device())