Merge branch 'comfyanonymous:master' into feat/is_change_object_storage

This commit is contained in:
Dr.Lt.Data 2023-07-22 13:00:01 +09:00 committed by GitHub
commit 6f3bdb6e64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 129 additions and 74 deletions

View File

@ -118,3 +118,57 @@ def model_config_from_unet_config(unet_config):
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16): def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
return model_config_from_unet_config(unet_config) return model_config_from_unet_config(unet_config)
def model_config_from_diffusers_unet(state_dict, use_fp16):
match = {}
match["context_dim"] = state_dict["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
match["adm_in_channels"] = None
if "class_embedding.linear_1.weight" in state_dict:
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
elif "add_embedding.linear_1.weight" in state_dict:
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
for unet_config in supported_models:
matches = True
for k in match:
if match[k] != unet_config[k]:
matches = False
break
if matches:
return model_config_from_unet_config(unet_config)
return None

View File

@ -202,6 +202,14 @@ def model_lora_keys_unet(model, key_map={}):
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k]) key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
return key_map return key_map
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0): def __init__(self, model, load_device, offload_device, size=0):
self.size = size self.size = size
@ -340,10 +348,11 @@ class ModelPatcher:
weight = model_sd[key] weight = model_sd[key]
if key not in self.backup: if key not in self.backup:
self.backup[key] = weight.to(self.offload_device, copy=True) self.backup[key] = weight.to(self.offload_device)
temp_weight = weight.to(torch.float32, copy=True) temp_weight = weight.to(torch.float32, copy=True)
weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
set_attr(self.model, key, out_weight)
del temp_weight del temp_weight
return self.model return self.model
@ -439,13 +448,6 @@ class ModelPatcher:
def unpatch_model(self): def unpatch_model(self):
keys = list(self.backup.keys()) keys = list(self.backup.keys())
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev
for k in keys: for k in keys:
set_attr(self.model, k, self.backup[k]) set_attr(self.model, k, self.backup[k])
@ -763,6 +765,51 @@ class ControlNet:
def load_controlnet(ckpt_path, model=None): def load_controlnet(ckpt_path, model=None):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
controlnet_config = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
use_fp16 = model_management.should_use_fp16()
controlnet_config = model_detection.model_config_from_diffusers_unet(controlnet_data, use_fp16).unet_config
diffusers_keys = utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
count = 0
loop = True
while loop:
suffix = [".weight", ".bias"]
for s in suffix:
k_in = "controlnet_down_blocks.{}{}".format(count, s)
k_out = "zero_convs.{}.0{}".format(count, s)
if k_in not in controlnet_data:
loop = False
break
diffusers_keys[k_in] = k_out
count += 1
count = 0
loop = True
while loop:
suffix = [".weight", ".bias"]
for s in suffix:
if count == 0:
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
else:
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
k_out = "input_hint_block.{}{}".format(count * 2, s)
if k_in not in controlnet_data:
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
loop = False
diffusers_keys[k_in] = k_out
count += 1
new_sd = {}
for k in diffusers_keys:
if k in controlnet_data:
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
controlnet_data = new_sd
pth_key = 'control_model.zero_convs.0.0.weight' pth_key = 'control_model.zero_convs.0.0.weight'
pth = False pth = False
key = 'zero_convs.0.0.weight' key = 'zero_convs.0.0.weight'
@ -778,9 +825,9 @@ def load_controlnet(ckpt_path, model=None):
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
return net return net
use_fp16 = model_management.should_use_fp16() if controlnet_config is None:
use_fp16 = model_management.should_use_fp16()
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
controlnet_config.pop("out_channels") controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = 3 controlnet_config["hint_channels"] = 3
control_model = cldm.ControlNet(**controlnet_config) control_model = cldm.ControlNet(**controlnet_config)
@ -1138,69 +1185,24 @@ def load_unet(unet_path): #load unet in diffusers format
parameters = calculate_parameters(sd, "") parameters = calculate_parameters(sd, "")
fp16 = model_management.should_use_fp16(model_params=parameters) fp16 = model_management.should_use_fp16(model_params=parameters)
match = {} model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
match["context_dim"] = sd["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] if model_config is None:
match["model_channels"] = sd["conv_in.weight"].shape[0] print("ERROR UNSUPPORTED UNET", unet_path)
match["in_channels"] = sd["conv_in.weight"].shape[1] return None
match["adm_in_channels"] = None
if "class_embedding.linear_1.weight" in sd:
match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1]
elif "add_embedding.linear_1.weight" in sd:
match["adm_in_channels"] = sd["add_embedding.linear_1.weight"].shape[1]
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, diffusers_keys = utils.unet_to_diffusers(model_config.unet_config)
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, new_sd = {}
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 384, for k in diffusers_keys:
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], if k in sd:
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280} new_sd[diffusers_keys[k]] = sd.pop(k)
else:
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, print(diffusers_keys[k], k)
'adm_in_channels': None, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, offload_device = model_management.unet_offload_device()
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], model = model_config.get_model(new_sd, "")
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} model = model.to(offload_device)
model.load_model_weights(new_sd, "")
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
print("match", match)
for unet_config in supported_models:
matches = True
for k in match:
if match[k] != unet_config[k]:
matches = False
break
if matches:
diffusers_keys = utils.unet_to_diffusers(unet_config)
new_sd = {}
for k in diffusers_keys:
if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
print(diffusers_keys[k], k)
offload_device = model_management.unet_offload_device()
model_config = model_detection.model_config_from_unet_config(unet_config)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
print("ERROR UNSUPPORTED UNET", unet_path)
def save_checkpoint(output_path, model, clip, vae, metadata=None): def save_checkpoint(output_path, model, clip, vae, metadata=None):
try: try:

View File

@ -6,7 +6,6 @@ import threading
import heapq import heapq
import traceback import traceback
import gc import gc
import time
import torch import torch
import nodes import nodes