AITemplate ControlNet

This commit is contained in:
hlky 2023-06-01 13:14:27 +01:00
parent fbc74fbb25
commit 4b8a650932
4 changed files with 387 additions and 256 deletions

291
comfy/ckpt_convert.py Normal file
View File

@ -0,0 +1,291 @@
def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, additional_replacements=None
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
attention layers, and takes into account additional replacements that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
for path in paths:
new_path = path["new"]
# Global renaming happens here
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def convert_ldm_unet_checkpoint(unet_state_dict, layers_per_block=2, controlnet=False):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
temp = {}
if controlnet:
unet_key = "control_model."
else:
unet_key = "model.diffusion_model."
for key, value in unet_state_dict.items():
if key.startswith(unet_key):
key = key.replace(unet_key, "")
temp[key] = value
unet_state_dict = temp
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
if not controlnet:
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (layers_per_block + 1)
layer_in_block_id = (i - 1) % (layers_per_block + 1)
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path]
)
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path]
)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
resnet_0_paths = renew_resnet_paths(resnet_0)
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict)
resnet_1_paths = renew_resnet_paths(resnet_1)
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict)
attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path]
)
for i in range(num_output_blocks):
block_id = i // (layers_per_block + 1)
layer_in_block_id = i % (layers_per_block + 1)
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
output_block_list = {}
for layer in output_block_layers:
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1:
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
resnet_0_paths = renew_resnet_paths(resnets)
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path]
)
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path]
)
else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
new_checkpoint[new_path] = unet_state_dict[old_path]
if controlnet:
# conditioning embedding
orig_index = 0
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
orig_index += 2
diffusers_index = 0
while diffusers_index < 6:
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
diffusers_index += 1
orig_index += 2
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
# down blocks
for i in range(num_input_blocks):
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
# mid block
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
return new_checkpoint

View File

@ -505,11 +505,16 @@ class AITemplateModelWrapper(torch.nn.Module):
timesteps_pt = t timesteps_pt = t
latent_model_input = x latent_model_input = x
encoder_hidden_states = None encoder_hidden_states = None
down_block_residuals = None
mid_block_residual = None
#TODO: verify this is correct/match DiffusionWrapper (ddpm.py) #TODO: verify this is correct/match DiffusionWrapper (ddpm.py)
if 'c_crossattn' in cond: if 'c_crossattn' in cond:
encoder_hidden_states = cond['c_crossattn'] encoder_hidden_states = cond['c_crossattn']
if 'c_concat' in cond: if 'c_concat' in cond:
encoder_hidden_states = cond['c_concat'] encoder_hidden_states = cond['c_concat']
if "control" in cond:
down_block_residuals = cond["control"]["output"]
mid_block_residual = cond["control"]["middle"][0]
if encoder_hidden_states is None: if encoder_hidden_states is None:
raise f"conditioning missing, it should be one of these {cond.keys()}" raise f"conditioning missing, it should be one of these {cond.keys()}"
if type(encoder_hidden_states) is list: if type(encoder_hidden_states) is list:
@ -525,6 +530,10 @@ class AITemplateModelWrapper(torch.nn.Module):
"input1": timesteps_pt.cuda().half(), "input1": timesteps_pt.cuda().half(),
"input2": encoder_hidden_states.cuda().half(), "input2": encoder_hidden_states.cuda().half(),
} }
if down_block_residuals is not None and mid_block_residual is not None:
for i, y in enumerate(down_block_residuals):
inputs[f"down_block_residual_{i}"] = y.permute((0, 2, 3, 1)).contiguous().cuda().half()
inputs["mid_block_residual"] = mid_block_residual.permute((0, 2, 3, 1)).contiguous().cuda().half()
ys = [] ys = []
num_outputs = len(self.unet_ait_exe.get_output_name_to_index_map()) num_outputs = len(self.unet_ait_exe.get_output_name_to_index_map())
for i in range(num_outputs): for i in range(num_outputs):

View File

