mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
LoRA fix
This commit is contained in:
parent
da1ab7b298
commit
1144d45027
@ -510,6 +510,7 @@ class AITemplateModelWrapper(torch.nn.Module):
|
|||||||
#TODO: verify this is correct/match DiffusionWrapper (ddpm.py)
|
#TODO: verify this is correct/match DiffusionWrapper (ddpm.py)
|
||||||
if 'c_crossattn' in cond:
|
if 'c_crossattn' in cond:
|
||||||
encoder_hidden_states = cond['c_crossattn']
|
encoder_hidden_states = cond['c_crossattn']
|
||||||
|
encoder_hidden_states = torch.cat(encoder_hidden_states, 1)
|
||||||
if 'c_concat' in cond:
|
if 'c_concat' in cond:
|
||||||
encoder_hidden_states = cond['c_concat']
|
encoder_hidden_states = cond['c_concat']
|
||||||
if "control" in cond:
|
if "control" in cond:
|
||||||
@ -517,8 +518,6 @@ class AITemplateModelWrapper(torch.nn.Module):
|
|||||||
mid_block_residual = cond["control"]["middle"][0]
|
mid_block_residual = cond["control"]["middle"][0]
|
||||||
if encoder_hidden_states is None:
|
if encoder_hidden_states is None:
|
||||||
raise f"conditioning missing, it should be one of these {cond.keys()}"
|
raise f"conditioning missing, it should be one of these {cond.keys()}"
|
||||||
if type(encoder_hidden_states) is list:
|
|
||||||
encoder_hidden_states = encoder_hidden_states[0]
|
|
||||||
height = latent_model_input.shape[2]
|
height = latent_model_input.shape[2]
|
||||||
width = latent_model_input.shape[3]
|
width = latent_model_input.shape[3]
|
||||||
|
|
||||||
|
|||||||
73
nodes.py
73
nodes.py
@ -472,15 +472,86 @@ class AITemplateLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_aitemplate(self, model, aitemplate_module):
|
def load_aitemplate(self, model, aitemplate_module):
|
||||||
|
state_dict = self.patch_model(model.model, model.patches)
|
||||||
aitemplate_path = folder_paths.get_full_path("aitemplate", aitemplate_module)
|
aitemplate_path = folder_paths.get_full_path("aitemplate", aitemplate_module)
|
||||||
aitemplate = Model(aitemplate_path)
|
aitemplate = Model(aitemplate_path)
|
||||||
unet_params_ait = self.map_unet_state_dict(convert_ldm_unet_checkpoint(model.model.state_dict()))
|
unet_params_ait = self.map_unet_state_dict(convert_ldm_unet_checkpoint(state_dict))
|
||||||
print("Setting constants")
|
print("Setting constants")
|
||||||
aitemplate.set_many_constants_with_tensors(unet_params_ait)
|
aitemplate.set_many_constants_with_tensors(unet_params_ait)
|
||||||
print("Folding constants")
|
print("Folding constants")
|
||||||
aitemplate.fold_constants()
|
aitemplate.fold_constants()
|
||||||
return ((aitemplate,model),)
|
return ((aitemplate,model),)
|
||||||
|
|
||||||
|
def patch_model(self, model, patches):
|
||||||
|
model_sd = model.state_dict()
|
||||||
|
for p in patches:
|
||||||
|
for k in p[1]:
|
||||||
|
v = p[1][k]
|
||||||
|
key = k
|
||||||
|
if key not in model_sd:
|
||||||
|
print("could not patch. key doesn't exist in model:", k)
|
||||||
|
continue
|
||||||
|
|
||||||
|
weight = model_sd[key]
|
||||||
|
|
||||||
|
alpha = p[0]
|
||||||
|
|
||||||
|
if len(v) == 4: #lora/locon
|
||||||
|
mat1 = v[0]
|
||||||
|
mat2 = v[1]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha *= v[2] / mat2.shape[0]
|
||||||
|
if v[3] is not None:
|
||||||
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
|
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
|
||||||
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
|
||||||
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
||||||
|
elif len(v) == 8: #lokr
|
||||||
|
w1 = v[0]
|
||||||
|
w2 = v[1]
|
||||||
|
w1_a = v[3]
|
||||||
|
w1_b = v[4]
|
||||||
|
w2_a = v[5]
|
||||||
|
w2_b = v[6]
|
||||||
|
t2 = v[7]
|
||||||
|
dim = None
|
||||||
|
|
||||||
|
if w1 is None:
|
||||||
|
dim = w1_b.shape[0]
|
||||||
|
w1 = torch.mm(w1_a.float(), w1_b.float())
|
||||||
|
|
||||||
|
if w2 is None:
|
||||||
|
dim = w2_b.shape[0]
|
||||||
|
if t2 is None:
|
||||||
|
w2 = torch.mm(w2_a.float(), w2_b.float())
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float())
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
if v[2] is not None and dim is not None:
|
||||||
|
alpha *= v[2] / dim
|
||||||
|
|
||||||
|
weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
||||||
|
else: #loha
|
||||||
|
w1a = v[0]
|
||||||
|
w1b = v[1]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha *= v[2] / w1b.shape[0]
|
||||||
|
w2a = v[3]
|
||||||
|
w2b = v[4]
|
||||||
|
if v[5] is not None: #cp decomposition
|
||||||
|
t1 = v[5]
|
||||||
|
t2 = v[6]
|
||||||
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float(), w1b.float(), w1a.float())
|
||||||
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2b.float(), w2a.float())
|
||||||
|
else:
|
||||||
|
m1 = torch.mm(w1a.float(), w1b.float())
|
||||||
|
m2 = torch.mm(w2a.float(), w2b.float())
|
||||||
|
|
||||||
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
||||||
|
return model_sd
|
||||||
|
|
||||||
#=========================#
|
#=========================#
|
||||||
# AITemplate mapping #
|
# AITemplate mapping #
|
||||||
#=========================#
|
#=========================#
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user