diff --git a/comfy/sample.py b/comfy/sample.py index fb85e3d15..7cbc1744b 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -84,7 +84,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative models = load_additional_models(positive, negative) - sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options, aitemplate=aitemplate, cfg=cfg) + sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options, aitemplate=aitemplate) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar) samples = samples.cpu() diff --git a/comfy/samplers.py b/comfy/samplers.py index 14589456b..04193f9bc 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -493,42 +493,27 @@ def encode_adm(noise_augmentor, conds, batch_size, device): return conds -class AITemplateModelWrapper: - def __init__(self, unet_ait_exe, alphas_cumprod, guidance_scale): +class AITemplateModelWrapper(torch.nn.Module): + def __init__(self, unet_ait_exe, alphas_cumprod, conditioning_key=None): + super().__init__() self.unet_ait_exe = unet_ait_exe self.alphas_cumprod = alphas_cumprod - self.guidance_scale = guidance_scale + #TODO: use the conditioning key + self.conditioning_key = conditioning_key - def apply_model(self, *args, **kwargs): - if len(args) == 3: - # unsure if this path is ever used - cond = args[-1] - args = args[:2] - if kwargs.get("cond", None) is not None: - cond = kwargs.pop("cond") - cond = cond[0][0] - if kwargs.get("uncond", None) is not None: - uncond = kwargs.pop("uncond") - uncond = uncond[0][0] - if uncond is not None and cond is not None: - if cond.shape[1] > uncond.shape[1]: - to_add = cond.shape[1] - uncond.shape[1] - padding = torch.zeros((uncond.shape[0], to_add, uncond.shape[2]), device=uncond.device) - uncond = torch.cat((uncond, padding), 1) - elif uncond.shape[1] > cond.shape[1]: - to_add = uncond.shape[1] - cond.shape[1] - padding = torch.zeros((cond.shape[0], to_add, cond.shape[2]), device=cond.device) - cond = torch.cat((cond, padding), 1) - encoder_hidden_states = torch.cat((uncond, cond)) - elif cond is not None and uncond is None: - encoder_hidden_states = torch.cat((cond, cond)) - elif uncond is not None and cond is None: - encoder_hidden_states = torch.cat((uncond, uncond)) - else: - raise Exception("Must provide uncond or cond") - latent_model_input, timesteps = args - timesteps_pt = torch.cat([timesteps] * 2) - latent_model_input = torch.cat([latent_model_input] * 2) + def apply_model(self, x, t, cond): + timesteps_pt = t + latent_model_input = x + encoder_hidden_states = None + #TODO: verify this is correct/match DiffusionWrapper (ddpm.py) + if 'c_crossattn' in cond: + encoder_hidden_states = cond['c_crossattn'] + if 'c_concat' in cond: + encoder_hidden_states = cond['c_concat'] + if encoder_hidden_states is None: + raise f"conditioning missing, it should be one of these {cond.keys()}" + if type(encoder_hidden_states) is list: + encoder_hidden_states = encoder_hidden_states[0] height = latent_model_input.shape[2] width = latent_model_input.shape[3] @@ -550,8 +535,6 @@ class AITemplateModelWrapper: ys.append(torch.empty(shape).cuda().half()) self.unet_ait_exe.run_with_tensors(inputs, ys, graph_mode=False) noise_pred = ys[0].permute((0, 3, 1, 2)).float() - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) return noise_pred class KSampler: @@ -560,11 +543,12 @@ class KSampler: "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"] - def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}, aitemplate=None, cfg=None): + def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}, aitemplate=None): self.model = model if aitemplate: alphas_cumprod = model.alphas_cumprod.to(device) - self.model_denoise = AITemplateModelWrapper(aitemplate, alphas_cumprod, cfg) + self.aitemplate_wrapper = AITemplateModelWrapper(aitemplate, alphas_cumprod) + self.model_denoise = CFGNoisePredictor(self.aitemplate_wrapper) else: self.model_denoise = CFGNoisePredictor(self.model) if self.model.parameterization == "v":