This commit is contained in:
hlky 2023-06-02 15:32:10 +01:00
parent da1ab7b298
commit 1144d45027
2 changed files with 73 additions and 3 deletions

View File

@ -510,6 +510,7 @@ class AITemplateModelWrapper(torch.nn.Module):
#TODO: verify this is correct/match DiffusionWrapper (ddpm.py) #TODO: verify this is correct/match DiffusionWrapper (ddpm.py)
if 'c_crossattn' in cond: if 'c_crossattn' in cond:
encoder_hidden_states = cond['c_crossattn'] encoder_hidden_states = cond['c_crossattn']
encoder_hidden_states = torch.cat(encoder_hidden_states, 1)
if 'c_concat' in cond: if 'c_concat' in cond:
encoder_hidden_states = cond['c_concat'] encoder_hidden_states = cond['c_concat']
if "control" in cond: if "control" in cond:
@ -517,8 +518,6 @@ class AITemplateModelWrapper(torch.nn.Module):
mid_block_residual = cond["control"]["middle"][0] mid_block_residual = cond["control"]["middle"][0]
if encoder_hidden_states is None: if encoder_hidden_states is None:
raise f"conditioning missing, it should be one of these {cond.keys()}" raise f"conditioning missing, it should be one of these {cond.keys()}"
if type(encoder_hidden_states) is list:
encoder_hidden_states = encoder_hidden_states[0]
height = latent_model_input.shape[2] height = latent_model_input.shape[2]
width = latent_model_input.shape[3] width = latent_model_input.shape[3]

View File

@ -472,15 +472,86 @@ class AITemplateLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_aitemplate(self, model, aitemplate_module): def load_aitemplate(self, model, aitemplate_module):
state_dict = self.patch_model(model.model, model.patches)
aitemplate_path = folder_paths.get_full_path("aitemplate", aitemplate_module) aitemplate_path = folder_paths.get_full_path("aitemplate", aitemplate_module)
aitemplate = Model(aitemplate_path) aitemplate = Model(aitemplate_path)
unet_params_ait = self.map_unet_state_dict(convert_ldm_unet_checkpoint(model.model.state_dict())) unet_params_ait = self.map_unet_state_dict(convert_ldm_unet_checkpoint(state_dict))
print("Setting constants") print("Setting constants")
aitemplate.set_many_constants_with_tensors(unet_params_ait) aitemplate.set_many_constants_with_tensors(unet_params_ait)
print("Folding constants") print("Folding constants")
aitemplate.fold_constants() aitemplate.fold_constants()
return ((aitemplate,model),) return ((aitemplate,model),)
def patch_model(self, model, patches):
model_sd = model.state_dict()
for p in patches:
for k in p[1]:
v = p[1][k]
key = k
if key not in model_sd:
print("could not patch. key doesn't exist in model:", k)
continue
weight = model_sd[key]
alpha = p[0]
if len(v) == 4: #lora/locon
mat1 = v[0]
mat2 = v[1]
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
elif len(v) == 8: #lokr
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(w1_a.float(), w1_b.float())
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(w2_a.float(), w2_b.float())
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float())
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device)
else: #loha
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha *= v[2] / w1b.shape[0]
w2a = v[3]
w2b = v[4]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float(), w1b.float(), w1a.float())
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2b.float(), w2a.float())
else:
m1 = torch.mm(w1a.float(), w1b.float())
m2 = torch.mm(w2a.float(), w2b.float())
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device)
return model_sd
#=========================# #=========================#
# AITemplate mapping # # AITemplate mapping #
#=========================# #=========================#