mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Progress bar hooks, via the server, are now set via a context. This will be used in other places too.
This commit is contained in:
parent
d6c374942e
commit
881258acb6
@ -21,6 +21,7 @@ from .. import model_management
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||
from ..execution_context import new_execution_context, ExecutionContext
|
||||
from ..nodes.package import import_all_nodes_in_workspace
|
||||
from ..nodes.package_typing import ExportedNodes
|
||||
|
||||
@ -153,9 +154,11 @@ def recursive_execute(server: ExecutorToClientProgress,
|
||||
prompt_id,
|
||||
outputs_ui,
|
||||
object_storage):
|
||||
span = get_current_span()
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
span.set_attribute("class_type", class_type)
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if unique_id in outputs:
|
||||
return (True, None, None)
|
||||
@ -374,6 +377,10 @@ class PromptExecutor:
|
||||
del d
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
|
||||
with new_execution_context(ExecutionContext(self.server)):
|
||||
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)
|
||||
|
||||
def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
|
||||
if execute_outputs is None:
|
||||
execute_outputs = []
|
||||
if extra_data is None:
|
||||
|
||||
@ -9,7 +9,6 @@ import time
|
||||
|
||||
# main_pre must be the earliest import since it suppresses some spurious warnings
|
||||
from .main_pre import args
|
||||
from ..utils import hijack_progress
|
||||
from .extra_model_paths import load_extra_path_config
|
||||
from .. import model_management
|
||||
from ..analytics.analytics import initialize_event_tracking
|
||||
@ -167,7 +166,6 @@ async def main():
|
||||
server.prompt_queue = q
|
||||
|
||||
server.add_routes()
|
||||
hijack_progress(server)
|
||||
cuda_malloc_warning()
|
||||
|
||||
# in a distributed setting, the default prompt worker will not be able to send execution events via the websocket
|
||||
|
||||
@ -102,6 +102,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
self.number: int = 0
|
||||
self.port: int = 8188
|
||||
self._external_address: Optional[str] = None
|
||||
self.receive_all_progress_notifications = True
|
||||
|
||||
middlewares = [cache_control]
|
||||
if args.enable_cors_header:
|
||||
|
||||
@ -61,6 +61,7 @@ class ExecutorToClientProgress(Protocol):
|
||||
client_id: Optional[str]
|
||||
last_node_id: Optional[str]
|
||||
last_prompt_id: Optional[str]
|
||||
receive_all_progress_notifications: Optional[bool]
|
||||
|
||||
def send_sync(self,
|
||||
event: SendSyncEvent,
|
||||
|
||||
@ -12,7 +12,6 @@ from aio_pika.patterns import RPC
|
||||
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress, \
|
||||
UnencodedPreviewImageMessage
|
||||
from ..component_model.queue_types import BinaryEventTypes
|
||||
from ..utils import hijack_progress
|
||||
|
||||
|
||||
async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[str] = None,
|
||||
@ -38,8 +37,7 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
||||
self.node_id = None
|
||||
self.last_node_id = None
|
||||
self.last_prompt_id = None
|
||||
if receive_all_progress_notifications:
|
||||
hijack_progress(self)
|
||||
self.receive_all_progress_notifications = receive_all_progress_notifications
|
||||
|
||||
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
|
||||
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical
|
||||
|
||||
@ -16,6 +16,7 @@ class ServerStub(ExecutorToClientProgress):
|
||||
self.client_id = str(uuid.uuid4())
|
||||
self.last_node_id = None
|
||||
self.last_prompt_id = None
|
||||
self.receive_all_progress_notifications = False
|
||||
|
||||
def send_sync(self,
|
||||
event: Literal["status", "executing"] | BinaryEventTypes | str | None,
|
||||
|
||||
30
comfy/execution_context.py
Normal file
30
comfy/execution_context.py
Normal file
@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import NamedTuple
|
||||
|
||||
from comfy.component_model.executor_types import ExecutorToClientProgress
|
||||
from comfy.distributed.server_stub import ServerStub
|
||||
|
||||
_current_context = ContextVar("comfyui_execution_context")
|
||||
|
||||
|
||||
class ExecutionContext(NamedTuple):
|
||||
server: ExecutorToClientProgress
|
||||
|
||||
|
||||
_empty_execution_context = ExecutionContext(ServerStub())
|
||||
|
||||
|
||||
def current_execution_context() -> ExecutionContext:
|
||||
return _current_context.get()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def new_execution_context(ctx: ExecutionContext):
|
||||
token = _current_context.set(ctx)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_current_context.reset(token)
|
||||
@ -23,6 +23,7 @@ from .. import model_management
|
||||
from ..cli_args import args
|
||||
|
||||
from ..cmd import folder_paths, latent_preview
|
||||
from ..execution_context import current_execution_context
|
||||
from ..images import open_image
|
||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES
|
||||
from ..nodes.common import MAX_RESOLUTION
|
||||
@ -1297,7 +1298,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
noise_mask = latent["noise_mask"]
|
||||
|
||||
callback = latent_preview.prepare_callback(model, steps)
|
||||
disable_pbar = not utils.PROGRESS_BAR_ENABLED
|
||||
disable_pbar = not current_execution_context().server.receive_all_progress_notifications
|
||||
samples = sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
|
||||
@ -162,7 +162,8 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
|
||||
# found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things,
|
||||
# the way community custom nodes is pretty radioactive
|
||||
from ..cmd import cuda_malloc, folder_paths, latent_preview
|
||||
for module in (cuda_malloc, folder_paths, latent_preview):
|
||||
from .. import node_helpers
|
||||
for module in (cuda_malloc, folder_paths, latent_preview, node_helpers):
|
||||
module_short_name = module.__name__.split(".")[-1]
|
||||
sys.modules[module_short_name] = module
|
||||
sys.modules['nodes'] = base_nodes
|
||||
|
||||
221
comfy/utils.py
221
comfy/utils.py
@ -1,26 +1,34 @@
|
||||
import os.path
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import math
|
||||
import os.path
|
||||
import struct
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import checkpoint_pickle, interruption
|
||||
import safetensors.torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import logging
|
||||
|
||||
from .component_model.executor_types import ExecutorToClientProgress
|
||||
from .component_model.queue_types import BinaryEventTypes
|
||||
|
||||
PROGRESS_BAR_ENABLED = True
|
||||
_progress_bar_hook = threading.local()
|
||||
from .execution_context import current_execution_context
|
||||
|
||||
|
||||
# deprecate PROGRESS_BAR_ENABLED
|
||||
def _get_progress_bar_enabled():
|
||||
warnings.warn(
|
||||
"The global variable 'PROGRESS_BAR_ENABLED' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return current_execution_context().server.receive_all_progress_notifications
|
||||
|
||||
|
||||
setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled))
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
if device is None:
|
||||
@ -46,12 +54,14 @@ def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
sd = pl_sd
|
||||
return sd
|
||||
|
||||
|
||||
def save_torch_file(sd, ckpt, metadata=None):
|
||||
if metadata is not None:
|
||||
safetensors.torch.save_file(sd, ckpt, metadata=metadata)
|
||||
else:
|
||||
safetensors.torch.save_file(sd, ckpt)
|
||||
|
||||
|
||||
def calculate_parameters(sd, prefix=""):
|
||||
params = 0
|
||||
for k in sd.keys():
|
||||
@ -59,12 +69,14 @@ def calculate_parameters(sd, prefix=""):
|
||||
params += sd[k].nelement()
|
||||
return params
|
||||
|
||||
|
||||
def state_dict_key_replace(state_dict, keys_to_replace):
|
||||
for x in keys_to_replace:
|
||||
if x in state_dict:
|
||||
state_dict[keys_to_replace[x]] = state_dict.pop(x)
|
||||
return state_dict
|
||||
|
||||
|
||||
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
|
||||
if filter_keys:
|
||||
out = {}
|
||||
@ -115,10 +127,11 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
for x in range(3):
|
||||
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
||||
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
sd[k_to] = weights[shape_from * x:shape_from * (x + 1)]
|
||||
|
||||
return sd
|
||||
|
||||
|
||||
def clip_text_transformers_convert(sd, prefix_from, prefix_to):
|
||||
sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32)
|
||||
|
||||
@ -200,6 +213,7 @@ UNET_MAP_BASIC = {
|
||||
("time_embed.2.bias", "time_embedding.linear_2.bias")
|
||||
}
|
||||
|
||||
|
||||
def unet_to_diffusers(unet_config):
|
||||
if "num_res_blocks" not in unet_config:
|
||||
return {}
|
||||
@ -266,6 +280,7 @@ def unet_to_diffusers(unet_config):
|
||||
|
||||
return diffusers_unet_map
|
||||
|
||||
|
||||
def repeat_to_batch_size(tensor, batch_size):
|
||||
if tensor.shape[0] > batch_size:
|
||||
return tensor[:batch_size]
|
||||
@ -273,6 +288,7 @@ def repeat_to_batch_size(tensor, batch_size):
|
||||
return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size]
|
||||
return tensor
|
||||
|
||||
|
||||
def resize_to_batch_size(tensor, batch_size):
|
||||
in_batch_size = tensor.shape[0]
|
||||
if in_batch_size == batch_size:
|
||||
@ -293,13 +309,15 @@ def resize_to_batch_size(tensor, batch_size):
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def convert_sd_to(state_dict, dtype):
|
||||
keys = list(state_dict.keys())
|
||||
for k in keys:
|
||||
state_dict[k] = state_dict[k].to(dtype)
|
||||
return state_dict
|
||||
|
||||
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||
|
||||
def safetensors_header(safetensors_path, max_size=100 * 1024 * 1024):
|
||||
with open(safetensors_path, "rb") as f:
|
||||
header = f.read(8)
|
||||
length_of_header = struct.unpack('<Q', header)[0]
|
||||
@ -307,6 +325,7 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||
return None
|
||||
return f.read(length_of_header)
|
||||
|
||||
|
||||
def set_attr(obj, attr, value):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
@ -315,9 +334,11 @@ def set_attr(obj, attr, value):
|
||||
setattr(obj, attrs[-1], value)
|
||||
return prev
|
||||
|
||||
|
||||
def set_attr_param(obj, attr, value):
|
||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||
|
||||
|
||||
def copy_to_param(obj, attr, value):
|
||||
# inplace update tensor instead of replacing it
|
||||
attrs = attr.split(".")
|
||||
@ -326,88 +347,91 @@ def copy_to_param(obj, attr, value):
|
||||
prev = getattr(obj, attrs[-1])
|
||||
prev.data.copy_(value)
|
||||
|
||||
|
||||
def get_attr(obj, attr):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs:
|
||||
obj = getattr(obj, name)
|
||||
return obj
|
||||
|
||||
|
||||
def bislerp(samples, width, height):
|
||||
def slerp(b1, b2, r):
|
||||
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
||||
|
||||
c = b1.shape[-1]
|
||||
|
||||
#norms
|
||||
# norms
|
||||
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
||||
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
||||
|
||||
#normalize
|
||||
# normalize
|
||||
b1_normalized = b1 / b1_norms
|
||||
b2_normalized = b2 / b2_norms
|
||||
|
||||
#zero when norms are zero
|
||||
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
|
||||
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
|
||||
# zero when norms are zero
|
||||
b1_normalized[b1_norms.expand(-1, c) == 0.0] = 0.0
|
||||
b2_normalized[b2_norms.expand(-1, c) == 0.0] = 0.0
|
||||
|
||||
#slerp
|
||||
dot = (b1_normalized*b2_normalized).sum(1)
|
||||
# slerp
|
||||
dot = (b1_normalized * b2_normalized).sum(1)
|
||||
omega = torch.acos(dot)
|
||||
so = torch.sin(omega)
|
||||
|
||||
#technically not mathematically correct, but more pleasing?
|
||||
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
|
||||
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
|
||||
# technically not mathematically correct, but more pleasing?
|
||||
res = (torch.sin((1.0 - r.squeeze(1)) * omega) / so).unsqueeze(1) * b1_normalized + (torch.sin(r.squeeze(1) * omega) / so).unsqueeze(1) * b2_normalized
|
||||
res *= (b1_norms * (1.0 - r) + b2_norms * r).expand(-1, c)
|
||||
|
||||
#edge cases for same or polar opposites
|
||||
# edge cases for same or polar opposites
|
||||
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
||||
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
||||
res[dot < 1e-5 - 1] = (b1 * (1.0 - r) + b2 * r)[dot < 1e-5 - 1]
|
||||
return res
|
||||
|
||||
def generate_bilinear_data(length_old, length_new, device):
|
||||
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
|
||||
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1))
|
||||
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
||||
ratios = coords_1 - coords_1.floor()
|
||||
coords_1 = coords_1.to(torch.int64)
|
||||
|
||||
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
|
||||
coords_2[:,:,:,-1] -= 1
|
||||
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1)) + 1
|
||||
coords_2[:, :, :, -1] -= 1
|
||||
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
||||
coords_2 = coords_2.to(torch.int64)
|
||||
return ratios, coords_1, coords_2
|
||||
|
||||
orig_dtype = samples.dtype
|
||||
samples = samples.float()
|
||||
n,c,h,w = samples.shape
|
||||
n, c, h, w = samples.shape
|
||||
h_new, w_new = (height, width)
|
||||
|
||||
#linear w
|
||||
# linear w
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
|
||||
coords_1 = coords_1.expand((n, c, h, -1))
|
||||
coords_2 = coords_2.expand((n, c, h, -1))
|
||||
ratios = ratios.expand((n, 1, h, -1))
|
||||
|
||||
pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c))
|
||||
pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c))
|
||||
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
||||
pass_1 = samples.gather(-1, coords_1).movedim(1, -1).reshape((-1, c))
|
||||
pass_2 = samples.gather(-1, coords_2).movedim(1, -1).reshape((-1, c))
|
||||
ratios = ratios.movedim(1, -1).reshape((-1, 1))
|
||||
|
||||
result = slerp(pass_1, pass_2, ratios)
|
||||
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
|
||||
|
||||
#linear h
|
||||
# linear h
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
|
||||
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
|
||||
coords_1 = coords_1.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new))
|
||||
coords_2 = coords_2.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new))
|
||||
ratios = ratios.reshape((1, 1, -1, 1)).expand((n, 1, -1, w_new))
|
||||
|
||||
pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c))
|
||||
pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c))
|
||||
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
||||
pass_1 = result.gather(-2, coords_1).movedim(1, -1).reshape((-1, c))
|
||||
pass_2 = result.gather(-2, coords_2).movedim(1, -1).reshape((-1, c))
|
||||
ratios = ratios.movedim(1, -1).reshape((-1, 1))
|
||||
|
||||
result = slerp(pass_1, pass_2, ratios)
|
||||
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
||||
return result.to(orig_dtype)
|
||||
|
||||
|
||||
def lanczos(samples, width, height):
|
||||
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||
@ -415,81 +439,95 @@ def lanczos(samples, width, height):
|
||||
result = torch.stack(images)
|
||||
return result.to(samples.device, samples.dtype)
|
||||
|
||||
def common_upscale(samples, width, height, upscale_method, crop):
|
||||
if crop == "center":
|
||||
old_width = samples.shape[3]
|
||||
old_height = samples.shape[2]
|
||||
old_aspect = old_width / old_height
|
||||
new_aspect = width / height
|
||||
x = 0
|
||||
y = 0
|
||||
if old_aspect > new_aspect:
|
||||
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||
elif old_aspect < new_aspect:
|
||||
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||
s = samples[:,:,y:old_height-y,x:old_width-x]
|
||||
else:
|
||||
s = samples
|
||||
|
||||
if upscale_method == "bislerp":
|
||||
return bislerp(s, width, height)
|
||||
elif upscale_method == "lanczos":
|
||||
return lanczos(s, width, height)
|
||||
else:
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
def common_upscale(samples, width, height, upscale_method, crop):
|
||||
if crop == "center":
|
||||
old_width = samples.shape[3]
|
||||
old_height = samples.shape[2]
|
||||
old_aspect = old_width / old_height
|
||||
new_aspect = width / height
|
||||
x = 0
|
||||
y = 0
|
||||
if old_aspect > new_aspect:
|
||||
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||
elif old_aspect < new_aspect:
|
||||
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||
s = samples[:, :, y:old_height - y, x:old_width - x]
|
||||
else:
|
||||
s = samples
|
||||
|
||||
if upscale_method == "bislerp":
|
||||
return bislerp(s, width, height)
|
||||
elif upscale_method == "lanczos":
|
||||
return lanczos(s, width, height)
|
||||
else:
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
|
||||
|
||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None):
|
||||
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device)
|
||||
for b in range(samples.shape[0]):
|
||||
s = samples[b:b+1]
|
||||
s = samples[b:b + 1]
|
||||
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
||||
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
||||
for y in range(0, s.shape[2], tile_y - overlap):
|
||||
for x in range(0, s.shape[3], tile_x - overlap):
|
||||
x = max(0, min(s.shape[-1] - overlap, x))
|
||||
y = max(0, min(s.shape[-2] - overlap, y))
|
||||
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
|
||||
s_in = s[:, :, y:y + tile_y, x:x + tile_x]
|
||||
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
feather = round(overlap * upscale_amount)
|
||||
for t in range(feather):
|
||||
mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
|
||||
mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
|
||||
mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
|
||||
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
|
||||
out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask
|
||||
out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask
|
||||
mask[:, :, t:1 + t, :] *= ((1.0 / feather) * (t + 1))
|
||||
mask[:, :, mask.shape[2] - 1 - t: mask.shape[2] - t, :] *= ((1.0 / feather) * (t + 1))
|
||||
mask[:, :, :, t:1 + t] *= ((1.0 / feather) * (t + 1))
|
||||
mask[:, :, :, mask.shape[3] - 1 - t: mask.shape[3] - t] *= ((1.0 / feather) * (t + 1))
|
||||
out[:, :, round(y * upscale_amount):round((y + tile_y) * upscale_amount), round(x * upscale_amount):round((x + tile_x) * upscale_amount)] += ps * mask
|
||||
out_div[:, :, round(y * upscale_amount):round((y + tile_y) * upscale_amount), round(x * upscale_amount):round((x + tile_x) * upscale_amount)] += mask
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
|
||||
output[b:b+1] = out/out_div
|
||||
output[b:b + 1] = out / out_div
|
||||
return output
|
||||
|
||||
|
||||
def hijack_progress(server: ExecutorToClientProgress):
|
||||
def hook(value: float, total: float, preview_image):
|
||||
interruption.throw_exception_if_processing_interrupted()
|
||||
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
||||
def _progress_bar_update(value: float, total: float, preview_image, client_id: Optional[str] = None):
|
||||
server = current_execution_context().server
|
||||
# todo: this should really be from the context. right now the server is behaving like a context
|
||||
client_id = client_id or server.client_id
|
||||
interruption.throw_exception_if_processing_interrupted()
|
||||
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
||||
|
||||
server.send_sync("progress", progress, server.client_id)
|
||||
if preview_image is not None:
|
||||
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
|
||||
|
||||
_progress_bar_hook.hook = hook
|
||||
server.send_sync("progress", progress, client_id)
|
||||
if preview_image is not None:
|
||||
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, client_id)
|
||||
|
||||
|
||||
def set_progress_bar_enabled(enabled):
|
||||
global PROGRESS_BAR_ENABLED
|
||||
PROGRESS_BAR_ENABLED = enabled
|
||||
def set_progress_bar_enabled(enabled: bool):
|
||||
warnings.warn(
|
||||
"The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
current_execution_context().server.receive_all_progress_notifications = enabled
|
||||
pass
|
||||
|
||||
|
||||
def get_progress_bar_enabled() -> bool:
|
||||
return PROGRESS_BAR_ENABLED
|
||||
warnings.warn(
|
||||
"The global method 'get_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return current_execution_context().server.receive_all_progress_notifications
|
||||
|
||||
|
||||
class _DisabledProgressBar:
|
||||
@ -505,10 +543,8 @@ class _DisabledProgressBar:
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self, total: float):
|
||||
global _progress_bar_hook
|
||||
self.total: float = total
|
||||
self.current: float = 0.0
|
||||
self.hook = _progress_bar_hook.hook if hasattr(_progress_bar_hook, "hook") else None
|
||||
|
||||
def update_absolute(self, value, total=None, preview=None):
|
||||
if total is not None:
|
||||
@ -516,8 +552,7 @@ class ProgressBar:
|
||||
if value > self.total:
|
||||
value = self.total
|
||||
self.current = value
|
||||
if self.hook is not None:
|
||||
self.hook(self.current, self.total, preview)
|
||||
_progress_bar_update(self.current, self.total, preview)
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
@ -556,8 +591,8 @@ def comfy_tqdm():
|
||||
|
||||
@contextmanager
|
||||
def comfy_progress(total: float) -> ProgressBar:
|
||||
global PROGRESS_BAR_ENABLED
|
||||
if PROGRESS_BAR_ENABLED:
|
||||
ctx = current_execution_context()
|
||||
if ctx.server.receive_all_progress_notifications:
|
||||
yield ProgressBar(total)
|
||||
else:
|
||||
yield _DisabledProgressBar()
|
||||
|
||||
@ -2,6 +2,7 @@ import comfy.sampler_names
|
||||
from comfy import samplers
|
||||
from comfy import model_management
|
||||
from comfy import sample
|
||||
from comfy.execution_context import current_execution_context
|
||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||
from comfy.cmd import latent_preview
|
||||
import torch
|
||||
@ -416,7 +417,7 @@ class SamplerCustom:
|
||||
x0_output = {}
|
||||
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
|
||||
|
||||
disable_pbar = not utils.PROGRESS_BAR_ENABLED
|
||||
disable_pbar = not current_execution_context().server.receive_all_progress_notifications
|
||||
samples = sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
|
||||
|
||||
out = latent.copy()
|
||||
@ -570,7 +571,7 @@ class SamplerCustomAdvanced:
|
||||
x0_output = {}
|
||||
callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output)
|
||||
|
||||
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||
disable_pbar = not current_execution_context().server.receive_all_progress_notifications
|
||||
samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user