mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
- Experimental support for sage attention on Linux - Diffusers loader now supports model indices - Transformers model management now aligns with updates to ComfyUI - Flux layers correctly use unbind - Add float8 support for model loading in more places - Experimental quantization approaches from Quanto and torchao - Model upscaling interacts with memory management better This update also disables ROCm testing because it isn't reliable enough on consumer hardware. ROCm is not really supported by the 7600.
993 lines
39 KiB
Python
993 lines
39 KiB
Python
"""
|
|
This file is part of ComfyUI.
|
|
Copyright (C) 2024 Comfy
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import itertools
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import random
|
|
import struct
|
|
import sys
|
|
import warnings
|
|
from contextlib import contextmanager
|
|
from pathlib import Path
|
|
from typing import Optional, Any
|
|
|
|
import accelerate
|
|
import numpy as np
|
|
import safetensors.torch
|
|
import torch
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
|
|
from . import checkpoint_pickle, interruption
|
|
from .component_model import files
|
|
from .component_model.deprecation import _deprecate_method
|
|
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
|
|
from .component_model.queue_types import BinaryEventTypes
|
|
from .execution_context import current_execution_context
|
|
|
|
|
|
# deprecate PROGRESS_BAR_ENABLED
|
|
def _get_progress_bar_enabled():
|
|
warnings.warn(
|
|
"The global variable 'PROGRESS_BAR_ENABLED' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
|
|
DeprecationWarning,
|
|
stacklevel=2
|
|
)
|
|
return current_execution_context().server.receive_all_progress_notifications
|
|
|
|
|
|
setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled))
|
|
|
|
|
|
def load_torch_file(ckpt: str, safe_load=False, device=None):
|
|
if device is None:
|
|
device = torch.device("cpu")
|
|
if ckpt is None:
|
|
raise FileNotFoundError("the checkpoint was not found")
|
|
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
|
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
|
elif ckpt.lower().endswith("index.json"):
|
|
# from accelerate
|
|
index_filename = ckpt
|
|
checkpoint_folder = os.path.split(index_filename)[0]
|
|
with open(index_filename) as f:
|
|
index = json.loads(f.read())
|
|
|
|
if "weight_map" in index:
|
|
index = index["weight_map"]
|
|
checkpoint_files = sorted(list(set(index.values())))
|
|
checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files]
|
|
sd: dict[str, torch.Tensor] = {}
|
|
for checkpoint_file in checkpoint_files:
|
|
sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type))
|
|
else:
|
|
if safe_load:
|
|
if not 'weights_only' in torch.load.__code__.co_varnames:
|
|
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
|
safe_load = False
|
|
if safe_load:
|
|
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
|
else:
|
|
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
|
|
if "global_step" in pl_sd:
|
|
logging.debug(f"Global Step: {pl_sd['global_step']}")
|
|
if "state_dict" in pl_sd:
|
|
sd = pl_sd["state_dict"]
|
|
else:
|
|
sd = pl_sd
|
|
return sd
|
|
|
|
|
|
def save_torch_file(sd, ckpt, metadata=None):
|
|
if metadata is not None:
|
|
safetensors.torch.save_file(sd, ckpt, metadata=metadata)
|
|
else:
|
|
safetensors.torch.save_file(sd, ckpt)
|
|
|
|
|
|
def calculate_parameters(sd, prefix=""):
|
|
params = 0
|
|
for k in sd.keys():
|
|
if k.startswith(prefix):
|
|
w = sd[k]
|
|
params += w.nelement()
|
|
return params
|
|
|
|
|
|
def weight_dtype(sd, prefix=""):
|
|
dtypes = {}
|
|
for k in sd.keys():
|
|
if k.startswith(prefix):
|
|
w = sd[k]
|
|
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
|
|
|
|
if len(dtypes) == 0:
|
|
return None
|
|
|
|
return max(dtypes, key=dtypes.get)
|
|
|
|
|
|
def state_dict_key_replace(state_dict, keys_to_replace):
|
|
for x in keys_to_replace:
|
|
if x in state_dict:
|
|
state_dict[keys_to_replace[x]] = state_dict.pop(x)
|
|
return state_dict
|
|
|
|
|
|
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
|
|
if filter_keys:
|
|
out = {}
|
|
else:
|
|
out = state_dict
|
|
for rp in replace_prefix:
|
|
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
|
|
for x in replace:
|
|
w = state_dict.pop(x[0])
|
|
out[x[1]] = w
|
|
return out
|
|
|
|
|
|
def transformers_convert(sd, prefix_from, prefix_to, number):
|
|
keys_to_replace = {
|
|
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
|
"{}token_embedding.weight": "{}embeddings.token_embedding.weight",
|
|
"{}ln_final.weight": "{}final_layer_norm.weight",
|
|
"{}ln_final.bias": "{}final_layer_norm.bias",
|
|
}
|
|
|
|
for k in keys_to_replace:
|
|
x = k.format(prefix_from)
|
|
if x in sd:
|
|
sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x)
|
|
|
|
resblock_to_replace = {
|
|
"ln_1": "layer_norm1",
|
|
"ln_2": "layer_norm2",
|
|
"mlp.c_fc": "mlp.fc1",
|
|
"mlp.c_proj": "mlp.fc2",
|
|
"attn.out_proj": "self_attn.out_proj",
|
|
}
|
|
|
|
for resblock in range(number):
|
|
for x in resblock_to_replace:
|
|
for y in ["weight", "bias"]:
|
|
k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
|
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
|
if k in sd:
|
|
sd[k_to] = sd.pop(k)
|
|
|
|
for y in ["weight", "bias"]:
|
|
k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
|
if k_from in sd:
|
|
weights = sd.pop(k_from)
|
|
shape_from = weights.shape[0] // 3
|
|
for x in range(3):
|
|
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
|
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
|
sd[k_to] = weights[shape_from * x:shape_from * (x + 1)]
|
|
|
|
return sd
|
|
|
|
|
|
def clip_text_transformers_convert(sd, prefix_from, prefix_to):
|
|
sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32)
|
|
|
|
tp = "{}text_projection.weight".format(prefix_from)
|
|
if tp in sd:
|
|
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp)
|
|
|
|
tp = "{}text_projection".format(prefix_from)
|
|
if tp in sd:
|
|
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous()
|
|
return sd
|
|
|
|
|
|
UNET_MAP_ATTENTIONS = {
|
|
"proj_in.weight",
|
|
"proj_in.bias",
|
|
"proj_out.weight",
|
|
"proj_out.bias",
|
|
"norm.weight",
|
|
"norm.bias",
|
|
}
|
|
|
|
TRANSFORMER_BLOCKS = {
|
|
"norm1.weight",
|
|
"norm1.bias",
|
|
"norm2.weight",
|
|
"norm2.bias",
|
|
"norm3.weight",
|
|
"norm3.bias",
|
|
"attn1.to_q.weight",
|
|
"attn1.to_k.weight",
|
|
"attn1.to_v.weight",
|
|
"attn1.to_out.0.weight",
|
|
"attn1.to_out.0.bias",
|
|
"attn2.to_q.weight",
|
|
"attn2.to_k.weight",
|
|
"attn2.to_v.weight",
|
|
"attn2.to_out.0.weight",
|
|
"attn2.to_out.0.bias",
|
|
"ff.net.0.proj.weight",
|
|
"ff.net.0.proj.bias",
|
|
"ff.net.2.weight",
|
|
"ff.net.2.bias",
|
|
}
|
|
|
|
UNET_MAP_RESNET = {
|
|
"in_layers.2.weight": "conv1.weight",
|
|
"in_layers.2.bias": "conv1.bias",
|
|
"emb_layers.1.weight": "time_emb_proj.weight",
|
|
"emb_layers.1.bias": "time_emb_proj.bias",
|
|
"out_layers.3.weight": "conv2.weight",
|
|
"out_layers.3.bias": "conv2.bias",
|
|
"skip_connection.weight": "conv_shortcut.weight",
|
|
"skip_connection.bias": "conv_shortcut.bias",
|
|
"in_layers.0.weight": "norm1.weight",
|
|
"in_layers.0.bias": "norm1.bias",
|
|
"out_layers.0.weight": "norm2.weight",
|
|
"out_layers.0.bias": "norm2.bias",
|
|
}
|
|
|
|
UNET_MAP_BASIC = {
|
|
("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
|
|
("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
|
|
("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
|
|
("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
|
|
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
|
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
|
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
|
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
|
("input_blocks.0.0.weight", "conv_in.weight"),
|
|
("input_blocks.0.0.bias", "conv_in.bias"),
|
|
("out.0.weight", "conv_norm_out.weight"),
|
|
("out.0.bias", "conv_norm_out.bias"),
|
|
("out.2.weight", "conv_out.weight"),
|
|
("out.2.bias", "conv_out.bias"),
|
|
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
|
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
|
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
|
("time_embed.2.bias", "time_embedding.linear_2.bias")
|
|
}
|
|
|
|
|
|
def unet_to_diffusers(unet_config):
|
|
if "num_res_blocks" not in unet_config:
|
|
return {}
|
|
num_res_blocks = unet_config["num_res_blocks"]
|
|
channel_mult = unet_config["channel_mult"]
|
|
transformer_depth = unet_config["transformer_depth"][:]
|
|
transformer_depth_output = unet_config["transformer_depth_output"][:]
|
|
num_blocks = len(channel_mult)
|
|
|
|
transformers_mid = unet_config.get("transformer_depth_middle", None)
|
|
|
|
diffusers_unet_map = {}
|
|
for x in range(num_blocks):
|
|
n = 1 + (num_res_blocks[x] + 1) * x
|
|
for i in range(num_res_blocks[x]):
|
|
for b in UNET_MAP_RESNET:
|
|
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
|
num_transformers = transformer_depth.pop(0)
|
|
if num_transformers > 0:
|
|
for b in UNET_MAP_ATTENTIONS:
|
|
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
|
for t in range(num_transformers):
|
|
for b in TRANSFORMER_BLOCKS:
|
|
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
|
n += 1
|
|
for k in ["weight", "bias"]:
|
|
diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
|
|
|
|
i = 0
|
|
for b in UNET_MAP_ATTENTIONS:
|
|
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
|
|
for t in range(transformers_mid):
|
|
for b in TRANSFORMER_BLOCKS:
|
|
diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
|
|
|
|
for i, n in enumerate([0, 2]):
|
|
for b in UNET_MAP_RESNET:
|
|
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
|
|
|
num_res_blocks = list(reversed(num_res_blocks))
|
|
for x in range(num_blocks):
|
|
n = (num_res_blocks[x] + 1) * x
|
|
l = num_res_blocks[x] + 1
|
|
for i in range(l):
|
|
c = 0
|
|
for b in UNET_MAP_RESNET:
|
|
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
|
c += 1
|
|
num_transformers = transformer_depth_output.pop()
|
|
if num_transformers > 0:
|
|
c += 1
|
|
for b in UNET_MAP_ATTENTIONS:
|
|
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
|
for t in range(num_transformers):
|
|
for b in TRANSFORMER_BLOCKS:
|
|
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
|
if i == l - 1:
|
|
for k in ["weight", "bias"]:
|
|
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
|
|
n += 1
|
|
|
|
for k in UNET_MAP_BASIC:
|
|
diffusers_unet_map[k[1]] = k[0]
|
|
|
|
return diffusers_unet_map
|
|
|
|
|
|
def swap_scale_shift(weight):
|
|
shift, scale = weight.chunk(2, dim=0)
|
|
new_weight = torch.cat([scale, shift], dim=0)
|
|
return new_weight
|
|
|
|
|
|
MMDIT_MAP_BASIC = {
|
|
("context_embedder.bias", "context_embedder.bias"),
|
|
("context_embedder.weight", "context_embedder.weight"),
|
|
("t_embedder.mlp.0.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
|
("t_embedder.mlp.0.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
|
("t_embedder.mlp.2.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
|
("t_embedder.mlp.2.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
|
("x_embedder.proj.bias", "pos_embed.proj.bias"),
|
|
("x_embedder.proj.weight", "pos_embed.proj.weight"),
|
|
("y_embedder.mlp.0.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
|
("y_embedder.mlp.0.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
|
("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
|
("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
|
("pos_embed", "pos_embed.pos_embed"),
|
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
|
("final_layer.linear.bias", "proj_out.bias"),
|
|
("final_layer.linear.weight", "proj_out.weight"),
|
|
}
|
|
|
|
MMDIT_MAP_BLOCK = {
|
|
("context_block.adaLN_modulation.1.bias", "norm1_context.linear.bias"),
|
|
("context_block.adaLN_modulation.1.weight", "norm1_context.linear.weight"),
|
|
("context_block.attn.proj.bias", "attn.to_add_out.bias"),
|
|
("context_block.attn.proj.weight", "attn.to_add_out.weight"),
|
|
("context_block.mlp.fc1.bias", "ff_context.net.0.proj.bias"),
|
|
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
|
|
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
|
|
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
|
|
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
|
|
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
|
|
("x_block.attn.proj.bias", "attn.to_out.0.bias"),
|
|
("x_block.attn.proj.weight", "attn.to_out.0.weight"),
|
|
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
|
|
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
|
|
("x_block.mlp.fc2.bias", "ff.net.2.bias"),
|
|
("x_block.mlp.fc2.weight", "ff.net.2.weight"),
|
|
}
|
|
|
|
|
|
def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
|
key_map = {}
|
|
|
|
depth = mmdit_config.get("depth", 0)
|
|
num_blocks = mmdit_config.get("num_blocks", depth)
|
|
for i in range(num_blocks):
|
|
block_from = "transformer_blocks.{}".format(i)
|
|
block_to = "{}joint_blocks.{}".format(output_prefix, i)
|
|
|
|
offset = depth * 64
|
|
|
|
for end in ("weight", "bias"):
|
|
k = "{}.attn.".format(block_from)
|
|
qkv = "{}.x_block.attn.qkv.{}".format(block_to, end)
|
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
|
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
|
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
|
|
|
qkv = "{}.context_block.attn.qkv.{}".format(block_to, end)
|
|
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, offset))
|
|
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
|
|
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
|
|
|
for k in MMDIT_MAP_BLOCK:
|
|
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
|
|
|
map_basic = MMDIT_MAP_BASIC.copy()
|
|
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.bias".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.bias".format(depth - 1), swap_scale_shift))
|
|
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.weight".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.weight".format(depth - 1), swap_scale_shift))
|
|
|
|
for k in map_basic:
|
|
if len(k) > 2:
|
|
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
|
else:
|
|
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
|
|
|
return key_map
|
|
|
|
|
|
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
|
|
n_double_layers = mmdit_config.get("n_double_layers", 0)
|
|
n_layers = mmdit_config.get("n_layers", 0)
|
|
|
|
key_map = {}
|
|
for i in range(n_layers):
|
|
if i < n_double_layers:
|
|
index = i
|
|
prefix_from = "joint_transformer_blocks"
|
|
prefix_to = "{}double_layers".format(output_prefix)
|
|
block_map = {
|
|
"attn.to_q.weight": "attn.w2q.weight",
|
|
"attn.to_k.weight": "attn.w2k.weight",
|
|
"attn.to_v.weight": "attn.w2v.weight",
|
|
"attn.to_out.0.weight": "attn.w2o.weight",
|
|
"attn.add_q_proj.weight": "attn.w1q.weight",
|
|
"attn.add_k_proj.weight": "attn.w1k.weight",
|
|
"attn.add_v_proj.weight": "attn.w1v.weight",
|
|
"attn.to_add_out.weight": "attn.w1o.weight",
|
|
"ff.linear_1.weight": "mlpX.c_fc1.weight",
|
|
"ff.linear_2.weight": "mlpX.c_fc2.weight",
|
|
"ff.out_projection.weight": "mlpX.c_proj.weight",
|
|
"ff_context.linear_1.weight": "mlpC.c_fc1.weight",
|
|
"ff_context.linear_2.weight": "mlpC.c_fc2.weight",
|
|
"ff_context.out_projection.weight": "mlpC.c_proj.weight",
|
|
"norm1.linear.weight": "modX.1.weight",
|
|
"norm1_context.linear.weight": "modC.1.weight",
|
|
}
|
|
else:
|
|
index = i - n_double_layers
|
|
prefix_from = "single_transformer_blocks"
|
|
prefix_to = "{}single_layers".format(output_prefix)
|
|
|
|
block_map = {
|
|
"attn.to_q.weight": "attn.w1q.weight",
|
|
"attn.to_k.weight": "attn.w1k.weight",
|
|
"attn.to_v.weight": "attn.w1v.weight",
|
|
"attn.to_out.0.weight": "attn.w1o.weight",
|
|
"norm1.linear.weight": "modCX.1.weight",
|
|
"ff.linear_1.weight": "mlp.c_fc1.weight",
|
|
"ff.linear_2.weight": "mlp.c_fc2.weight",
|
|
"ff.out_projection.weight": "mlp.c_proj.weight"
|
|
}
|
|
|
|
for k in block_map:
|
|
key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k])
|
|
|
|
MAP_BASIC = {
|
|
("positional_encoding", "pos_embed.pos_embed"),
|
|
("register_tokens", "register_tokens"),
|
|
("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"),
|
|
("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"),
|
|
("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"),
|
|
("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"),
|
|
("cond_seq_linear.weight", "context_embedder.weight"),
|
|
("init_x_linear.weight", "pos_embed.proj.weight"),
|
|
("init_x_linear.bias", "pos_embed.proj.bias"),
|
|
("final_linear.weight", "proj_out.weight"),
|
|
("modF.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
|
}
|
|
|
|
for k in MAP_BASIC:
|
|
if len(k) > 2:
|
|
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
|
else:
|
|
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
|
|
|
return key_map
|
|
|
|
|
|
def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|
n_double_layers = mmdit_config.get("depth", 0)
|
|
n_single_layers = mmdit_config.get("depth_single_blocks", 0)
|
|
hidden_size = mmdit_config.get("hidden_size", 0)
|
|
|
|
key_map = {}
|
|
for index in range(n_double_layers):
|
|
prefix_from = "transformer_blocks.{}".format(index)
|
|
prefix_to = "{}double_blocks.{}".format(output_prefix, index)
|
|
|
|
for end in ("weight", "bias"):
|
|
k = "{}.attn.".format(prefix_from)
|
|
qkv = "{}.img_attn.qkv.{}".format(prefix_to, end)
|
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
|
|
|
k = "{}.attn.".format(prefix_from)
|
|
qkv = "{}.txt_attn.qkv.{}".format(prefix_to, end)
|
|
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
|
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
|
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
|
|
|
block_map = {
|
|
"attn.to_out.0.weight": "img_attn.proj.weight",
|
|
"attn.to_out.0.bias": "img_attn.proj.bias",
|
|
"norm1.linear.weight": "img_mod.lin.weight",
|
|
"norm1.linear.bias": "img_mod.lin.bias",
|
|
"norm1_context.linear.weight": "txt_mod.lin.weight",
|
|
"norm1_context.linear.bias": "txt_mod.lin.bias",
|
|
"attn.to_add_out.weight": "txt_attn.proj.weight",
|
|
"attn.to_add_out.bias": "txt_attn.proj.bias",
|
|
"ff.net.0.proj.weight": "img_mlp.0.weight",
|
|
"ff.net.0.proj.bias": "img_mlp.0.bias",
|
|
"ff.net.2.weight": "img_mlp.2.weight",
|
|
"ff.net.2.bias": "img_mlp.2.bias",
|
|
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
|
|
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
|
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
|
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
|
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
|
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
|
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
|
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
|
}
|
|
|
|
for k in block_map:
|
|
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
|
|
|
for index in range(n_single_layers):
|
|
prefix_from = "single_transformer_blocks.{}".format(index)
|
|
prefix_to = "{}single_blocks.{}".format(output_prefix, index)
|
|
|
|
for end in ("weight", "bias"):
|
|
k = "{}.attn.".format(prefix_from)
|
|
qkv = "{}.linear1.{}".format(prefix_to, end)
|
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
|
key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
|
|
|
|
block_map = {
|
|
"norm.linear.weight": "modulation.lin.weight",
|
|
"norm.linear.bias": "modulation.lin.bias",
|
|
"proj_out.weight": "linear2.weight",
|
|
"proj_out.bias": "linear2.bias",
|
|
"attn.norm_q.weight": "norm.query_norm.scale",
|
|
"attn.norm_k.weight": "norm.key_norm.scale",
|
|
}
|
|
|
|
for k in block_map:
|
|
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
|
|
|
MAP_BASIC = {
|
|
("final_layer.linear.bias", "proj_out.bias"),
|
|
("final_layer.linear.weight", "proj_out.weight"),
|
|
("img_in.bias", "x_embedder.bias"),
|
|
("img_in.weight", "x_embedder.weight"),
|
|
("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
|
("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
|
("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
|
("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
|
("txt_in.bias", "context_embedder.bias"),
|
|
("txt_in.weight", "context_embedder.weight"),
|
|
("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
|
("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
|
("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
|
("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
|
("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
|
|
("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
|
|
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
|
|
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
|
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
|
|
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
|
|
}
|
|
|
|
for k in MAP_BASIC:
|
|
if len(k) > 2:
|
|
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
|
else:
|
|
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
|
|
|
return key_map
|
|
|
|
|
|
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
|
if tensor.shape[dim] > batch_size:
|
|
return tensor.narrow(dim, 0, batch_size)
|
|
elif tensor.shape[dim] < batch_size:
|
|
return tensor.repeat(dim * [1] + [math.ceil(batch_size / tensor.shape[dim])] + [1] * (len(tensor.shape) - 1 - dim)).narrow(dim, 0, batch_size)
|
|
return tensor
|
|
|
|
|
|
def resize_to_batch_size(tensor, batch_size):
|
|
in_batch_size = tensor.shape[0]
|
|
if in_batch_size == batch_size:
|
|
return tensor
|
|
|
|
if batch_size <= 1:
|
|
return tensor[:batch_size]
|
|
|
|
output = torch.empty([batch_size] + list(tensor.shape)[1:], dtype=tensor.dtype, device=tensor.device)
|
|
if batch_size < in_batch_size:
|
|
scale = (in_batch_size - 1) / (batch_size - 1)
|
|
for i in range(batch_size):
|
|
output[i] = tensor[min(round(i * scale), in_batch_size - 1)]
|
|
else:
|
|
scale = in_batch_size / batch_size
|
|
for i in range(batch_size):
|
|
output[i] = tensor[min(math.floor((i + 0.5) * scale), in_batch_size - 1)]
|
|
|
|
return output
|
|
|
|
|
|
def convert_sd_to(state_dict, dtype):
|
|
keys = list(state_dict.keys())
|
|
for k in keys:
|
|
state_dict[k] = state_dict[k].to(dtype)
|
|
return state_dict
|
|
|
|
|
|
def safetensors_header(safetensors_path, max_size=100 * 1024 * 1024):
|
|
with open(safetensors_path, "rb") as f:
|
|
header = f.read(8)
|
|
length_of_header = struct.unpack('<Q', header)[0]
|
|
if length_of_header > max_size:
|
|
return None
|
|
return f.read(length_of_header)
|
|
|
|
|
|
def set_attr(obj, attr, value):
|
|
attrs = attr.split(".")
|
|
for name in attrs[:-1]:
|
|
obj = getattr(obj, name)
|
|
prev = getattr(obj, attrs[-1])
|
|
setattr(obj, attrs[-1], value)
|
|
return prev
|
|
|
|
|
|
def set_attr_param(obj, attr, value):
|
|
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
|
|
|
|
|
def copy_to_param(obj, attr, value):
|
|
# inplace update tensor instead of replacing it
|
|
attrs = attr.split(".")
|
|
for name in attrs[:-1]:
|
|
obj = getattr(obj, name)
|
|
prev = getattr(obj, attrs[-1])
|
|
prev.data.copy_(value)
|
|
|
|
|
|
def get_attr(obj, attr):
|
|
attrs = attr.split(".")
|
|
for name in attrs:
|
|
obj = getattr(obj, name)
|
|
return obj
|
|
|
|
|
|
def bislerp(samples, width, height):
|
|
def slerp(b1, b2, r):
|
|
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
|
|
|
c = b1.shape[-1]
|
|
|
|
# norms
|
|
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
|
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
|
|
|
# normalize
|
|
b1_normalized = b1 / b1_norms
|
|
b2_normalized = b2 / b2_norms
|
|
|
|
# zero when norms are zero
|
|
b1_normalized[b1_norms.expand(-1, c) == 0.0] = 0.0
|
|
b2_normalized[b2_norms.expand(-1, c) == 0.0] = 0.0
|
|
|
|
# slerp
|
|
dot = (b1_normalized * b2_normalized).sum(1)
|
|
omega = torch.acos(dot)
|
|
so = torch.sin(omega)
|
|
|
|
# technically not mathematically correct, but more pleasing?
|
|
res = (torch.sin((1.0 - r.squeeze(1)) * omega) / so).unsqueeze(1) * b1_normalized + (torch.sin(r.squeeze(1) * omega) / so).unsqueeze(1) * b2_normalized
|
|
res *= (b1_norms * (1.0 - r) + b2_norms * r).expand(-1, c)
|
|
|
|
# edge cases for same or polar opposites
|
|
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
|
res[dot < 1e-5 - 1] = (b1 * (1.0 - r) + b2 * r)[dot < 1e-5 - 1]
|
|
return res
|
|
|
|
def generate_bilinear_data(length_old, length_new, device):
|
|
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1))
|
|
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
|
ratios = coords_1 - coords_1.floor()
|
|
coords_1 = coords_1.to(torch.int64)
|
|
|
|
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1, 1, 1, -1)) + 1
|
|
coords_2[:, :, :, -1] -= 1
|
|
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
|
coords_2 = coords_2.to(torch.int64)
|
|
return ratios, coords_1, coords_2
|
|
|
|
orig_dtype = samples.dtype
|
|
samples = samples.float()
|
|
n, c, h, w = samples.shape
|
|
h_new, w_new = (height, width)
|
|
|
|
# linear w
|
|
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
|
|
coords_1 = coords_1.expand((n, c, h, -1))
|
|
coords_2 = coords_2.expand((n, c, h, -1))
|
|
ratios = ratios.expand((n, 1, h, -1))
|
|
|
|
pass_1 = samples.gather(-1, coords_1).movedim(1, -1).reshape((-1, c))
|
|
pass_2 = samples.gather(-1, coords_2).movedim(1, -1).reshape((-1, c))
|
|
ratios = ratios.movedim(1, -1).reshape((-1, 1))
|
|
|
|
result = slerp(pass_1, pass_2, ratios)
|
|
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
|
|
|
|
# linear h
|
|
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
|
|
coords_1 = coords_1.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new))
|
|
coords_2 = coords_2.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new))
|
|
ratios = ratios.reshape((1, 1, -1, 1)).expand((n, 1, -1, w_new))
|
|
|
|
pass_1 = result.gather(-2, coords_1).movedim(1, -1).reshape((-1, c))
|
|
pass_2 = result.gather(-2, coords_2).movedim(1, -1).reshape((-1, c))
|
|
ratios = ratios.movedim(1, -1).reshape((-1, 1))
|
|
|
|
result = slerp(pass_1, pass_2, ratios)
|
|
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
|
return result.to(orig_dtype)
|
|
|
|
|
|
def lanczos(samples, width, height):
|
|
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
|
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
|
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
|
result = torch.stack(images)
|
|
return result.to(samples.device, samples.dtype)
|
|
|
|
|
|
def common_upscale(samples, width, height, upscale_method, crop):
|
|
if crop == "center":
|
|
old_width = samples.shape[3]
|
|
old_height = samples.shape[2]
|
|
old_aspect = old_width / old_height
|
|
new_aspect = width / height
|
|
x = 0
|
|
y = 0
|
|
if old_aspect > new_aspect:
|
|
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
|
elif old_aspect < new_aspect:
|
|
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
|
s = samples[:, :, y:old_height - y, x:old_width - x]
|
|
else:
|
|
s = samples
|
|
|
|
if upscale_method == "bislerp":
|
|
return bislerp(s, width, height)
|
|
elif upscale_method == "lanczos":
|
|
return lanczos(s, width, height)
|
|
else:
|
|
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
|
|
|
|
|
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
|
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
|
|
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
|
|
return rows * cols
|
|
|
|
|
|
@torch.inference_mode()
|
|
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None):
|
|
dims = len(tile)
|
|
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
|
|
|
|
for b in range(samples.shape[0]):
|
|
s = samples[b:b + 1]
|
|
|
|
# handle entire input fitting in a single tile
|
|
if all(s.shape[d + 2] <= tile[d] for d in range(dims)):
|
|
output[b:b + 1] = function(s).to(output_device)
|
|
if pbar is not None:
|
|
pbar.update(1)
|
|
continue
|
|
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
|
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
|
|
|
positions = [range(0, s.shape[d + 2], tile[d] - overlap) if s.shape[d + 2] > tile[d] else [0] for d in range(dims)]
|
|
|
|
for it in itertools.product(*positions):
|
|
s_in = s
|
|
upscaled = []
|
|
|
|
for d in range(dims):
|
|
pos = max(0, min(s.shape[d + 2] - overlap, it[d]))
|
|
l = min(tile[d], s.shape[d + 2] - pos)
|
|
s_in = s_in.narrow(d + 2, pos, l)
|
|
upscaled.append(round(pos * upscale_amount))
|
|
|
|
ps = function(s_in).to(output_device)
|
|
mask = torch.ones_like(ps)
|
|
feather = round(overlap * upscale_amount)
|
|
|
|
for t in range(feather):
|
|
for d in range(2, dims + 2):
|
|
a = (t + 1) / feather
|
|
mask.narrow(d, t, 1).mul_(a)
|
|
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
|
|
|
o = out
|
|
o_d = out_div
|
|
for d in range(dims):
|
|
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
|
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
|
|
|
o.add_(ps * mask)
|
|
o_d.add_(mask)
|
|
|
|
if pbar is not None:
|
|
pbar.update(1)
|
|
|
|
output[b:b + 1] = out / out_div
|
|
return output
|
|
|
|
|
|
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None):
|
|
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar)
|
|
|
|
|
|
def _progress_bar_update(value: float, total: float, preview_image_or_data: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None):
|
|
server = server or current_execution_context().server
|
|
# todo: this should really be from the context. right now the server is behaving like a context
|
|
client_id = client_id or server.client_id
|
|
interruption.throw_exception_if_processing_interrupted()
|
|
progress: ProgressMessage = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
|
if isinstance(preview_image_or_data, dict):
|
|
progress["output"] = preview_image_or_data
|
|
|
|
server.send_sync("progress", progress, client_id)
|
|
|
|
# todo: investigate a better way to send the image data, since it needs the node ID
|
|
if preview_image_or_data is not None and not isinstance(preview_image_or_data, dict):
|
|
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image_or_data, client_id)
|
|
|
|
|
|
def set_progress_bar_enabled(enabled: bool):
|
|
warnings.warn(
|
|
"The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
|
|
DeprecationWarning,
|
|
stacklevel=2
|
|
)
|
|
|
|
current_execution_context().server.receive_all_progress_notifications = enabled
|
|
pass
|
|
|
|
|
|
def get_progress_bar_enabled() -> bool:
|
|
warnings.warn(
|
|
"The global method 'get_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
|
|
DeprecationWarning,
|
|
stacklevel=2
|
|
)
|
|
return current_execution_context().server.receive_all_progress_notifications
|
|
|
|
|
|
class _DisabledProgressBar:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
def update(self, *args, **kwargs):
|
|
pass
|
|
|
|
def update_absolute(self, *args, **kwargs):
|
|
pass
|
|
|
|
|
|
class ProgressBar:
|
|
def __init__(self, total: float):
|
|
self.total: float = total
|
|
self.current: float = 0.0
|
|
self.server = current_execution_context().server
|
|
|
|
def update_absolute(self, value, total=None, preview_image_or_output=None):
|
|
if total is not None:
|
|
self.total = total
|
|
if value > self.total:
|
|
value = self.total
|
|
self.current = value
|
|
_progress_bar_update(self.current, self.total, preview_image_or_output, server=self.server)
|
|
|
|
def update(self, value):
|
|
self.update_absolute(self.current + value)
|
|
|
|
|
|
@_deprecate_method(version="1.0.0", message="The root project directory isn't valid when the application is installed as a package. Use os.getcwd() instead.")
|
|
def get_project_root() -> str:
|
|
return files.get_package_as_path("comfy")
|
|
|
|
|
|
@contextmanager
|
|
def comfy_tqdm():
|
|
"""
|
|
Monky patches child calls to tqdm and sends the progress to the UI
|
|
:return:
|
|
"""
|
|
_original_init = tqdm.__init__
|
|
_original_update = tqdm.update
|
|
try:
|
|
def __init(self, *args, **kwargs):
|
|
_original_init(self, *args, **kwargs)
|
|
self._progress_bar = ProgressBar(self.total)
|
|
|
|
def __update(self, n=1):
|
|
assert self._progress_bar is not None
|
|
_original_update(self, n)
|
|
self._progress_bar.update(n)
|
|
|
|
tqdm.__init__ = __init
|
|
tqdm.update = __update
|
|
yield
|
|
finally:
|
|
# Restore original tqdm
|
|
tqdm.__init__ = _original_init
|
|
tqdm.update = _original_update
|
|
|
|
|
|
@contextmanager
|
|
def comfy_progress(total: float) -> ProgressBar:
|
|
ctx = current_execution_context()
|
|
if ctx.server.receive_all_progress_notifications:
|
|
yield ProgressBar(total)
|
|
else:
|
|
yield _DisabledProgressBar()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def seed_for_block(seed):
|
|
# Save the current random state
|
|
torch_rng_state = torch.get_rng_state()
|
|
random_state = random.getstate()
|
|
numpy_rng_state = np.random.get_state()
|
|
# todo: investigate with torch.random.fork_rng(devices=(device,))
|
|
if torch.cuda.is_available():
|
|
cuda_rng_state = torch.cuda.get_rng_state_all()
|
|
else:
|
|
cuda_rng_state = None
|
|
|
|
# Set the new seed
|
|
torch.manual_seed(seed)
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
# Restore the previous random state
|
|
torch.set_rng_state(torch_rng_state)
|
|
random.setstate(random_state)
|
|
np.random.set_state(numpy_rng_state)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.set_rng_state_all(cuda_rng_state)
|
|
|
|
|
|
def pil2tensor(image: Image) -> torch.Tensor:
|
|
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
|
|
|
|
|
def tensor2pil(t_image: torch.Tensor) -> Image:
|
|
return Image.fromarray(np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|