@ -601,6 +601,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
class ControlNet: class ControlNet:
def __init__(self, control_model, device=None): def __init__(self, control_model, device=None):
self.aitemplate = None
self.control_model = control_model self.control_model = control_model
self.cond_hint_original = None self.cond_hint_original = None
self.cond_hint = None self.cond_hint = None
@ -610,6 +611,31 @@ class ControlNet:
self.device = device self.device = device
self.previous_controlnet = None self.previous_controlnet = None
def aitemplate_controlnet(
self, latent_model_input, timesteps, encoder_hidden_states, controlnet_cond
):
if self.aitemplate is None:
raise RuntimeError("No aitemplate loaded")
batch = latent_model_input.shape[0]
timesteps_pt = timesteps.expand(batch * 2)
inputs = {
"input0": latent_model_input.permute((0, 2, 3, 1))
.contiguous()
.cuda()
.half(),
"input1": timesteps_pt.cuda().half(),
"input2": encoder_hidden_states.cuda().half(),
"input3": controlnet_cond.permute((0, 2, 3, 1)).contiguous().cuda().half(),
}
ys = []
num_outputs = len(self.aitemplate.get_output_name_to_index_map())
for i in range(num_outputs):
shape = self.aitemplate.get_output_maximum_shape(i)
ys.append(torch.empty(shape).cuda().half())
self.aitemplate.run_with_tensors(inputs, ys, graph_mode=False)
return ys
def get_control(self, x_noisy, t, cond_txt, batched_number): def get_control(self, x_noisy, t, cond_txt, batched_number):
control_prev = None control_prev = None
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
@ -623,16 +649,18 @@ class ControlNet:
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
if x_noisy.shape[0] != self.cond_hint.shape[0]: if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.aitemplate is None:
if self.control_model.dtype == torch.float16:
precision_scope = torch.autocast
else:
precision_scope = contextlib.nullcontext
if self.control_model.dtype == torch.float16: with precision_scope(model_management.get_autocast_device(self.device)):
precision_scope = torch.autocast self.control_model = model_management.load_if_low_vram(self.control_model)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
self.control_model = model_management.unload_if_low_vram(self.control_model)
else: else:
precision_scope = contextlib.nullcontext control = self.aitemplate_controlnet(x_noisy, t, cond_txt, self.cond_hint)
with precision_scope(model_management.get_autocast_device(self.device)):
self.control_model = model_management.load_if_low_vram(self.control_model)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
self.control_model = model_management.unload_if_low_vram(self.control_model)
out = {'middle':[], 'output': []} out = {'middle':[], 'output': []}
autocast_enabled = torch.is_autocast_enabled() autocast_enabled = torch.is_autocast_enabled()

299
nodes.py
View File

