from __future__ import annotations import contextlib import itertools import logging import math import os import os.path import random import struct import sys import warnings from contextlib import contextmanager from typing import Optional, Any import numpy as np import safetensors.torch import torch from PIL import Image from tqdm import tqdm from . import checkpoint_pickle, interruption from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage from .component_model.queue_types import BinaryEventTypes from .execution_context import current_execution_context # deprecate PROGRESS_BAR_ENABLED def _get_progress_bar_enabled(): warnings.warn( "The global variable 'PROGRESS_BAR_ENABLED' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.", DeprecationWarning, stacklevel=2 ) return current_execution_context().server.receive_all_progress_notifications setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled)) def load_torch_file(ckpt, safe_load=False, device=None): if device is None: device = torch.device("cpu") if ckpt is None: raise FileNotFoundError("the checkpoint was not found") if ckpt.lower().endswith(".safetensors"): sd = safetensors.torch.load_file(ckpt, device=device.type) else: if safe_load: if not 'weights_only' in torch.load.__code__.co_varnames: logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") safe_load = False if safe_load: pl_sd = torch.load(ckpt, map_location=device, weights_only=True) else: pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle) if "global_step" in pl_sd: logging.debug(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: sd = pl_sd return sd def save_torch_file(sd, ckpt, metadata=None): if metadata is not None: safetensors.torch.save_file(sd, ckpt, metadata=metadata) else: safetensors.torch.save_file(sd, ckpt) def calculate_parameters(sd, prefix=""): params = 0 for k in sd.keys(): if k.startswith(prefix): params += sd[k].nelement() return params def state_dict_key_replace(state_dict, keys_to_replace): for x in keys_to_replace: if x in state_dict: state_dict[keys_to_replace[x]] = state_dict.pop(x) return state_dict def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): if filter_keys: out = {} else: out = state_dict for rp in replace_prefix: replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) for x in replace: w = state_dict.pop(x[0]) out[x[1]] = w return out def transformers_convert(sd, prefix_from, prefix_to, number): keys_to_replace = { "{}positional_embedding": "{}embeddings.position_embedding.weight", "{}token_embedding.weight": "{}embeddings.token_embedding.weight", "{}ln_final.weight": "{}final_layer_norm.weight", "{}ln_final.bias": "{}final_layer_norm.bias", } for k in keys_to_replace: x = k.format(prefix_from) if x in sd: sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x) resblock_to_replace = { "ln_1": "layer_norm1", "ln_2": "layer_norm2", "mlp.c_fc": "mlp.fc1", "mlp.c_proj": "mlp.fc2", "attn.out_proj": "self_attn.out_proj", } for resblock in range(number): for x in resblock_to_replace: for y in ["weight", "bias"]: k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y) k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y) if k in sd: sd[k_to] = sd.pop(k) for y in ["weight", "bias"]: k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y) if k_from in sd: weights = sd.pop(k_from) shape_from = weights.shape[0] // 3 for x in range(3): p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) sd[k_to] = weights[shape_from * x:shape_from * (x + 1)] return sd def clip_text_transformers_convert(sd, prefix_from, prefix_to): sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32) tp = "{}text_projection.weight".format(prefix_from) if tp in sd: sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp) tp = "{}text_projection".format(prefix_from) if tp in sd: sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous() return sd UNET_MAP_ATTENTIONS = { "proj_in.weight", "proj_in.bias", "proj_out.weight", "proj_out.bias", "norm.weight", "norm.bias", } TRANSFORMER_BLOCKS = { "norm1.weight", "norm1.bias", "norm2.weight", "norm2.bias", "norm3.weight", "norm3.bias", "attn1.to_q.weight", "attn1.to_k.weight", "attn1.to_v.weight", "attn1.to_out.0.weight", "attn1.to_out.0.bias", "attn2.to_q.weight", "attn2.to_k.weight", "attn2.to_v.weight", "attn2.to_out.0.weight", "attn2.to_out.0.bias", "ff.net.0.proj.weight", "ff.net.0.proj.bias", "ff.net.2.weight", "ff.net.2.bias", } UNET_MAP_RESNET = { "in_layers.2.weight": "conv1.weight", "in_layers.2.bias": "conv1.bias", "emb_layers.1.weight": "time_emb_proj.weight", "emb_layers.1.bias": "time_emb_proj.bias", "out_layers.3.weight": "conv2.weight", "out_layers.3.bias": "conv2.bias", "skip_connection.weight": "conv_shortcut.weight", "skip_connection.bias": "conv_shortcut.bias", "in_layers.0.weight": "norm1.weight", "in_layers.0.bias": "norm1.bias", "out_layers.0.weight": "norm2.weight", "out_layers.0.bias": "norm2.bias", } UNET_MAP_BASIC = { ("label_emb.0.0.weight", "class_embedding.linear_1.weight"), ("label_emb.0.0.bias", "class_embedding.linear_1.bias"), ("label_emb.0.2.weight", "class_embedding.linear_2.weight"), ("label_emb.0.2.bias", "class_embedding.linear_2.bias"), ("label_emb.0.0.weight", "add_embedding.linear_1.weight"), ("label_emb.0.0.bias", "add_embedding.linear_1.bias"), ("label_emb.0.2.weight", "add_embedding.linear_2.weight"), ("label_emb.0.2.bias", "add_embedding.linear_2.bias"), ("input_blocks.0.0.weight", "conv_in.weight"), ("input_blocks.0.0.bias", "conv_in.bias"), ("out.0.weight", "conv_norm_out.weight"), ("out.0.bias", "conv_norm_out.bias"), ("out.2.weight", "conv_out.weight"), ("out.2.bias", "conv_out.bias"), ("time_embed.0.weight", "time_embedding.linear_1.weight"), ("time_embed.0.bias", "time_embedding.linear_1.bias"), ("time_embed.2.weight", "time_embedding.linear_2.weight"), ("time_embed.2.bias", "time_embedding.linear_2.bias") } def unet_to_diffusers(unet_config): if "num_res_blocks" not in unet_config: return {} num_res_blocks = unet_config["num_res_blocks"] channel_mult = unet_config["channel_mult"] transformer_depth = unet_config["transformer_depth"][:] transformer_depth_output = unet_config["transformer_depth_output"][:] num_blocks = len(channel_mult) transformers_mid = unet_config.get("transformer_depth_middle", None) diffusers_unet_map = {} for x in range(num_blocks): n = 1 + (num_res_blocks[x] + 1) * x for i in range(num_res_blocks[x]): for b in UNET_MAP_RESNET: diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b) num_transformers = transformer_depth.pop(0) if num_transformers > 0: for b in UNET_MAP_ATTENTIONS: diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b) for t in range(num_transformers): for b in TRANSFORMER_BLOCKS: diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) n += 1 for k in ["weight", "bias"]: diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k) i = 0 for b in UNET_MAP_ATTENTIONS: diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b) for t in range(transformers_mid): for b in TRANSFORMER_BLOCKS: diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b) for i, n in enumerate([0, 2]): for b in UNET_MAP_RESNET: diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b) num_res_blocks = list(reversed(num_res_blocks)) for x in range(num_blocks): n = (num_res_blocks[x] + 1) * x l = num_res_blocks[x] + 1 for i in range(l): c = 0 for b in UNET_MAP_RESNET: diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b) c += 1 num_transformers = transformer_depth_output.pop() if num_transformers > 0: c += 1 for b in UNET_MAP_ATTENTIONS: diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b) for t in range(num_transformers): for b in TRANSFORMER_BLOCKS: diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) if i == l - 1: for k in ["weight", "bias"]: diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k) n += 1 for k in UNET_MAP_BASIC: diffusers_unet_map[k[1]] = k[0] return diffusers_unet_map def swap_scale_shift(weight): shift, scale = weight.chunk(2, dim=0) new_weight = torch.cat([scale, shift], dim=0) return new_weight MMDIT_MAP_BASIC = { ("context_embedder.bias", "context_embedder.bias"), ("context_embedder.weight", "context_embedder.weight"), ("t_embedder.mlp.0.bias", "time_text_embed.timestep_embedder.linear_1.bias"), ("t_embedder.mlp.0.weight", "time_text_embed.timestep_embedder.linear_1.weight"), ("t_embedder.mlp.2.bias", "time_text_embed.timestep_embedder.linear_2.bias"), ("t_embedder.mlp.2.weight", "time_text_embed.timestep_embedder.linear_2.weight"), ("x_embedder.proj.bias", "pos_embed.proj.bias"), ("x_embedder.proj.weight", "pos_embed.proj.weight"), ("y_embedder.mlp.0.bias", "time_text_embed.text_embedder.linear_1.bias"), ("y_embedder.mlp.0.weight", "time_text_embed.text_embedder.linear_1.weight"), ("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"), ("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"), ("pos_embed", "pos_embed.pos_embed"), ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift), ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift), ("final_layer.linear.bias", "proj_out.bias"), ("final_layer.linear.weight", "proj_out.weight"), } MMDIT_MAP_BLOCK = { ("context_block.adaLN_modulation.1.bias", "norm1_context.linear.bias"), ("context_block.adaLN_modulation.1.weight", "norm1_context.linear.weight"), ("context_block.attn.proj.bias", "attn.to_add_out.bias"), ("context_block.attn.proj.weight", "attn.to_add_out.weight"), ("context_block.mlp.fc1.bias", "ff_context.net.0.proj.bias"), ("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"), ("context_block.mlp.fc2.bias", "ff_context.net.2.bias"), ("context_block.mlp.fc2.weight", "ff_context.net.2.weight"), ("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"), ("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"), ("x_block.attn.proj.bias", "attn.to_out.0.bias"), ("x_block.attn.proj.weight", "attn.to_out.0.weight"), ("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"), ("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"), ("x_block.mlp.fc2.bias", "ff.net.2.bias"), ("x_block.mlp.fc2.weight", "ff.net.2.weight"), } def mmdit_to_diffusers(mmdit_config, output_prefix=""): key_map = {} depth = mmdit_config.get("depth", 0) num_blocks = mmdit_config.get("num_blocks", depth) for i in range(num_blocks): block_from = "transformer_blocks.{}".format(i) block_to = "{}joint_blocks.{}".format(output_prefix, i) offset = depth * 64 for end in ("weight", "bias"): k = "{}.attn.".format(block_from) qkv = "{}.x_block.attn.qkv.{}".format(block_to, end) key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset)) key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset)) key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset)) qkv = "{}.context_block.attn.qkv.{}".format(block_to, end) key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, offset)) key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset)) key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset)) for k in MMDIT_MAP_BLOCK: key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0]) map_basic = MMDIT_MAP_BASIC.copy() map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.bias".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.bias".format(depth - 1), swap_scale_shift)) map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.weight".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.weight".format(depth - 1), swap_scale_shift)) for k in map_basic: if len(k) > 2: key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2]) else: key_map[k[1]] = "{}{}".format(output_prefix, k[0]) return key_map def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: return tensor.narrow(dim, 0, batch_size) elif tensor.shape[dim] < batch_size: return tensor.repeat(dim * [1] + [math.ceil(batch_size / tensor.shape[dim])] + [1] * (len(tensor.shape) - 1 - dim)).narrow(dim, 0, batch_size) return tensor def resize_to_batch_size(tensor, batch_size): in_batch_size = tensor.shape[0] if in_batch_size == batch_size: return tensor if batch_size <= 1: return tensor[:batch_size] output = torch.empty([batch_size] + list(tensor.shape)[1:], dtype=tensor.dtype, device=tensor.device) if batch_size < in_batch_size: scale = (in_batch_size - 1) / (batch_size - 1) for i in range(batch_size): output[i] = tensor[min(round(i * scale), in_batch_size - 1)] else: scale = in_batch_size / batch_size for i in range(batch_size): output[i] = tensor[min(math.floor((i + 0.5) * scale), in_batch_size - 1)] return output def convert_sd_to(state_dict, dtype): keys = list(state_dict.keys()) for k in keys: state_dict[k] = state_dict[k].to(dtype) return state_dict def safetensors_header(safetensors_path, max_size=100 * 1024 * 1024): with open(safetensors_path, "rb") as f: header = f.read(8) length_of_header = struct.unpack(' max_size: return None return f.read(length_of_header) def set_attr(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) prev = getattr(obj, attrs[-1]) setattr(obj, attrs[-1], value) return prev def set_attr_param(obj, attr, value): return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) prev = getattr(obj, attrs[-1]) prev.data.copy_(value) def get_attr(obj, attr): attrs = attr.split(".") for name in attrs: obj = getattr(obj, name) return obj def bislerp(samples, width, height): def slerp(b1, b2, r): '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' c = b1.shape[-1] # norms b1_norms = torch.norm(b1, dim=-1, keepdim=True) b2_norms = torch.norm(b2, dim=-1, keepdim=True) # normalize b1_normalized = b1 / b1_norms b2_normalized = b2 / b2_norms # zero when norms are zero b1_normalized[b1_norms.expand(-1, c) == 0.0] = 0.0 b2_normalized[b2_norms.expand(-1, c) == 0.0] = 0.0 # slerp dot = (b1_normalized * b2_normalized).sum(1) omega = torch.acos(dot) so = torch.sin(omega) # technically not mathematically correct, but more pleasing? res = (torch.sin((1.0 - r.squeeze(1)) * omega) / so).unsqueeze(1) * b1_normalized + (torch.sin(r.squeeze(1) * omega) / so).unsqueeze(1) * b2_normalized res *= (b1_norms * (1.0 - r) + b2_norms * r).expand(-1, c) # edge cases for same or polar opposites res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] res[dot < 1e-5 - 1] = (b1 * (1.0 - r) + b2 * r)[dot < 1e-5 - 1] return res def generate_bilinear_data(length_old, length_new, device): coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1)) coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") ratios = coords_1 - coords_1.floor() coords_1 = coords_1.to(torch.int64) coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1)) + 1 coords_2[:, :, :, -1] -= 1 coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") coords_2 = coords_2.to(torch.int64) return ratios, coords_1, coords_2 orig_dtype = samples.dtype samples = samples.float() n, c, h, w = samples.shape h_new, w_new = (height, width) # linear w ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device) coords_1 = coords_1.expand((n, c, h, -1)) coords_2 = coords_2.expand((n, c, h, -1)) ratios = ratios.expand((n, 1, h, -1)) pass_1 = samples.gather(-1, coords_1).movedim(1, -1).reshape((-1, c)) pass_2 = samples.gather(-1, coords_2).movedim(1, -1).reshape((-1, c)) ratios = ratios.movedim(1, -1).reshape((-1, 1)) result = slerp(pass_1, pass_2, ratios) result = result.reshape(n, h, w_new, c).movedim(-1, 1) # linear h ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device) coords_1 = coords_1.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new)) coords_2 = coords_2.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new)) ratios = ratios.reshape((1, 1, -1, 1)).expand((n, 1, -1, w_new)) pass_1 = result.gather(-2, coords_1).movedim(1, -1).reshape((-1, c)) pass_2 = result.gather(-2, coords_2).movedim(1, -1).reshape((-1, c)) ratios = ratios.movedim(1, -1).reshape((-1, 1)) result = slerp(pass_1, pass_2, ratios) result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) return result.to(orig_dtype) def lanczos(samples, width, height): images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] result = torch.stack(images) return result.to(samples.device, samples.dtype) def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": old_width = samples.shape[3] old_height = samples.shape[2] old_aspect = old_width / old_height new_aspect = width / height x = 0 y = 0 if old_aspect > new_aspect: x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) elif old_aspect < new_aspect: y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) s = samples[:, :, y:old_height - y, x:old_width - x] else: s = samples if upscale_method == "bislerp": return bislerp(s, width, height) elif upscale_method == "lanczos": return lanczos(s, width, height) else: return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) @torch.inference_mode() def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None): dims = len(tile) output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device) for b in range(samples.shape[0]): s = samples[b:b + 1] out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))): s_in = s upscaled = [] for d in range(dims): pos = max(0, min(s.shape[d + 2] - overlap, it[d])) l = min(tile[d], s.shape[d + 2] - pos) s_in = s_in.narrow(d + 2, pos, l) upscaled.append(round(pos * upscale_amount)) ps = function(s_in).to(output_device) mask = torch.ones_like(ps) feather = round(overlap * upscale_amount) for t in range(feather): for d in range(2, dims + 2): m = mask.narrow(d, t, 1) m *= ((1.0 / feather) * (t + 1)) m = mask.narrow(d, mask.shape[d] - 1 - t, 1) m *= ((1.0 / feather) * (t + 1)) o = out o_d = out_div for d in range(dims): o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) o += ps * mask o_d += mask if pbar is not None: pbar.update(1) output[b:b + 1] = out / out_div return output def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None): return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar) def _progress_bar_update(value: float, total: float, preview_image_or_data: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None): server = server or current_execution_context().server # todo: this should really be from the context. right now the server is behaving like a context client_id = client_id or server.client_id interruption.throw_exception_if_processing_interrupted() progress: ProgressMessage = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} if isinstance(preview_image_or_data, dict): progress["output"] = preview_image_or_data server.send_sync("progress", progress, client_id) # todo: investigate a better way to send the image data, since it needs the node ID if preview_image_or_data is not None and not isinstance(preview_image_or_data, dict): server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image_or_data, client_id) def set_progress_bar_enabled(enabled: bool): warnings.warn( "The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.", DeprecationWarning, stacklevel=2 ) current_execution_context().server.receive_all_progress_notifications = enabled pass def get_progress_bar_enabled() -> bool: warnings.warn( "The global method 'get_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.", DeprecationWarning, stacklevel=2 ) return current_execution_context().server.receive_all_progress_notifications class _DisabledProgressBar: def __init__(self, *args, **kwargs): pass def update(self, *args, **kwargs): pass def update_absolute(self, *args, **kwargs): pass class ProgressBar: def __init__(self, total: float): self.total: float = total self.current: float = 0.0 self.server = current_execution_context().server def update_absolute(self, value, total=None, preview_image_or_output=None): if total is not None: self.total = total if value > self.total: value = self.total self.current = value _progress_bar_update(self.current, self.total, preview_image_or_output, server=self.server) def update(self, value): self.update_absolute(self.current + value) def get_project_root() -> str: return os.path.join(os.path.dirname(__file__), "..") @contextmanager def comfy_tqdm(): """ Monky patches child calls to tqdm and sends the progress to the UI :return: """ _original_init = tqdm.__init__ _original_update = tqdm.update try: def __init(self, *args, **kwargs): _original_init(self, *args, **kwargs) self._progress_bar = ProgressBar(self.total) def __update(self, n=1): assert self._progress_bar is not None _original_update(self, n) self._progress_bar.update(n) tqdm.__init__ = __init tqdm.update = __update yield finally: # Restore original tqdm tqdm.__init__ = _original_init tqdm.update = _original_update @contextmanager def comfy_progress(total: float) -> ProgressBar: ctx = current_execution_context() if ctx.server.receive_all_progress_notifications: yield ProgressBar(total) else: yield _DisabledProgressBar() @contextlib.contextmanager def seed_for_block(seed): # Save the current random state torch_rng_state = torch.get_rng_state() random_state = random.getstate() numpy_rng_state = np.random.get_state() if torch.cuda.is_available(): cuda_rng_state = torch.cuda.get_rng_state_all() # Set the new seed torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) try: yield finally: # Restore the previous random state torch.set_rng_state(torch_rng_state) random.setstate(random_state) np.random.set_state(numpy_rng_state) if torch.cuda.is_available(): torch.cuda.set_rng_state_all(cuda_rng_state)