""" This file is part of ComfyUI. Copyright (C) 2024 Comfy This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . """ import torch import math import struct import ctypes import os import comfy.memory_management import safetensors.torch import numpy as np from PIL import Image import logging import itertools import threading from torch.nn.functional import interpolate from tqdm.auto import trange from einops import rearrange from comfy.cli_args import args import json import time import warnings MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap if True: # ckpt/pt file whitelist for safe loading of old sd files class ModelCheckpoint: pass ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint" def scalar(*args, **kwargs): return None scalar.__module__ = "numpy.core.multiarray" from numpy import dtype from numpy.dtypes import Float64DType def encode(*args, **kwargs): # no longer necessary on newer torch return None encode.__module__ = "_codecs" torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode]) logging.info("Checkpoint files will always be loaded safely.") # Current as of safetensors 0.7.0 _TYPES = { "F64": torch.float64, "F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16, "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8, "U8": torch.uint8, "BOOL": torch.bool, "F8_E4M3": torch.float8_e4m3fn, "F8_E5M2": torch.float8_e5m2, "C64": torch.complex64, "U64": torch.uint64, "U32": torch.uint32, "U16": torch.uint16, } def load_safetensors(ckpt): import comfy_aimdo.model_mmap f = open(ckpt, "rb", buffering=0) file_lock = threading.Lock() model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt) file_size = os.path.getsize(ckpt) mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get())) header_size = struct.unpack(" 0: message = e.args[0] if "HeaderTooLarge" in message: raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.".format(message, ckpt)) if "MetadataIncompleteBuffer" in message: raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt)) raise e else: torch_args = {} if MMAP_TORCH_FILES: torch_args["mmap"] = True pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: if len(pl_sd) == 1: key = list(pl_sd.keys())[0] sd = pl_sd[key] if not isinstance(sd, dict): sd = pl_sd else: sd = pl_sd return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None): if metadata is not None: safetensors.torch.save_file(sd, ckpt, metadata=metadata) else: safetensors.torch.save_file(sd, ckpt) def calculate_parameters(sd, prefix=""): params = 0 for k in sd.keys(): if k.startswith(prefix): w = sd[k] params += w.nelement() return params def weight_dtype(sd, prefix=""): dtypes = {} for k in sd.keys(): if k.startswith(prefix): w = sd[k] dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel() if len(dtypes) == 0: return None return max(dtypes, key=dtypes.get) def state_dict_key_replace(state_dict, keys_to_replace): for x in keys_to_replace: if x in state_dict: state_dict[keys_to_replace[x]] = state_dict.pop(x) return state_dict def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): if filter_keys: out = {} else: out = state_dict for rp in replace_prefix: replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) for x in replace: w = state_dict.pop(x[0]) out[x[1]] = w return out def transformers_convert(sd, prefix_from, prefix_to, number): keys_to_replace = { "{}positional_embedding": "{}embeddings.position_embedding.weight", "{}token_embedding.weight": "{}embeddings.token_embedding.weight", "{}ln_final.weight": "{}final_layer_norm.weight", "{}ln_final.bias": "{}final_layer_norm.bias", } for k in keys_to_replace: x = k.format(prefix_from) if x in sd: sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x) resblock_to_replace = { "ln_1": "layer_norm1", "ln_2": "layer_norm2", "mlp.c_fc": "mlp.fc1", "mlp.c_proj": "mlp.fc2", "attn.out_proj": "self_attn.out_proj", } for resblock in range(number): for x in resblock_to_replace: for y in ["weight", "bias"]: k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y) k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y) if k in sd: sd[k_to] = sd.pop(k) for y in ["weight", "bias"]: k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y) if k_from in sd: weights = sd.pop(k_from) shape_from = weights.shape[0] // 3 for x in range(3): p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd def clip_text_transformers_convert(sd, prefix_from, prefix_to): sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32) tp = "{}text_projection.weight".format(prefix_from) if tp in sd: sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp) tp = "{}text_projection".format(prefix_from) if tp in sd: sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous() return sd UNET_MAP_ATTENTIONS = { "proj_in.weight", "proj_in.bias", "proj_out.weight", "proj_out.bias", "norm.weight", "norm.bias", } TRANSFORMER_BLOCKS = { "norm1.weight", "norm1.bias", "norm2.weight", "norm2.bias", "norm3.weight", "norm3.bias", "attn1.to_q.weight", "attn1.to_k.weight", "attn1.to_v.weight", "attn1.to_out.0.weight", "attn1.to_out.0.bias", "attn2.to_q.weight", "attn2.to_k.weight", "attn2.to_v.weight", "attn2.to_out.0.weight", "attn2.to_out.0.bias", "ff.net.0.proj.weight", "ff.net.0.proj.bias", "ff.net.2.weight", "ff.net.2.bias", } UNET_MAP_RESNET = { "in_layers.2.weight": "conv1.weight", "in_layers.2.bias": "conv1.bias", "emb_layers.1.weight": "time_emb_proj.weight", "emb_layers.1.bias": "time_emb_proj.bias", "out_layers.3.weight": "conv2.weight", "out_layers.3.bias": "conv2.bias", "skip_connection.weight": "conv_shortcut.weight", "skip_connection.bias": "conv_shortcut.bias", "in_layers.0.weight": "norm1.weight", "in_layers.0.bias": "norm1.bias", "out_layers.0.weight": "norm2.weight", "out_layers.0.bias": "norm2.bias", } UNET_MAP_BASIC = { ("label_emb.0.0.weight", "class_embedding.linear_1.weight"), ("label_emb.0.0.bias", "class_embedding.linear_1.bias"), ("label_emb.0.2.weight", "class_embedding.linear_2.weight"), ("label_emb.0.2.bias", "class_embedding.linear_2.bias"), ("label_emb.0.0.weight", "add_embedding.linear_1.weight"), ("label_emb.0.0.bias", "add_embedding.linear_1.bias"), ("label_emb.0.2.weight", "add_embedding.linear_2.weight"), ("label_emb.0.2.bias", "add_embedding.linear_2.bias"), ("input_blocks.0.0.weight", "conv_in.weight"), ("input_blocks.0.0.bias", "conv_in.bias"), ("out.0.weight", "conv_norm_out.weight"), ("out.0.bias", "conv_norm_out.bias"), ("out.2.weight", "conv_out.weight"), ("out.2.bias", "conv_out.bias"), ("time_embed.0.weight", "time_embedding.linear_1.weight"), ("time_embed.0.bias", "time_embedding.linear_1.bias"), ("time_embed.2.weight", "time_embedding.linear_2.weight"), ("time_embed.2.bias", "time_embedding.linear_2.bias") } def unet_to_diffusers(unet_config): if "num_res_blocks" not in unet_config: return {} num_res_blocks = unet_config["num_res_blocks"] channel_mult = unet_config["channel_mult"] transformer_depth = unet_config["transformer_depth"][:] transformer_depth_output = unet_config["transformer_depth_output"][:] num_blocks = len(channel_mult) transformers_mid = unet_config.get("transformer_depth_middle", None) diffusers_unet_map = {} for x in range(num_blocks): n = 1 + (num_res_blocks[x] + 1) * x for i in range(num_res_blocks[x]): for b in UNET_MAP_RESNET: diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b) num_transformers = transformer_depth.pop(0) if num_transformers > 0: for b in UNET_MAP_ATTENTIONS: diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b) for t in range(num_transformers): for b in TRANSFORMER_BLOCKS: diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) n += 1 for k in ["weight", "bias"]: diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k) i = 0 for b in UNET_MAP_ATTENTIONS: diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b) for t in range(transformers_mid): for b in TRANSFORMER_BLOCKS: diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b) for i, n in enumerate([0, 2]): for b in UNET_MAP_RESNET: diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b) num_res_blocks = list(reversed(num_res_blocks)) for x in range(num_blocks): n = (num_res_blocks[x] + 1) * x l = num_res_blocks[x] + 1 for i in range(l): c = 0 for b in UNET_MAP_RESNET: diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b) c += 1 num_transformers = transformer_depth_output.pop() if num_transformers > 0: c += 1 for b in UNET_MAP_ATTENTIONS: diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b) for t in range(num_transformers): for b in TRANSFORMER_BLOCKS: diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) if i == l - 1: for k in ["weight", "bias"]: diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k) n += 1 for k in UNET_MAP_BASIC: diffusers_unet_map[k[1]] = k[0] return diffusers_unet_map def swap_scale_shift(weight): shift, scale = weight.chunk(2, dim=0) new_weight = torch.cat([scale, shift], dim=0) return new_weight MMDIT_MAP_BASIC = { ("context_embedder.bias", "context_embedder.bias"), ("context_embedder.weight", "context_embedder.weight"), ("t_embedder.mlp.0.bias", "time_text_embed.timestep_embedder.linear_1.bias"), ("t_embedder.mlp.0.weight", "time_text_embed.timestep_embedder.linear_1.weight"), ("t_embedder.mlp.2.bias", "time_text_embed.timestep_embedder.linear_2.bias"), ("t_embedder.mlp.2.weight", "time_text_embed.timestep_embedder.linear_2.weight"), ("x_embedder.proj.bias", "pos_embed.proj.bias"), ("x_embedder.proj.weight", "pos_embed.proj.weight"), ("y_embedder.mlp.0.bias", "time_text_embed.text_embedder.linear_1.bias"), ("y_embedder.mlp.0.weight", "time_text_embed.text_embedder.linear_1.weight"), ("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"), ("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"), ("pos_embed", "pos_embed.pos_embed"), ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift), ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift), ("final_layer.linear.bias", "proj_out.bias"), ("final_layer.linear.weight", "proj_out.weight"), } MMDIT_MAP_BLOCK = { ("context_block.adaLN_modulation.1.bias", "norm1_context.linear.bias"), ("context_block.adaLN_modulation.1.weight", "norm1_context.linear.weight"), ("context_block.attn.proj.bias", "attn.to_add_out.bias"), ("context_block.attn.proj.weight", "attn.to_add_out.weight"), ("context_block.mlp.fc1.bias", "ff_context.net.0.proj.bias"), ("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"), ("context_block.mlp.fc2.bias", "ff_context.net.2.bias"), ("context_block.mlp.fc2.weight", "ff_context.net.2.weight"), ("context_block.attn.ln_q.weight", "attn.norm_added_q.weight"), ("context_block.attn.ln_k.weight", "attn.norm_added_k.weight"), ("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"), ("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"), ("x_block.attn.proj.bias", "attn.to_out.0.bias"), ("x_block.attn.proj.weight", "attn.to_out.0.weight"), ("x_block.attn.ln_q.weight", "attn.norm_q.weight"), ("x_block.attn.ln_k.weight", "attn.norm_k.weight"), ("x_block.attn2.proj.bias", "attn2.to_out.0.bias"), ("x_block.attn2.proj.weight", "attn2.to_out.0.weight"), ("x_block.attn2.ln_q.weight", "attn2.norm_q.weight"), ("x_block.attn2.ln_k.weight", "attn2.norm_k.weight"), ("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"), ("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"), ("x_block.mlp.fc2.bias", "ff.net.2.bias"), ("x_block.mlp.fc2.weight", "ff.net.2.weight"), } def mmdit_to_diffusers(mmdit_config, output_prefix=""): key_map = {} depth = mmdit_config.get("depth", 0) num_blocks = mmdit_config.get("num_blocks", depth) for i in range(num_blocks): block_from = "transformer_blocks.{}".format(i) block_to = "{}joint_blocks.{}".format(output_prefix, i) offset = depth * 64 for end in ("weight", "bias"): k = "{}.attn.".format(block_from) qkv = "{}.x_block.attn.qkv.{}".format(block_to, end) key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset)) key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset)) key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset)) qkv = "{}.context_block.attn.qkv.{}".format(block_to, end) key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, offset)) key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset)) key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset)) k = "{}.attn2.".format(block_from) qkv = "{}.x_block.attn2.qkv.{}".format(block_to, end) key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset)) key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset)) key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset)) for k in MMDIT_MAP_BLOCK: key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0]) map_basic = MMDIT_MAP_BASIC.copy() map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.bias".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.bias".format(depth - 1), swap_scale_shift)) map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.weight".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.weight".format(depth - 1), swap_scale_shift)) for k in map_basic: if len(k) > 2: key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2]) else: key_map[k[1]] = "{}{}".format(output_prefix, k[0]) return key_map PIXART_MAP_BASIC = { ("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"), ("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"), ("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"), ("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"), ("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"), ("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"), ("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"), ("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"), ("x_embedder.proj.weight", "pos_embed.proj.weight"), ("x_embedder.proj.bias", "pos_embed.proj.bias"), ("y_embedder.y_embedding", "caption_projection.y_embedding"), ("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"), ("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"), ("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"), ("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"), ("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"), ("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"), ("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"), ("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"), ("t_block.1.weight", "adaln_single.linear.weight"), ("t_block.1.bias", "adaln_single.linear.bias"), ("final_layer.linear.weight", "proj_out.weight"), ("final_layer.linear.bias", "proj_out.bias"), ("final_layer.scale_shift_table", "scale_shift_table"), } PIXART_MAP_BLOCK = { ("scale_shift_table", "scale_shift_table"), ("attn.proj.weight", "attn1.to_out.0.weight"), ("attn.proj.bias", "attn1.to_out.0.bias"), ("mlp.fc1.weight", "ff.net.0.proj.weight"), ("mlp.fc1.bias", "ff.net.0.proj.bias"), ("mlp.fc2.weight", "ff.net.2.weight"), ("mlp.fc2.bias", "ff.net.2.bias"), ("cross_attn.proj.weight" ,"attn2.to_out.0.weight"), ("cross_attn.proj.bias" ,"attn2.to_out.0.bias"), } def pixart_to_diffusers(mmdit_config, output_prefix=""): key_map = {} depth = mmdit_config.get("depth", 0) offset = mmdit_config.get("hidden_size", 1152) for i in range(depth): block_from = "transformer_blocks.{}".format(i) block_to = "{}blocks.{}".format(output_prefix, i) for end in ("weight", "bias"): s = "{}.attn1.".format(block_from) qkv = "{}.attn.qkv.{}".format(block_to, end) key_map["{}to_q.{}".format(s, end)] = (qkv, (0, 0, offset)) key_map["{}to_k.{}".format(s, end)] = (qkv, (0, offset, offset)) key_map["{}to_v.{}".format(s, end)] = (qkv, (0, offset * 2, offset)) s = "{}.attn2.".format(block_from) q = "{}.cross_attn.q_linear.{}".format(block_to, end) kv = "{}.cross_attn.kv_linear.{}".format(block_to, end) key_map["{}to_q.{}".format(s, end)] = q key_map["{}to_k.{}".format(s, end)] = (kv, (0, 0, offset)) key_map["{}to_v.{}".format(s, end)] = (kv, (0, offset, offset)) for k in PIXART_MAP_BLOCK: key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0]) for k in PIXART_MAP_BASIC: key_map[k[1]] = "{}{}".format(output_prefix, k[0]) return key_map def auraflow_to_diffusers(mmdit_config, output_prefix=""): n_double_layers = mmdit_config.get("n_double_layers", 0) n_layers = mmdit_config.get("n_layers", 0) key_map = {} for i in range(n_layers): if i < n_double_layers: index = i prefix_from = "joint_transformer_blocks" prefix_to = "{}double_layers".format(output_prefix) block_map = { "attn.to_q.weight": "attn.w2q.weight", "attn.to_k.weight": "attn.w2k.weight", "attn.to_v.weight": "attn.w2v.weight", "attn.to_out.0.weight": "attn.w2o.weight", "attn.add_q_proj.weight": "attn.w1q.weight", "attn.add_k_proj.weight": "attn.w1k.weight", "attn.add_v_proj.weight": "attn.w1v.weight", "attn.to_add_out.weight": "attn.w1o.weight", "ff.linear_1.weight": "mlpX.c_fc1.weight", "ff.linear_2.weight": "mlpX.c_fc2.weight", "ff.out_projection.weight": "mlpX.c_proj.weight", "ff_context.linear_1.weight": "mlpC.c_fc1.weight", "ff_context.linear_2.weight": "mlpC.c_fc2.weight", "ff_context.out_projection.weight": "mlpC.c_proj.weight", "norm1.linear.weight": "modX.1.weight", "norm1_context.linear.weight": "modC.1.weight", } else: index = i - n_double_layers prefix_from = "single_transformer_blocks" prefix_to = "{}single_layers".format(output_prefix) block_map = { "attn.to_q.weight": "attn.w1q.weight", "attn.to_k.weight": "attn.w1k.weight", "attn.to_v.weight": "attn.w1v.weight", "attn.to_out.0.weight": "attn.w1o.weight", "norm1.linear.weight": "modCX.1.weight", "ff.linear_1.weight": "mlp.c_fc1.weight", "ff.linear_2.weight": "mlp.c_fc2.weight", "ff.out_projection.weight": "mlp.c_proj.weight" } for k in block_map: key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k]) MAP_BASIC = { ("positional_encoding", "pos_embed.pos_embed"), ("register_tokens", "register_tokens"), ("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"), ("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"), ("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"), ("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"), ("cond_seq_linear.weight", "context_embedder.weight"), ("init_x_linear.weight", "pos_embed.proj.weight"), ("init_x_linear.bias", "pos_embed.proj.bias"), ("final_linear.weight", "proj_out.weight"), ("modF.1.weight", "norm_out.linear.weight", swap_scale_shift), } for k in MAP_BASIC: if len(k) > 2: key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2]) else: key_map[k[1]] = "{}{}".format(output_prefix, k[0]) return key_map def flux_to_diffusers(mmdit_config, output_prefix=""): n_double_layers = mmdit_config.get("depth", 0) n_single_layers = mmdit_config.get("depth_single_blocks", 0) hidden_size = mmdit_config.get("hidden_size", 0) key_map = {} for index in range(n_double_layers): prefix_from = "transformer_blocks.{}".format(index) prefix_to = "{}double_blocks.{}".format(output_prefix, index) for end in ("weight", "bias"): k = "{}.attn.".format(prefix_from) qkv = "{}.img_attn.qkv.{}".format(prefix_to, end) key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) k = "{}.attn.".format(prefix_from) qkv = "{}.txt_attn.qkv.{}".format(prefix_to, end) key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) block_map = { "attn.to_out.0.weight": "img_attn.proj.weight", "attn.to_out.0.bias": "img_attn.proj.bias", "norm1.linear.weight": "img_mod.lin.weight", "norm1.linear.bias": "img_mod.lin.bias", "norm1_context.linear.weight": "txt_mod.lin.weight", "norm1_context.linear.bias": "txt_mod.lin.bias", "attn.to_add_out.weight": "txt_attn.proj.weight", "attn.to_add_out.bias": "txt_attn.proj.bias", "ff.net.0.proj.weight": "img_mlp.0.weight", "ff.net.0.proj.bias": "img_mlp.0.bias", "ff.net.2.weight": "img_mlp.2.weight", "ff.net.2.bias": "img_mlp.2.bias", "ff_context.net.0.proj.weight": "txt_mlp.0.weight", "ff_context.net.0.proj.bias": "txt_mlp.0.bias", "ff_context.net.2.weight": "txt_mlp.2.weight", "ff_context.net.2.bias": "txt_mlp.2.bias", "ff.linear_in.weight": "img_mlp.0.weight", # LyCoris LoKr "ff.linear_in.bias": "img_mlp.0.bias", "ff.linear_out.weight": "img_mlp.2.weight", "ff.linear_out.bias": "img_mlp.2.bias", "ff_context.linear_in.weight": "txt_mlp.0.weight", "ff_context.linear_in.bias": "txt_mlp.0.bias", "ff_context.linear_out.weight": "txt_mlp.2.weight", "ff_context.linear_out.bias": "txt_mlp.2.bias", "attn.norm_q.weight": "img_attn.norm.query_norm.weight", "attn.norm_k.weight": "img_attn.norm.key_norm.weight", "attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight", "attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight", } for k in block_map: key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) for index in range(n_single_layers): prefix_from = "single_transformer_blocks.{}".format(index) prefix_to = "{}single_blocks.{}".format(output_prefix, index) for end in ("weight", "bias"): k = "{}.attn.".format(prefix_from) qkv = "{}.linear1.{}".format(prefix_to, end) key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4)) block_map = { "norm.linear.weight": "modulation.lin.weight", "norm.linear.bias": "modulation.lin.bias", "proj_out.weight": "linear2.weight", "proj_out.bias": "linear2.bias", "attn.norm_q.weight": "norm.query_norm.weight", "attn.norm_k.weight": "norm.key_norm.weight", "attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2 "attn.to_out.weight": "linear2.weight", # Flux 2 } for k in block_map: key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) MAP_BASIC = { ("final_layer.linear.bias", "proj_out.bias"), ("final_layer.linear.weight", "proj_out.weight"), ("img_in.bias", "x_embedder.bias"), ("img_in.weight", "x_embedder.weight"), ("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"), ("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"), ("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"), ("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"), ("txt_in.bias", "context_embedder.bias"), ("txt_in.weight", "context_embedder.weight"), ("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"), ("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"), ("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"), ("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"), ("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"), ("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"), ("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"), ("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"), ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift), ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift), ("pos_embed_input.bias", "controlnet_x_embedder.bias"), ("pos_embed_input.weight", "controlnet_x_embedder.weight"), } for k in MAP_BASIC: if len(k) > 2: key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2]) else: key_map[k[1]] = "{}{}".format(output_prefix, k[0]) return key_map def z_image_to_diffusers(mmdit_config, output_prefix=""): n_layers = mmdit_config.get("n_layers", 0) hidden_size = mmdit_config.get("dim", 0) n_context_refiner = mmdit_config.get("n_refiner_layers", 2) n_noise_refiner = mmdit_config.get("n_refiner_layers", 2) key_map = {} def add_block_keys(prefix_from, prefix_to, has_adaln=True): for end in ("weight", "bias"): k = "{}.attention.".format(prefix_from) qkv = "{}.attention.qkv.{}".format(prefix_to, end) key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) block_map = { "attention.norm_q.weight": "attention.q_norm.weight", "attention.norm_k.weight": "attention.k_norm.weight", "attention.to_out.0.weight": "attention.out.weight", "attention.to_out.0.bias": "attention.out.bias", "attention_norm1.weight": "attention_norm1.weight", "attention_norm2.weight": "attention_norm2.weight", "feed_forward.w1.weight": "feed_forward.w1.weight", "feed_forward.w2.weight": "feed_forward.w2.weight", "feed_forward.w3.weight": "feed_forward.w3.weight", "ffn_norm1.weight": "ffn_norm1.weight", "ffn_norm2.weight": "ffn_norm2.weight", } if has_adaln: block_map["adaLN_modulation.0.weight"] = "adaLN_modulation.0.weight" block_map["adaLN_modulation.0.bias"] = "adaLN_modulation.0.bias" for k, v in block_map.items(): key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, v) for i in range(n_layers): add_block_keys("layers.{}".format(i), "{}layers.{}".format(output_prefix, i)) for i in range(n_context_refiner): add_block_keys("context_refiner.{}".format(i), "{}context_refiner.{}".format(output_prefix, i)) for i in range(n_noise_refiner): add_block_keys("noise_refiner.{}".format(i), "{}noise_refiner.{}".format(output_prefix, i)) MAP_BASIC = [ ("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"), ("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"), ("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"), ("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"), ("x_embedder.weight", "all_x_embedder.2-1.weight"), ("x_embedder.bias", "all_x_embedder.2-1.bias"), ("x_pad_token", "x_pad_token"), ("cap_embedder.0.weight", "cap_embedder.0.weight"), ("cap_embedder.1.weight", "cap_embedder.1.weight"), ("cap_embedder.1.bias", "cap_embedder.1.bias"), ("cap_pad_token", "cap_pad_token"), ("t_embedder.mlp.0.weight", "t_embedder.mlp.0.weight"), ("t_embedder.mlp.0.bias", "t_embedder.mlp.0.bias"), ("t_embedder.mlp.2.weight", "t_embedder.mlp.2.weight"), ("t_embedder.mlp.2.bias", "t_embedder.mlp.2.bias"), ] for c, diffusers in MAP_BASIC: key_map[diffusers] = "{}{}".format(output_prefix, c) return key_map def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: return tensor.narrow(dim, 0, batch_size) elif tensor.shape[dim] < batch_size: return tensor.repeat(dim * [1] + [math.ceil(batch_size / tensor.shape[dim])] + [1] * (len(tensor.shape) - 1 - dim)).narrow(dim, 0, batch_size) return tensor def resize_to_batch_size(tensor, batch_size): in_batch_size = tensor.shape[0] if in_batch_size == batch_size: return tensor if batch_size <= 1: return tensor[:batch_size] output = torch.empty([batch_size] + list(tensor.shape)[1:], dtype=tensor.dtype, device=tensor.device) if batch_size < in_batch_size: scale = (in_batch_size - 1) / (batch_size - 1) for i in range(batch_size): output[i] = tensor[min(round(i * scale), in_batch_size - 1)] else: scale = in_batch_size / batch_size for i in range(batch_size): output[i] = tensor[min(math.floor((i + 0.5) * scale), in_batch_size - 1)] return output def resize_list_to_batch_size(l, batch_size): in_batch_size = len(l) if in_batch_size == batch_size or in_batch_size == 0: return l if batch_size <= 1: return l[:batch_size] output = [] if batch_size < in_batch_size: scale = (in_batch_size - 1) / (batch_size - 1) for i in range(batch_size): output.append(l[min(round(i * scale), in_batch_size - 1)]) else: scale = in_batch_size / batch_size for i in range(batch_size): output.append(l[min(math.floor((i + 0.5) * scale), in_batch_size - 1)]) return output def convert_sd_to(state_dict, dtype): keys = list(state_dict.keys()) for k in keys: state_dict[k] = state_dict[k].to(dtype) return state_dict def safetensors_header(safetensors_path, max_size=100*1024*1024): with open(safetensors_path, "rb") as f: header = f.read(8) length_of_header = struct.unpack(' max_size: return None return f.read(length_of_header) ATTR_UNSET={} def resolve_attr(obj, attr): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) return obj, attrs[-1] def set_attr(obj, attr, value): obj, name = resolve_attr(obj, attr) prev = getattr(obj, name, ATTR_UNSET) if value is ATTR_UNSET: delattr(obj, name) else: setattr(obj, name, value) return prev def set_attr_param(obj, attr, value): # Clone inference tensors (created under torch.inference_mode) since # their version counter is frozen and nn.Parameter() cannot wrap them. if (not torch.is_inference_mode_enabled()) and value.is_inference(): value = value.clone() return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) def set_attr_buffer(obj, attr, value): obj, name = resolve_attr(obj, attr) prev = getattr(obj, name, ATTR_UNSET) persistent = name not in getattr(obj, "_non_persistent_buffers_set", set()) obj.register_buffer(name, value, persistent=persistent) return prev def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) prev = getattr(obj, attrs[-1]) prev.data.copy_(value) def get_attr(obj, attr: str): """Retrieves a nested attribute from an object using dot notation. Args: obj: The object to get the attribute from attr (str): The attribute path using dot notation (e.g. "model.layer.weight") Returns: The value of the requested attribute Example: model = MyModel() weight = get_attr(model, "layer1.conv.weight") # Equivalent to: model.layer1.conv.weight Important: Always prefer `comfy.model_patcher.ModelPatcher.get_model_object` when accessing nested model objects under `ModelPatcher.model`. """ attrs = attr.split(".") for name in attrs: obj = getattr(obj, name) return obj def bislerp(samples, width, height): def slerp(b1, b2, r): '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' c = b1.shape[-1] #norms b1_norms = torch.norm(b1, dim=-1, keepdim=True) b2_norms = torch.norm(b2, dim=-1, keepdim=True) #normalize b1_normalized = b1 / b1_norms b2_normalized = b2 / b2_norms #zero when norms are zero b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0 b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0 #slerp dot = (b1_normalized*b2_normalized).sum(1) omega = torch.acos(dot) so = torch.sin(omega) #technically not mathematically correct, but more pleasing? res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c) #edge cases for same or polar opposites res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] return res def generate_bilinear_data(length_old, length_new, device): coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") ratios = coords_1 - coords_1.floor() coords_1 = coords_1.to(torch.int64) coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1 coords_2[:,:,:,-1] -= 1 coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") coords_2 = coords_2.to(torch.int64) return ratios, coords_1, coords_2 orig_dtype = samples.dtype samples = samples.float() n,c,h,w = samples.shape h_new, w_new = (height, width) #linear w ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device) coords_1 = coords_1.expand((n, c, h, -1)) coords_2 = coords_2.expand((n, c, h, -1)) ratios = ratios.expand((n, 1, h, -1)) pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c)) pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c)) ratios = ratios.movedim(1, -1).reshape((-1,1)) result = slerp(pass_1, pass_2, ratios) result = result.reshape(n, h, w_new, c).movedim(-1, 1) #linear h ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device) coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c)) pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c)) ratios = ratios.movedim(1, -1).reshape((-1,1)) result = slerp(pass_1, pass_2, ratios) result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) return result.to(orig_dtype) def lanczos(samples, width, height): #the below API is strict and expects grayscale to be squeezed if samples.ndim == 4: samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1) images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] images = [torch.from_numpy(t).movedim(-1, 0) if (t := np.array(image).astype(np.float32) / 255.0).ndim == 3 else torch.from_numpy(t) for image in images] result = torch.stack(images) return result.to(samples.device, samples.dtype) def common_upscale(samples, width, height, upscale_method, crop): orig_shape = tuple(samples.shape) if len(orig_shape) > 4: samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1]) samples = samples.movedim(2, 1) samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1]) if crop == "center": old_width = samples.shape[-1] old_height = samples.shape[-2] old_aspect = old_width / old_height new_aspect = width / height x = 0 y = 0 if old_aspect > new_aspect: x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) elif old_aspect < new_aspect: y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2) else: s = samples if upscale_method == "bislerp": out = bislerp(s, width, height) elif upscale_method == "lanczos": out = lanczos(s, width, height) else: out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) if len(orig_shape) == 4: return out out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width)) return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width)) def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap)) cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap)) return rows * cols @torch.inference_mode() def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None): dims = len(tile) if not (isinstance(upscale_amount, (tuple, list))): upscale_amount = [upscale_amount] * dims if not (isinstance(overlap, (tuple, list))): overlap = [overlap] * dims if index_formulas is None: index_formulas = upscale_amount if not (isinstance(index_formulas, (tuple, list))): index_formulas = [index_formulas] * dims def get_upscale(dim, val): up = upscale_amount[dim] if callable(up): return up(val) else: return up * val def get_downscale(dim, val): up = upscale_amount[dim] if callable(up): return up(val) else: return val / up def get_upscale_pos(dim, val): up = index_formulas[dim] if callable(up): return up(val) else: return up * val def get_downscale_pos(dim, val): up = index_formulas[dim] if callable(up): return up(val) else: return val / up if downscale: get_scale = get_downscale get_pos = get_downscale_pos else: get_scale = get_upscale get_pos = get_upscale_pos def mult_list_upscale(a): out = [] for i in range(len(a)): out.append(round(get_scale(i, a[i]))) return out output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device) for b in range(samples.shape[0]): s = samples[b:b+1] # handle entire input fitting in a single tile if all(s.shape[d+2] <= tile[d] for d in range(dims)): output[b:b+1] = function(s).to(output_device) if pbar is not None: pbar.update(1) continue out = output[b:b+1].zero_() out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device) positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] for it in itertools.product(*positions): s_in = s upscaled = [] for d in range(dims): pos = max(0, min(s.shape[d + 2] - overlap[d], it[d])) l = min(tile[d], s.shape[d + 2] - pos) s_in = s_in.narrow(d + 2, pos, l) upscaled.append(round(get_pos(d, pos))) ps = function(s_in).to(output_device) mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device) for d in range(2, dims + 2): feather = round(get_scale(d - 2, overlap[d - 2])) if feather >= mask.shape[d]: continue for t in range(feather): a = (t + 1) / feather mask.narrow(d, t, 1).mul_(a) mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a) o = out o_d = out_div ps_view = ps mask_view = mask for d in range(dims): l = min(ps_view.shape[d + 2], o.shape[d + 2] - upscaled[d]) o = o.narrow(d + 2, upscaled[d], l) o_d = o_d.narrow(d + 2, upscaled[d], l) if l < ps_view.shape[d + 2]: ps_view = ps_view.narrow(d + 2, 0, l) mask_view = mask_view.narrow(d + 2, 0, l) o.add_(ps_view * mask_view) o_d.add_(mask_view) if pbar is not None: pbar.update(1) out.div_(out_div) return output def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None): """Multigpu variant of tiled_scale_multidim. ``functions`` is a dict[torch.device, callable]. Round-robin dispatches tile positions across devices via threading. Each thread maintains its own per-device CPU output and divisor buffer, applying the same feathered overlap mask formula as the single-device path. Buffers are summed at the end, producing output that is bit-equivalent to ``tiled_scale_multidim`` within fp32 add-order noise. Falls back to ``tiled_scale_multidim`` with the only function when ``len(functions) < 2``. Falls back to single-device on the "whole input fits in one tile" branch (no parallelism available at that granularity). """ devices = list(functions.keys()) if len(devices) < 2: only_fn = next(iter(functions.values())) if functions else None return tiled_scale_multidim(samples, only_fn, tile=tile, overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, downscale=downscale, index_formulas=index_formulas, pbar=pbar) dims = len(tile) if not (isinstance(upscale_amount, (tuple, list))): upscale_amount = [upscale_amount] * dims if not (isinstance(overlap, (tuple, list))): overlap = [overlap] * dims if index_formulas is None: index_formulas = upscale_amount if not (isinstance(index_formulas, (tuple, list))): index_formulas = [index_formulas] * dims def get_upscale(dim, val): up = upscale_amount[dim] return up(val) if callable(up) else up * val def get_downscale(dim, val): up = upscale_amount[dim] return up(val) if callable(up) else val / up def get_upscale_pos(dim, val): up = index_formulas[dim] return up(val) if callable(up) else up * val def get_downscale_pos(dim, val): up = index_formulas[dim] return up(val) if callable(up) else val / up if downscale: get_scale = get_downscale get_pos = get_downscale_pos else: get_scale = get_upscale get_pos = get_upscale_pos def mult_list_upscale(a): return [round(get_scale(i, a[i])) for i in range(len(a))] output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device) merge_device = torch.device("cpu") pbar_lock = threading.Lock() if pbar is not None else None primary_device = devices[0] samples_staged = samples if samples.device.type == "cpu" else samples.to("cpu", non_blocking=False) for b in range(samples_staged.shape[0]): s = samples_staged[b:b+1] if all(s.shape[d+2] <= tile[d] for d in range(dims)): with torch.inference_mode(): output[b:b+1] = functions[primary_device](s.to(primary_device, non_blocking=True)).to(output_device) if pbar is not None: pbar.update(1) continue positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] split = {devices[i]: itertools.islice(itertools.product(*positions), i, None, len(devices)) for i in range(len(devices))} out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]) div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:]) bufs = {d: torch.zeros(out_shape, device=merge_device) for d in devices} divs = {d: torch.zeros(div_shape, device=merge_device) for d in devices} worker_errors: list[BaseException] = [] worker_lock = threading.Lock() def worker(device, my_positions): try: if device.type == "cuda": torch.cuda.set_device(device) fn = functions[device] local_buf = bufs[device] local_div = divs[device] with torch.inference_mode(): for it in my_positions: s_in = s upscaled = [] for d in range(dims): pos = max(0, min(s.shape[d + 2] - overlap[d], it[d])) l = min(tile[d], s.shape[d + 2] - pos) s_in = s_in.narrow(d + 2, pos, l) upscaled.append(round(get_pos(d, pos))) s_in_dev = s_in.to(device, non_blocking=True) ps = fn(s_in_dev).to(merge_device) mask = torch.ones([1, 1] + list(ps.shape[2:]), device=merge_device) for d in range(2, dims + 2): feather = round(get_scale(d - 2, overlap[d - 2])) if feather >= mask.shape[d]: continue for t in range(feather): a = (t + 1) / feather mask.narrow(d, t, 1).mul_(a) mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a) o = local_buf o_d = local_div ps_view = ps mask_view = mask for d in range(dims): l = min(ps_view.shape[d + 2], o.shape[d + 2] - upscaled[d]) o = o.narrow(d + 2, upscaled[d], l) o_d = o_d.narrow(d + 2, upscaled[d], l) if l < ps_view.shape[d + 2]: ps_view = ps_view.narrow(d + 2, 0, l) mask_view = mask_view.narrow(d + 2, 0, l) o.add_(ps_view * mask_view) o_d.add_(mask_view) if pbar is not None: with pbar_lock: pbar.update(1) if device.type == "cuda": torch.cuda.synchronize(device) except BaseException as e: with worker_lock: worker_errors.append(e) threads = [threading.Thread(target=worker, args=(d, split[d])) for d in devices] for t in threads: t.start() for t in threads: t.join() if worker_errors: raise worker_errors[0] combined_buf = sum(bufs.values()) combined_div = sum(divs.values()) output[b:b+1] = combined_buf / combined_div return output def model_trange(*args, **kwargs): if not comfy.memory_management.aimdo_enabled: return trange(*args, **kwargs) pbar = trange(*args, **kwargs, smoothing=1.0) pbar._i = 0 pbar.set_postfix_str(" Model Initializing ... ") _update = pbar.update def warmup_update(n=1): pbar._i += 1 if pbar._i == 1: pbar.i1_time = time.time() pbar.set_postfix_str(" Model Initialization complete! ") elif pbar._i == 2: #bring forward the effective start time based the diff between first and second iteration #to attempt to remove load overhead from the final step rate estimate. pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time) pbar.set_postfix_str("") _update(n) pbar.update = warmup_update return pbar PROGRESS_BAR_ENABLED = True def set_progress_bar_enabled(enabled): global PROGRESS_BAR_ENABLED PROGRESS_BAR_ENABLED = enabled PROGRESS_BAR_HOOK = None def set_progress_bar_global_hook(function): global PROGRESS_BAR_HOOK PROGRESS_BAR_HOOK = function # Throttle settings for progress bar updates to reduce WebSocket flooding PROGRESS_THROTTLE_MIN_INTERVAL = 0.1 # 100ms minimum between updates PROGRESS_THROTTLE_MIN_PERCENT = 0.5 # 0.5% minimum progress change class ProgressBar: def __init__(self, total, node_id=None): global PROGRESS_BAR_HOOK self.total = total self.current = 0 self.hook = PROGRESS_BAR_HOOK self.node_id = node_id self._last_update_time = 0.0 self._last_sent_value = -1 def update_absolute(self, value, total=None, preview=None): if total is not None: self.total = total if value > self.total: value = self.total self.current = value if self.hook is not None: current_time = time.perf_counter() is_first = (self._last_sent_value < 0) is_final = (value >= self.total) has_preview = (preview is not None) # Always send immediately for previews, first update, or final update if has_preview or is_first or is_final: self.hook(self.current, self.total, preview, node_id=self.node_id) self._last_update_time = current_time self._last_sent_value = value return # Apply throttling for regular progress updates if self.total > 0: percent_changed = ((value - max(0, self._last_sent_value)) / self.total) * 100 else: percent_changed = 100 time_elapsed = current_time - self._last_update_time if time_elapsed >= PROGRESS_THROTTLE_MIN_INTERVAL and percent_changed >= PROGRESS_THROTTLE_MIN_PERCENT: self.hook(self.current, self.total, preview, node_id=self.node_id) self._last_update_time = current_time self._last_sent_value = value def update(self, value): self.update_absolute(self.current + value) def reshape_mask(input_mask, output_shape): dims = len(output_shape) - 2 if dims == 1: scale_mode = "linear" if dims == 2: input_mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1])) scale_mode = "bilinear" if dims == 3: if len(input_mask.shape) < 5: input_mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1])) scale_mode = "trilinear" mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode) if mask.shape[1] < output_shape[1]: mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]] mask = repeat_to_batch_size(mask, output_shape[0]) return mask def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out): hi, wi = img_size_in ho, wo = img_size_out # if it's already the correct size, no need to do anything if (hi, wi) == (ho, wo): return mask if mask.ndim == 2: mask = mask.unsqueeze(0) if mask.ndim != 3: raise ValueError(f"Got a mask of shape {list(mask.shape)}, expected [b, q, k] or [q, k]") txt_tokens = mask.shape[1] - (hi * wi) # quadrants of the mask txt_to_txt = mask[:, :txt_tokens, :txt_tokens] txt_to_img = mask[:, :txt_tokens, txt_tokens:] img_to_img = mask[:, txt_tokens:, txt_tokens:] img_to_txt = mask[:, txt_tokens:, :txt_tokens] # convert to 1d x 2d, interpolate, then back to 1d x 1d txt_to_img = rearrange (txt_to_img, "b t (h w) -> b t h w", h=hi, w=wi) txt_to_img = interpolate(txt_to_img, size=img_size_out, mode="bilinear") txt_to_img = rearrange (txt_to_img, "b t h w -> b t (h w)") # this one is hard because we have to do it twice # convert to 1d x 2d, interpolate, then to 2d x 1d, interpolate, then 1d x 1d img_to_img = rearrange (img_to_img, "b hw (h w) -> b hw h w", h=hi, w=wi) img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear") img_to_img = rearrange (img_to_img, "b (hk wk) hq wq -> b (hq wq) hk wk", hk=hi, wk=wi) img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear") img_to_img = rearrange (img_to_img, "b (hq wq) hk wk -> b (hk wk) (hq wq)", hq=ho, wq=wo) # convert to 2d x 1d, interpolate, then back to 1d x 1d img_to_txt = rearrange (img_to_txt, "b (h w) t -> b t h w", h=hi, w=wi) img_to_txt = interpolate(img_to_txt, size=img_size_out, mode="bilinear") img_to_txt = rearrange (img_to_txt, "b t h w -> b (h w) t") # reassemble the mask from blocks out = torch.cat([ torch.cat([txt_to_txt, txt_to_img], dim=2), torch.cat([img_to_txt, img_to_img], dim=2)], dim=1 ) return out def pack_latents(latents): latent_shapes = [] tensors = [] for tensor in latents: latent_shapes.append(tensor.shape) tensors.append(tensor.reshape(tensor.shape[0], 1, -1)) latent = torch.cat(tensors, dim=-1) return latent, latent_shapes def unpack_latents(combined_latent, latent_shapes): if len(latent_shapes) > 1: output_tensors = [] for shape in latent_shapes: cut = math.prod(shape[1:]) tens = combined_latent[:, :, :cut] combined_latent = combined_latent[:, :, cut:] output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:])) else: output_tensors = [combined_latent] return output_tensors def detect_layer_quantization(state_dict, prefix): for k in state_dict: if k.startswith(prefix) and k.endswith(".comfy_quant"): logging.info("Found quantization metadata version 1") return {"mixed_ops": True} return None def convert_old_quants(state_dict, model_prefix="", metadata={}): if metadata is None: metadata = {} quant_metadata = None if "_quantization_metadata" not in metadata: scaled_fp8_key = "{}scaled_fp8".format(model_prefix) if scaled_fp8_key in state_dict: scaled_fp8_weight = state_dict[scaled_fp8_key] scaled_fp8_dtype = scaled_fp8_weight.dtype if scaled_fp8_dtype == torch.float32: scaled_fp8_dtype = torch.float8_e4m3fn if scaled_fp8_weight.nelement() == 2: full_precision_matrix_mult = True else: full_precision_matrix_mult = False out_sd = {} layers = {} for k in list(state_dict.keys()): if k == scaled_fp8_key: continue if not k.startswith(model_prefix): out_sd[k] = state_dict[k] continue k_out = k w = state_dict.pop(k) layer = None if k_out.endswith(".scale_weight"): layer = k_out[:-len(".scale_weight")] k_out = "{}.weight_scale".format(layer) if layer is not None: layer_conf = {"format": "float8_e4m3fn"} if full_precision_matrix_mult: layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult layers[layer] = layer_conf if k_out.endswith(".scale_input"): layer = k_out[:-len(".scale_input")] k_out = "{}.input_scale".format(layer) if w.item() == 1.0: continue out_sd[k_out] = w state_dict = out_sd quant_metadata = {"layers": layers} else: quant_metadata = json.loads(metadata["_quantization_metadata"]) if quant_metadata is not None: layers = quant_metadata["layers"] for k, v in layers.items(): state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8) return state_dict, metadata def string_to_seed(data): crc = 0xFFFFFFFF for byte in data: if isinstance(byte, str): byte = ord(byte) crc ^= byte for _ in range(8): if crc & 1: crc = (crc >> 1) ^ 0xEDB88320 else: crc >>= 1 return crc ^ 0xFFFFFFFF def deepcopy_list_dict(obj, memo=None): if memo is None: memo = {} obj_id = id(obj) if obj_id in memo: return memo[obj_id] if isinstance(obj, dict): res = {deepcopy_list_dict(k, memo): deepcopy_list_dict(v, memo) for k, v in obj.items()} elif isinstance(obj, list): res = [deepcopy_list_dict(i, memo) for i in obj] else: res = obj memo[obj_id] = res return res