mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 23:30:16 +08:00
Progress bar hooks, via the server, are now set via a context. This will be used in other places too.
This commit is contained in:
parent
d6c374942e
commit
881258acb6
@ -21,6 +21,7 @@ from .. import model_management
|
|||||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress
|
from ..component_model.executor_types import ExecutorToClientProgress
|
||||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||||
|
from ..execution_context import new_execution_context, ExecutionContext
|
||||||
from ..nodes.package import import_all_nodes_in_workspace
|
from ..nodes.package import import_all_nodes_in_workspace
|
||||||
from ..nodes.package_typing import ExportedNodes
|
from ..nodes.package_typing import ExportedNodes
|
||||||
|
|
||||||
@ -153,9 +154,11 @@ def recursive_execute(server: ExecutorToClientProgress,
|
|||||||
prompt_id,
|
prompt_id,
|
||||||
outputs_ui,
|
outputs_ui,
|
||||||
object_storage):
|
object_storage):
|
||||||
|
span = get_current_span()
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
inputs = prompt[unique_id]['inputs']
|
inputs = prompt[unique_id]['inputs']
|
||||||
class_type = prompt[unique_id]['class_type']
|
class_type = prompt[unique_id]['class_type']
|
||||||
|
span.set_attribute("class_type", class_type)
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
if unique_id in outputs:
|
if unique_id in outputs:
|
||||||
return (True, None, None)
|
return (True, None, None)
|
||||||
@ -374,6 +377,10 @@ class PromptExecutor:
|
|||||||
del d
|
del d
|
||||||
|
|
||||||
def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
|
def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
|
||||||
|
with new_execution_context(ExecutionContext(self.server)):
|
||||||
|
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)
|
||||||
|
|
||||||
|
def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
|
||||||
if execute_outputs is None:
|
if execute_outputs is None:
|
||||||
execute_outputs = []
|
execute_outputs = []
|
||||||
if extra_data is None:
|
if extra_data is None:
|
||||||
|
|||||||
@ -9,7 +9,6 @@ import time
|
|||||||
|
|
||||||
# main_pre must be the earliest import since it suppresses some spurious warnings
|
# main_pre must be the earliest import since it suppresses some spurious warnings
|
||||||
from .main_pre import args
|
from .main_pre import args
|
||||||
from ..utils import hijack_progress
|
|
||||||
from .extra_model_paths import load_extra_path_config
|
from .extra_model_paths import load_extra_path_config
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
from ..analytics.analytics import initialize_event_tracking
|
from ..analytics.analytics import initialize_event_tracking
|
||||||
@ -167,7 +166,6 @@ async def main():
|
|||||||
server.prompt_queue = q
|
server.prompt_queue = q
|
||||||
|
|
||||||
server.add_routes()
|
server.add_routes()
|
||||||
hijack_progress(server)
|
|
||||||
cuda_malloc_warning()
|
cuda_malloc_warning()
|
||||||
|
|
||||||
# in a distributed setting, the default prompt worker will not be able to send execution events via the websocket
|
# in a distributed setting, the default prompt worker will not be able to send execution events via the websocket
|
||||||
|
|||||||
@ -102,6 +102,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
self.number: int = 0
|
self.number: int = 0
|
||||||
self.port: int = 8188
|
self.port: int = 8188
|
||||||
self._external_address: Optional[str] = None
|
self._external_address: Optional[str] = None
|
||||||
|
self.receive_all_progress_notifications = True
|
||||||
|
|
||||||
middlewares = [cache_control]
|
middlewares = [cache_control]
|
||||||
if args.enable_cors_header:
|
if args.enable_cors_header:
|
||||||
|
|||||||
@ -61,6 +61,7 @@ class ExecutorToClientProgress(Protocol):
|
|||||||
client_id: Optional[str]
|
client_id: Optional[str]
|
||||||
last_node_id: Optional[str]
|
last_node_id: Optional[str]
|
||||||
last_prompt_id: Optional[str]
|
last_prompt_id: Optional[str]
|
||||||
|
receive_all_progress_notifications: Optional[bool]
|
||||||
|
|
||||||
def send_sync(self,
|
def send_sync(self,
|
||||||
event: SendSyncEvent,
|
event: SendSyncEvent,
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from aio_pika.patterns import RPC
|
|||||||
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress, \
|
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress, \
|
||||||
UnencodedPreviewImageMessage
|
UnencodedPreviewImageMessage
|
||||||
from ..component_model.queue_types import BinaryEventTypes
|
from ..component_model.queue_types import BinaryEventTypes
|
||||||
from ..utils import hijack_progress
|
|
||||||
|
|
||||||
|
|
||||||
async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[str] = None,
|
async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[str] = None,
|
||||||
@ -38,8 +37,7 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
|||||||
self.node_id = None
|
self.node_id = None
|
||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
self.last_prompt_id = None
|
self.last_prompt_id = None
|
||||||
if receive_all_progress_notifications:
|
self.receive_all_progress_notifications = receive_all_progress_notifications
|
||||||
hijack_progress(self)
|
|
||||||
|
|
||||||
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
|
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
|
||||||
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical
|
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical
|
||||||
|
|||||||
@ -16,6 +16,7 @@ class ServerStub(ExecutorToClientProgress):
|
|||||||
self.client_id = str(uuid.uuid4())
|
self.client_id = str(uuid.uuid4())
|
||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
self.last_prompt_id = None
|
self.last_prompt_id = None
|
||||||
|
self.receive_all_progress_notifications = False
|
||||||
|
|
||||||
def send_sync(self,
|
def send_sync(self,
|
||||||
event: Literal["status", "executing"] | BinaryEventTypes | str | None,
|
event: Literal["status", "executing"] | BinaryEventTypes | str | None,
|
||||||
|
|||||||
30
comfy/execution_context.py
Normal file
30
comfy/execution_context.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from comfy.component_model.executor_types import ExecutorToClientProgress
|
||||||
|
from comfy.distributed.server_stub import ServerStub
|
||||||
|
|
||||||
|
_current_context = ContextVar("comfyui_execution_context")
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionContext(NamedTuple):
|
||||||
|
server: ExecutorToClientProgress
|
||||||
|
|
||||||
|
|
||||||
|
_empty_execution_context = ExecutionContext(ServerStub())
|
||||||
|
|
||||||
|
|
||||||
|
def current_execution_context() -> ExecutionContext:
|
||||||
|
return _current_context.get()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def new_execution_context(ctx: ExecutionContext):
|
||||||
|
token = _current_context.set(ctx)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_current_context.reset(token)
|
||||||
@ -23,6 +23,7 @@ from .. import model_management
|
|||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
|
|
||||||
from ..cmd import folder_paths, latent_preview
|
from ..cmd import folder_paths, latent_preview
|
||||||
|
from ..execution_context import current_execution_context
|
||||||
from ..images import open_image
|
from ..images import open_image
|
||||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES
|
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES
|
||||||
from ..nodes.common import MAX_RESOLUTION
|
from ..nodes.common import MAX_RESOLUTION
|
||||||
@ -1297,7 +1298,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||||||
noise_mask = latent["noise_mask"]
|
noise_mask = latent["noise_mask"]
|
||||||
|
|
||||||
callback = latent_preview.prepare_callback(model, steps)
|
callback = latent_preview.prepare_callback(model, steps)
|
||||||
disable_pbar = not utils.PROGRESS_BAR_ENABLED
|
disable_pbar = not current_execution_context().server.receive_all_progress_notifications
|
||||||
samples = sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
samples = sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||||
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
|
|||||||
@ -162,7 +162,8 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
|
|||||||
# found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things,
|
# found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things,
|
||||||
# the way community custom nodes is pretty radioactive
|
# the way community custom nodes is pretty radioactive
|
||||||
from ..cmd import cuda_malloc, folder_paths, latent_preview
|
from ..cmd import cuda_malloc, folder_paths, latent_preview
|
||||||
for module in (cuda_malloc, folder_paths, latent_preview):
|
from .. import node_helpers
|
||||||
|
for module in (cuda_malloc, folder_paths, latent_preview, node_helpers):
|
||||||
module_short_name = module.__name__.split(".")[-1]
|
module_short_name = module.__name__.split(".")[-1]
|
||||||
sys.modules[module_short_name] = module
|
sys.modules[module_short_name] = module
|
||||||
sys.modules['nodes'] = base_nodes
|
sys.modules['nodes'] = base_nodes
|
||||||
|
|||||||
221
comfy/utils.py
221
comfy/utils.py
@ -1,26 +1,34 @@
|
|||||||
import os.path
|
import logging
|
||||||
import threading
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import math
|
import math
|
||||||
|
import os.path
|
||||||
import struct
|
import struct
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import safetensors.torch
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from . import checkpoint_pickle, interruption
|
from . import checkpoint_pickle, interruption
|
||||||
import safetensors.torch
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .component_model.executor_types import ExecutorToClientProgress
|
|
||||||
from .component_model.queue_types import BinaryEventTypes
|
from .component_model.queue_types import BinaryEventTypes
|
||||||
|
from .execution_context import current_execution_context
|
||||||
PROGRESS_BAR_ENABLED = True
|
|
||||||
_progress_bar_hook = threading.local()
|
|
||||||
|
|
||||||
|
|
||||||
|
# deprecate PROGRESS_BAR_ENABLED
|
||||||
|
def _get_progress_bar_enabled():
|
||||||
|
warnings.warn(
|
||||||
|
"The global variable 'PROGRESS_BAR_ENABLED' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2
|
||||||
|
)
|
||||||
|
return current_execution_context().server.receive_all_progress_notifications
|
||||||
|
|
||||||
|
|
||||||
|
setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled))
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||||
if device is None:
|
if device is None:
|
||||||
@ -46,12 +54,14 @@ def load_torch_file(ckpt, safe_load=False, device=None):
|
|||||||
sd = pl_sd
|
sd = pl_sd
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
|
||||||
def save_torch_file(sd, ckpt, metadata=None):
|
def save_torch_file(sd, ckpt, metadata=None):
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
safetensors.torch.save_file(sd, ckpt, metadata=metadata)
|
safetensors.torch.save_file(sd, ckpt, metadata=metadata)
|
||||||
else:
|
else:
|
||||||
safetensors.torch.save_file(sd, ckpt)
|
safetensors.torch.save_file(sd, ckpt)
|
||||||
|
|
||||||
|
|
||||||
def calculate_parameters(sd, prefix=""):
|
def calculate_parameters(sd, prefix=""):
|
||||||
params = 0
|
params = 0
|
||||||
for k in sd.keys():
|
for k in sd.keys():
|
||||||
@ -59,12 +69,14 @@ def calculate_parameters(sd, prefix=""):
|
|||||||
params += sd[k].nelement()
|
params += sd[k].nelement()
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def state_dict_key_replace(state_dict, keys_to_replace):
|
def state_dict_key_replace(state_dict, keys_to_replace):
|
||||||
for x in keys_to_replace:
|
for x in keys_to_replace:
|
||||||
if x in state_dict:
|
if x in state_dict:
|
||||||
state_dict[keys_to_replace[x]] = state_dict.pop(x)
|
state_dict[keys_to_replace[x]] = state_dict.pop(x)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
|
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
|
||||||
if filter_keys:
|
if filter_keys:
|
||||||
out = {}
|
out = {}
|
||||||
@ -115,10 +127,11 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
|||||||
for x in range(3):
|
for x in range(3):
|
||||||
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
||||||
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
||||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
sd[k_to] = weights[shape_from * x:shape_from * (x + 1)]
|
||||||
|
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
|
||||||
def clip_text_transformers_convert(sd, prefix_from, prefix_to):
|
def clip_text_transformers_convert(sd, prefix_from, prefix_to):
|
||||||
sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32)
|
sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32)
|
||||||
|
|
||||||
@ -200,6 +213,7 @@ UNET_MAP_BASIC = {
|
|||||||
("time_embed.2.bias", "time_embedding.linear_2.bias")
|
("time_embed.2.bias", "time_embedding.linear_2.bias")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def unet_to_diffusers(unet_config):
|
def unet_to_diffusers(unet_config):
|
||||||
if "num_res_blocks" not in unet_config:
|
if "num_res_blocks" not in unet_config:
|
||||||
return {}
|
return {}
|
||||||
@ -266,6 +280,7 @@ def unet_to_diffusers(unet_config):
|
|||||||
|
|
||||||
return diffusers_unet_map
|
return diffusers_unet_map
|
||||||
|
|
||||||
|
|
||||||
def repeat_to_batch_size(tensor, batch_size):
|
def repeat_to_batch_size(tensor, batch_size):
|
||||||
if tensor.shape[0] > batch_size:
|
if tensor.shape[0] > batch_size:
|
||||||
return tensor[:batch_size]
|
return tensor[:batch_size]
|
||||||
@ -273,6 +288,7 @@ def repeat_to_batch_size(tensor, batch_size):
|
|||||||
return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size]
|
return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size]
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def resize_to_batch_size(tensor, batch_size):
|
def resize_to_batch_size(tensor, batch_size):
|
||||||
in_batch_size = tensor.shape[0]
|
in_batch_size = tensor.shape[0]
|
||||||
if in_batch_size == batch_size:
|
if in_batch_size == batch_size:
|
||||||
@ -293,13 +309,15 @@ def resize_to_batch_size(tensor, batch_size):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def convert_sd_to(state_dict, dtype):
|
def convert_sd_to(state_dict, dtype):
|
||||||
keys = list(state_dict.keys())
|
keys = list(state_dict.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
state_dict[k] = state_dict[k].to(dtype)
|
state_dict[k] = state_dict[k].to(dtype)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
|
||||||
|
def safetensors_header(safetensors_path, max_size=100 * 1024 * 1024):
|
||||||
with open(safetensors_path, "rb") as f:
|
with open(safetensors_path, "rb") as f:
|
||||||
header = f.read(8)
|
header = f.read(8)
|
||||||
length_of_header = struct.unpack('<Q', header)[0]
|
length_of_header = struct.unpack('<Q', header)[0]
|
||||||
@ -307,6 +325,7 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
|||||||
return None
|
return None
|
||||||
return f.read(length_of_header)
|
return f.read(length_of_header)
|
||||||
|
|
||||||
|
|
||||||
def set_attr(obj, attr, value):
|
def set_attr(obj, attr, value):
|
||||||
attrs = attr.split(".")
|
attrs = attr.split(".")
|
||||||
for name in attrs[:-1]:
|
for name in attrs[:-1]:
|
||||||
@ -315,9 +334,11 @@ def set_attr(obj, attr, value):
|
|||||||
setattr(obj, attrs[-1], value)
|
setattr(obj, attrs[-1], value)
|
||||||
return prev
|
return prev
|
||||||
|
|
||||||
|
|
||||||
def set_attr_param(obj, attr, value):
|
def set_attr_param(obj, attr, value):
|
||||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||||
|
|
||||||
|
|
||||||
def copy_to_param(obj, attr, value):
|
def copy_to_param(obj, attr, value):
|
||||||
# inplace update tensor instead of replacing it
|
# inplace update tensor instead of replacing it
|
||||||
attrs = attr.split(".")
|
attrs = attr.split(".")
|
||||||
@ -326,88 +347,91 @@ def copy_to_param(obj, attr, value):
|
|||||||
prev = getattr(obj, attrs[-1])
|
prev = getattr(obj, attrs[-1])
|
||||||
prev.data.copy_(value)
|
prev.data.copy_(value)
|
||||||
|
|
||||||
|
|
||||||
def get_attr(obj, attr):
|
def get_attr(obj, attr):
|
||||||
attrs = attr.split(".")
|
attrs = attr.split(".")
|
||||||
for name in attrs:
|
for name in attrs:
|
||||||
obj = getattr(obj, name)
|
obj = getattr(obj, name)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def bislerp(samples, width, height):
|
def bislerp(samples, width, height):
|
||||||
def slerp(b1, b2, r):
|
def slerp(b1, b2, r):
|
||||||
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
||||||
|
|
||||||
c = b1.shape[-1]
|
c = b1.shape[-1]
|
||||||
|
|
||||||
#norms
|
# norms
|
||||||
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
||||||
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
||||||
|
|
||||||
#normalize
|
# normalize
|
||||||
b1_normalized = b1 / b1_norms
|
b1_normalized = b1 / b1_norms
|
||||||
b2_normalized = b2 / b2_norms
|
b2_normalized = b2 / b2_norms
|
||||||
|
|
||||||
#zero when norms are zero
|
# zero when norms are zero
|
||||||
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
|
b1_normalized[b1_norms.expand(-1, c) == 0.0] = 0.0
|
||||||
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
|
b2_normalized[b2_norms.expand(-1, c) == 0.0] = 0.0
|
||||||
|
|
||||||
#slerp
|
# slerp
|
||||||
dot = (b1_normalized*b2_normalized).sum(1)
|
dot = (b1_normalized * b2_normalized).sum(1)
|
||||||
omega = torch.acos(dot)
|
omega = torch.acos(dot)
|
||||||
so = torch.sin(omega)
|
so = torch.sin(omega)
|
||||||
|
|
||||||
#technically not mathematically correct, but more pleasing?
|
# technically not mathematically correct, but more pleasing?
|
||||||
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
|
res = (torch.sin((1.0 - r.squeeze(1)) * omega) / so).unsqueeze(1) * b1_normalized + (torch.sin(r.squeeze(1) * omega) / so).unsqueeze(1) * b2_normalized
|
||||||
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
|
res *= (b1_norms * (1.0 - r) + b2_norms * r).expand(-1, c)
|
||||||
|
|
||||||
#edge cases for same or polar opposites
|
# edge cases for same or polar opposites
|
||||||
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
||||||
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
res[dot < 1e-5 - 1] = (b1 * (1.0 - r) + b2 * r)[dot < 1e-5 - 1]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def generate_bilinear_data(length_old, length_new, device):
|
def generate_bilinear_data(length_old, length_new, device):
|
||||||
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
|
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1))
|
||||||
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
||||||
ratios = coords_1 - coords_1.floor()
|
ratios = coords_1 - coords_1.floor()
|
||||||
coords_1 = coords_1.to(torch.int64)
|
coords_1 = coords_1.to(torch.int64)
|
||||||
|
|
||||||
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
|
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1)) + 1
|
||||||
coords_2[:,:,:,-1] -= 1
|
coords_2[:, :, :, -1] -= 1
|
||||||
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
||||||
coords_2 = coords_2.to(torch.int64)
|
coords_2 = coords_2.to(torch.int64)
|
||||||
return ratios, coords_1, coords_2
|
return ratios, coords_1, coords_2
|
||||||
|
|
||||||
orig_dtype = samples.dtype
|
orig_dtype = samples.dtype
|
||||||
samples = samples.float()
|
samples = samples.float()
|
||||||
n,c,h,w = samples.shape
|
n, c, h, w = samples.shape
|
||||||
h_new, w_new = (height, width)
|
h_new, w_new = (height, width)
|
||||||
|
|
||||||
#linear w
|
# linear w
|
||||||
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
|
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
|
||||||
coords_1 = coords_1.expand((n, c, h, -1))
|
coords_1 = coords_1.expand((n, c, h, -1))
|
||||||
coords_2 = coords_2.expand((n, c, h, -1))
|
coords_2 = coords_2.expand((n, c, h, -1))
|
||||||
ratios = ratios.expand((n, 1, h, -1))
|
ratios = ratios.expand((n, 1, h, -1))
|
||||||
|
|
||||||
pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c))
|
pass_1 = samples.gather(-1, coords_1).movedim(1, -1).reshape((-1, c))
|
||||||
pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c))
|
pass_2 = samples.gather(-1, coords_2).movedim(1, -1).reshape((-1, c))
|
||||||
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
ratios = ratios.movedim(1, -1).reshape((-1, 1))
|
||||||
|
|
||||||
result = slerp(pass_1, pass_2, ratios)
|
result = slerp(pass_1, pass_2, ratios)
|
||||||
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
|
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
|
||||||
|
|
||||||
#linear h
|
# linear h
|
||||||
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
|
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
|
||||||
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
coords_1 = coords_1.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new))
|
||||||
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
coords_2 = coords_2.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new))
|
||||||
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
|
ratios = ratios.reshape((1, 1, -1, 1)).expand((n, 1, -1, w_new))
|
||||||
|
|
||||||
pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c))
|
pass_1 = result.gather(-2, coords_1).movedim(1, -1).reshape((-1, c))
|
||||||
pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c))
|
pass_2 = result.gather(-2, coords_2).movedim(1, -1).reshape((-1, c))
|
||||||
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
ratios = ratios.movedim(1, -1).reshape((-1, 1))
|
||||||
|
|
||||||
result = slerp(pass_1, pass_2, ratios)
|
result = slerp(pass_1, pass_2, ratios)
|
||||||
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
||||||
return result.to(orig_dtype)
|
return result.to(orig_dtype)
|
||||||
|
|
||||||
|
|
||||||
def lanczos(samples, width, height):
|
def lanczos(samples, width, height):
|
||||||
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||||
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||||
@ -415,81 +439,95 @@ def lanczos(samples, width, height):
|
|||||||
result = torch.stack(images)
|
result = torch.stack(images)
|
||||||
return result.to(samples.device, samples.dtype)
|
return result.to(samples.device, samples.dtype)
|
||||||
|
|
||||||
def common_upscale(samples, width, height, upscale_method, crop):
|
|
||||||
if crop == "center":
|
|
||||||
old_width = samples.shape[3]
|
|
||||||
old_height = samples.shape[2]
|
|
||||||
old_aspect = old_width / old_height
|
|
||||||
new_aspect = width / height
|
|
||||||
x = 0
|
|
||||||
y = 0
|
|
||||||
if old_aspect > new_aspect:
|
|
||||||
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
|
||||||
elif old_aspect < new_aspect:
|
|
||||||
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
|
||||||
s = samples[:,:,y:old_height-y,x:old_width-x]
|
|
||||||
else:
|
|
||||||
s = samples
|
|
||||||
|
|
||||||
if upscale_method == "bislerp":
|
def common_upscale(samples, width, height, upscale_method, crop):
|
||||||
return bislerp(s, width, height)
|
if crop == "center":
|
||||||
elif upscale_method == "lanczos":
|
old_width = samples.shape[3]
|
||||||
return lanczos(s, width, height)
|
old_height = samples.shape[2]
|
||||||
else:
|
old_aspect = old_width / old_height
|
||||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
new_aspect = width / height
|
||||||
|
x = 0
|
||||||
|
y = 0
|
||||||
|
if old_aspect > new_aspect:
|
||||||
|
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||||
|
elif old_aspect < new_aspect:
|
||||||
|
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||||
|
s = samples[:, :, y:old_height - y, x:old_width - x]
|
||||||
|
else:
|
||||||
|
s = samples
|
||||||
|
|
||||||
|
if upscale_method == "bislerp":
|
||||||
|
return bislerp(s, width, height)
|
||||||
|
elif upscale_method == "lanczos":
|
||||||
|
return lanczos(s, width, height)
|
||||||
|
else:
|
||||||
|
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||||
|
|
||||||
|
|
||||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None):
|
||||||
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device)
|
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device)
|
||||||
for b in range(samples.shape[0]):
|
for b in range(samples.shape[0]):
|
||||||
s = samples[b:b+1]
|
s = samples[b:b + 1]
|
||||||
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
||||||
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
||||||
for y in range(0, s.shape[2], tile_y - overlap):
|
for y in range(0, s.shape[2], tile_y - overlap):
|
||||||
for x in range(0, s.shape[3], tile_x - overlap):
|
for x in range(0, s.shape[3], tile_x - overlap):
|
||||||
x = max(0, min(s.shape[-1] - overlap, x))
|
x = max(0, min(s.shape[-1] - overlap, x))
|
||||||
y = max(0, min(s.shape[-2] - overlap, y))
|
y = max(0, min(s.shape[-2] - overlap, y))
|
||||||
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
|
s_in = s[:, :, y:y + tile_y, x:x + tile_x]
|
||||||
|
|
||||||
ps = function(s_in).to(output_device)
|
ps = function(s_in).to(output_device)
|
||||||
mask = torch.ones_like(ps)
|
mask = torch.ones_like(ps)
|
||||||
feather = round(overlap * upscale_amount)
|
feather = round(overlap * upscale_amount)
|
||||||
for t in range(feather):
|
for t in range(feather):
|
||||||
mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
|
mask[:, :, t:1 + t, :] *= ((1.0 / feather) * (t + 1))
|
||||||
mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
|
mask[:, :, mask.shape[2] - 1 - t: mask.shape[2] - t, :] *= ((1.0 / feather) * (t + 1))
|
||||||
mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
|
mask[:, :, :, t:1 + t] *= ((1.0 / feather) * (t + 1))
|
||||||
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
|
mask[:, :, :, mask.shape[3] - 1 - t: mask.shape[3] - t] *= ((1.0 / feather) * (t + 1))
|
||||||
out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask
|
out[:, :, round(y * upscale_amount):round((y + tile_y) * upscale_amount), round(x * upscale_amount):round((x + tile_x) * upscale_amount)] += ps * mask
|
||||||
out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask
|
out_div[:, :, round(y * upscale_amount):round((y + tile_y) * upscale_amount), round(x * upscale_amount):round((x + tile_x) * upscale_amount)] += mask
|
||||||
if pbar is not None:
|
if pbar is not None:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
output[b:b+1] = out/out_div
|
output[b:b + 1] = out / out_div
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def hijack_progress(server: ExecutorToClientProgress):
|
def _progress_bar_update(value: float, total: float, preview_image, client_id: Optional[str] = None):
|
||||||
def hook(value: float, total: float, preview_image):
|
server = current_execution_context().server
|
||||||
interruption.throw_exception_if_processing_interrupted()
|
# todo: this should really be from the context. right now the server is behaving like a context
|
||||||
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
client_id = client_id or server.client_id
|
||||||
|
interruption.throw_exception_if_processing_interrupted()
|
||||||
|
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
||||||
|
|
||||||
server.send_sync("progress", progress, server.client_id)
|
server.send_sync("progress", progress, client_id)
|
||||||
if preview_image is not None:
|
if preview_image is not None:
|
||||||
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
|
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, client_id)
|
||||||
|
|
||||||
_progress_bar_hook.hook = hook
|
|
||||||
|
|
||||||
|
|
||||||
def set_progress_bar_enabled(enabled):
|
def set_progress_bar_enabled(enabled: bool):
|
||||||
global PROGRESS_BAR_ENABLED
|
warnings.warn(
|
||||||
PROGRESS_BAR_ENABLED = enabled
|
"The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2
|
||||||
|
)
|
||||||
|
|
||||||
|
current_execution_context().server.receive_all_progress_notifications = enabled
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_progress_bar_enabled() -> bool:
|
def get_progress_bar_enabled() -> bool:
|
||||||
return PROGRESS_BAR_ENABLED
|
warnings.warn(
|
||||||
|
"The global method 'get_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2
|
||||||
|
)
|
||||||
|
return current_execution_context().server.receive_all_progress_notifications
|
||||||
|
|
||||||
|
|
||||||
class _DisabledProgressBar:
|
class _DisabledProgressBar:
|
||||||
@ -505,10 +543,8 @@ class _DisabledProgressBar:
|
|||||||
|
|
||||||
class ProgressBar:
|
class ProgressBar:
|
||||||
def __init__(self, total: float):
|
def __init__(self, total: float):
|
||||||
global _progress_bar_hook
|
|
||||||
self.total: float = total
|
self.total: float = total
|
||||||
self.current: float = 0.0
|
self.current: float = 0.0
|
||||||
self.hook = _progress_bar_hook.hook if hasattr(_progress_bar_hook, "hook") else None
|
|
||||||
|
|
||||||
def update_absolute(self, value, total=None, preview=None):
|
def update_absolute(self, value, total=None, preview=None):
|
||||||
if total is not None:
|
if total is not None:
|
||||||
@ -516,8 +552,7 @@ class ProgressBar:
|
|||||||
if value > self.total:
|
if value > self.total:
|
||||||
value = self.total
|
value = self.total
|
||||||
self.current = value
|
self.current = value
|
||||||
if self.hook is not None:
|
_progress_bar_update(self.current, self.total, preview)
|
||||||
self.hook(self.current, self.total, preview)
|
|
||||||
|
|
||||||
def update(self, value):
|
def update(self, value):
|
||||||
self.update_absolute(self.current + value)
|
self.update_absolute(self.current + value)
|
||||||
@ -556,8 +591,8 @@ def comfy_tqdm():
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def comfy_progress(total: float) -> ProgressBar:
|
def comfy_progress(total: float) -> ProgressBar:
|
||||||
global PROGRESS_BAR_ENABLED
|
ctx = current_execution_context()
|
||||||
if PROGRESS_BAR_ENABLED:
|
if ctx.server.receive_all_progress_notifications:
|
||||||
yield ProgressBar(total)
|
yield ProgressBar(total)
|
||||||
else:
|
else:
|
||||||
yield _DisabledProgressBar()
|
yield _DisabledProgressBar()
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import comfy.sampler_names
|
|||||||
from comfy import samplers
|
from comfy import samplers
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
from comfy import sample
|
from comfy import sample
|
||||||
|
from comfy.execution_context import current_execution_context
|
||||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||||
from comfy.cmd import latent_preview
|
from comfy.cmd import latent_preview
|
||||||
import torch
|
import torch
|
||||||
@ -416,7 +417,7 @@ class SamplerCustom:
|
|||||||
x0_output = {}
|
x0_output = {}
|
||||||
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
|
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
|
||||||
|
|
||||||
disable_pbar = not utils.PROGRESS_BAR_ENABLED
|
disable_pbar = not current_execution_context().server.receive_all_progress_notifications
|
||||||
samples = sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
|
samples = sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
|
||||||
|
|
||||||
out = latent.copy()
|
out = latent.copy()
|
||||||
@ -570,7 +571,7 @@ class SamplerCustomAdvanced:
|
|||||||
x0_output = {}
|
x0_output = {}
|
||||||
callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output)
|
callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output)
|
||||||
|
|
||||||
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
disable_pbar = not current_execution_context().server.receive_all_progress_notifications
|
||||||
samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
|
samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
|
||||||
samples = samples.to(comfy.model_management.intermediate_device())
|
samples = samples.to(comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user