Fix for negative prompt in AITemplateModelWrapper

This commit is contained in:
hlky 2023-05-16 18:44:51 +01:00
parent f380b89a71
commit 6769b918fb

View File

@ -501,13 +501,24 @@ class AITemplateModelWrapper:
def apply_model(self, *args, **kwargs):
if len(args) == 3:
encoder_hidden_states = args[-1]
# unsure if this path is ever used
cond = args[-1]
args = args[:2]
if kwargs.get("cond", None) is not None:
encoder_hidden_states = kwargs.pop("cond")
cond = kwargs.pop("cond")
cond = cond[0][0]
if kwargs.get("uncond", None) is not None:
uncond = kwargs.pop("uncond")
uncond = uncond[0][0]
if uncond is not None and cond is not None:
encoder_hidden_states = torch.cat((uncond, cond))
elif cond is not None and uncond is None:
encoder_hidden_states = torch.cat((cond, cond))
elif uncond is not None and cond is None:
encoder_hidden_states = torch.cat((uncond, uncond))
else:
raise Exception("Must provide uncond or cond")
latent_model_input, timesteps = args
encoder_hidden_states = encoder_hidden_states[0][0]
encoder_hidden_states = torch.cat([encoder_hidden_states] * 2)
timesteps_pt = torch.cat([timesteps] * 2)
latent_model_input = torch.cat([latent_model_input] * 2)
height = latent_model_input.shape[2]