Fix for long prompts

This commit is contained in:
hlky 2023-05-16 20:55:19 +01:00
parent 6769b918fb
commit d879a33956

View File

@ -511,6 +511,14 @@ class AITemplateModelWrapper:
uncond = kwargs.pop("uncond")
uncond = uncond[0][0]
if uncond is not None and cond is not None:
if cond.shape[1] > uncond.shape[1]:
to_add = cond.shape[1] - uncond.shape[1]
padding = torch.zeros((uncond.shape[0], to_add, uncond.shape[2]), device=uncond.device)
uncond = torch.cat((uncond, padding), 1)
elif uncond.shape[1] > cond.shape[1]:
to_add = uncond.shape[1] - cond.shape[1]
padding = torch.zeros((cond.shape[0], to_add, cond.shape[2]), device=cond.device)
cond = torch.cat((cond, padding), 1)
encoder_hidden_states = torch.cat((uncond, cond))
elif cond is not None and uncond is None:
encoder_hidden_states = torch.cat((cond, cond))