From 6769b918fb53319a8eb1dff3d01510c44444033b Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Tue, 16 May 2023 18:44:51 +0100 Subject: [PATCH] Fix for negative prompt in AITemplateModelWrapper --- comfy/samplers.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 9b1ac7cf0..d57d1030b 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -501,13 +501,24 @@ class AITemplateModelWrapper: def apply_model(self, *args, **kwargs): if len(args) == 3: - encoder_hidden_states = args[-1] + # unsure if this path is ever used + cond = args[-1] args = args[:2] if kwargs.get("cond", None) is not None: - encoder_hidden_states = kwargs.pop("cond") + cond = kwargs.pop("cond") + cond = cond[0][0] + if kwargs.get("uncond", None) is not None: + uncond = kwargs.pop("uncond") + uncond = uncond[0][0] + if uncond is not None and cond is not None: + encoder_hidden_states = torch.cat((uncond, cond)) + elif cond is not None and uncond is None: + encoder_hidden_states = torch.cat((cond, cond)) + elif uncond is not None and cond is None: + encoder_hidden_states = torch.cat((uncond, uncond)) + else: + raise Exception("Must provide uncond or cond") latent_model_input, timesteps = args - encoder_hidden_states = encoder_hidden_states[0][0] - encoder_hidden_states = torch.cat([encoder_hidden_states] * 2) timesteps_pt = torch.cat([timesteps] * 2) latent_model_input = torch.cat([latent_model_input] * 2) height = latent_model_input.shape[2]