@ -16,6 +16,7 @@ import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
from comfy.aitemplate.model import Model from comfy.aitemplate.model import Model
from comfy.ckpt_convert import convert_ldm_unet_checkpoint
import comfy.diffusers_convert import comfy.diffusers_convert
import comfy.samplers import comfy.samplers
import comfy.sample import comfy.sample
@ -411,6 +412,53 @@ class CLIPSetLastLayer:
return (clip,) return (clip,)
class AITemplateControlNetLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "control_net": ("CONTROL_NET",),
"aitemplate_module": (folder_paths.get_filename_list("aitemplate"), ),
}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_aitemplate_controlnet"
CATEGORY = "loaders"
def load_aitemplate_controlnet(self, control_net, aitemplate_module):
aitemplate_path = folder_paths.get_full_path("aitemplate", aitemplate_module)
aitemplate = Model(aitemplate_path)
control_net_ait_params = self.map_controlnet_params(convert_ldm_unet_checkpoint(control_net.control_model.state_dict(), controlnet=True))
print("Setting constants")
aitemplate.set_many_constants_with_tensors(control_net_ait_params)
print("Folding constants")
aitemplate.fold_constants()
control_net.aitemplate = aitemplate
return (control_net,)
def map_controlnet_params(self, state_dict):
params_ait = {}
for key, arr in state_dict.items():
arr = arr.to("cuda", dtype=torch.float16)
if len(arr.shape) == 4:
arr = arr.permute((0, 2, 3, 1)).contiguous()
elif key.endswith("ff.net.0.proj.weight"):
w1, w2 = arr.chunk(2, dim=0)
params_ait[key.replace(".", "_")] = w1
params_ait[key.replace(".", "_").replace("proj", "gate")] = w2
continue
elif key.endswith("ff.net.0.proj.bias"):
w1, w2 = arr.chunk(2, dim=0)
params_ait[key.replace(".", "_")] = w1
params_ait[key.replace(".", "_").replace("proj", "gate")] = w2
continue
params_ait[key.replace(".", "_")] = arr
params_ait["controlnet_cond_embedding_conv_in_weight"] = torch.nn.functional.pad(
params_ait["controlnet_cond_embedding_conv_in_weight"], (0, 1, 0, 0, 0, 0, 0, 0)
)
params_ait["arange"] = (
torch.arange(start=0, end=320 // 2, dtype=torch.float32).cuda().half()
)
return params_ait
class AITemplateLoader: class AITemplateLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -425,259 +473,12 @@ class AITemplateLoader:
def load_aitemplate(self, model, aitemplate_module): def load_aitemplate(self, model, aitemplate_module):
aitemplate_path = folder_paths.get_full_path("aitemplate", aitemplate_module) aitemplate_path = folder_paths.get_full_path("aitemplate", aitemplate_module)
aitemplate = Model(aitemplate_path) aitemplate = Model(aitemplate_path)
unet_params_ait = self.map_unet_state_dict(self.convert_ldm_unet_checkpoint(model.model.state_dict())) unet_params_ait = self.map_unet_state_dict(convert_ldm_unet_checkpoint(model.model.state_dict()))
print("Setting constants") print("Setting constants")
aitemplate.set_many_constants_with_tensors(unet_params_ait) aitemplate.set_many_constants_with_tensors(unet_params_ait)
print("Folding constants") print("Folding constants")
aitemplate.fold_constants() aitemplate.fold_constants()
return ((aitemplate,model),) return ((aitemplate,model),)
#=================#
# UNet Conversion #
#=================#
def assign_to_checkpoint(
self, paths, checkpoint, old_checkpoint, additional_replacements=None
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
attention layers, and takes into account additional replacements that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
for path in paths:
new_path = path["new"]
# Global renaming happens here
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
def conv_attn_to_linear(self, checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def renew_attention_paths(self, old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def shave_segments(self, path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(self, old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = self.shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def convert_ldm_unet_checkpoint(self, unet_state_dict, layers_per_block=2):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
temp = {}
for key, value in unet_state_dict.items():
if key.startswith("model.diffusion_model."):
key = key.replace("model.diffusion_model.", "")
temp[key] = value
unet_state_dict = temp
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (layers_per_block + 1)
layer_in_block_id = (i - 1) % (layers_per_block + 1)
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
paths = self.renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
self.assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path]
)
if len(attentions):
paths = self.renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
self.assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path]
)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
resnet_0_paths = self.renew_resnet_paths(resnet_0)
self.assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict)
resnet_1_paths = self.renew_resnet_paths(resnet_1)
self.assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict)
attentions_paths = self.renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
self.assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path]
)
for i in range(num_output_blocks):
block_id = i // (layers_per_block + 1)
layer_in_block_id = i % (layers_per_block + 1)
output_block_layers = [self.shave_segments(name, 2) for name in output_blocks[i]]
output_block_list = {}
for layer in output_block_layers:
layer_id, layer_name = layer.split(".")[0], self.shave_segments(layer, 1)
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1:
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
resnet_0_paths = self.renew_resnet_paths(resnets)
paths = self.renew_resnet_paths(resnets)
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
self.assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path]
)
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
if len(attentions):
paths = self.renew_attention_paths(attentions)
meta_path = {
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
self.assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path]
)
else:
resnet_0_paths = self.renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
new_checkpoint[new_path] = unet_state_dict[old_path]
return new_checkpoint
#=========================# #=========================#
# AITemplate mapping # # AITemplate mapping #
@ -1593,6 +1394,7 @@ NODE_CLASS_MAPPINGS = {
"LatentCrop": LatentCrop, "LatentCrop": LatentCrop,
"LoraLoader": LoraLoader, "LoraLoader": LoraLoader,
"AITemplateLoader": AITemplateLoader, "AITemplateLoader": AITemplateLoader,
"AITemplateControlNetLoader": AITemplateControlNetLoader,
"CLIPLoader": CLIPLoader, "CLIPLoader": CLIPLoader,
"CLIPVisionEncode": CLIPVisionEncode, "CLIPVisionEncode": CLIPVisionEncode,
"StyleModelApply": StyleModelApply, "StyleModelApply": StyleModelApply,
@ -1627,6 +1429,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"VAELoader": "Load VAE", "VAELoader": "Load VAE",
"LoraLoader": "Load LoRA", "LoraLoader": "Load LoRA",
"AITemplateLoader": "Load AITemplate", "AITemplateLoader": "Load AITemplate",
"AITemplateControlNetLoader": "Load AITemplate (ControlNet)",
"CLIPLoader": "Load CLIP", "CLIPLoader": "Load CLIP",
"ControlNetLoader": "Load ControlNet Model", "ControlNetLoader": "Load ControlNet Model",
"DiffControlNetLoader": "Load ControlNet Model (diff)", "DiffControlNetLoader": "Load ControlNet Model (diff)",