AITemplate uses main sampling function

This commit is contained in:
hlky 2023-05-31 20:52:20 +01:00
parent 7e4da3c48a
commit fbc74fbb25
2 changed files with 22 additions and 38 deletions

View File

@ -84,7 +84,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
models = load_additional_models(positive, negative) models = load_additional_models(positive, negative)
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options, aitemplate=aitemplate, cfg=cfg) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options, aitemplate=aitemplate)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar)
samples = samples.cpu() samples = samples.cpu()

View File

@ -493,42 +493,27 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
return conds return conds
class AITemplateModelWrapper: class AITemplateModelWrapper(torch.nn.Module):
def __init__(self, unet_ait_exe, alphas_cumprod, guidance_scale): def __init__(self, unet_ait_exe, alphas_cumprod, conditioning_key=None):
super().__init__()
self.unet_ait_exe = unet_ait_exe self.unet_ait_exe = unet_ait_exe
self.alphas_cumprod = alphas_cumprod self.alphas_cumprod = alphas_cumprod
self.guidance_scale = guidance_scale #TODO: use the conditioning key
self.conditioning_key = conditioning_key
def apply_model(self, *args, **kwargs): def apply_model(self, x, t, cond):
if len(args) == 3: timesteps_pt = t
# unsure if this path is ever used latent_model_input = x
cond = args[-1] encoder_hidden_states = None
args = args[:2] #TODO: verify this is correct/match DiffusionWrapper (ddpm.py)
if kwargs.get("cond", None) is not None: if 'c_crossattn' in cond:
cond = kwargs.pop("cond") encoder_hidden_states = cond['c_crossattn']
cond = cond[0][0] if 'c_concat' in cond:
if kwargs.get("uncond", None) is not None: encoder_hidden_states = cond['c_concat']
uncond = kwargs.pop("uncond") if encoder_hidden_states is None:
uncond = uncond[0][0] raise f"conditioning missing, it should be one of these {cond.keys()}"
if uncond is not None and cond is not None: if type(encoder_hidden_states) is list:
if cond.shape[1] > uncond.shape[1]: encoder_hidden_states = encoder_hidden_states[0]
to_add = cond.shape[1] - uncond.shape[1]
padding = torch.zeros((uncond.shape[0], to_add, uncond.shape[2]), device=uncond.device)
uncond = torch.cat((uncond, padding), 1)
elif uncond.shape[1] > cond.shape[1]:
to_add = uncond.shape[1] - cond.shape[1]
padding = torch.zeros((cond.shape[0], to_add, cond.shape[2]), device=cond.device)
cond = torch.cat((cond, padding), 1)
encoder_hidden_states = torch.cat((uncond, cond))
elif cond is not None and uncond is None:
encoder_hidden_states = torch.cat((cond, cond))
elif uncond is not None and cond is None:
encoder_hidden_states = torch.cat((uncond, uncond))
else:
raise Exception("Must provide uncond or cond")
latent_model_input, timesteps = args
timesteps_pt = torch.cat([timesteps] * 2)
latent_model_input = torch.cat([latent_model_input] * 2)
height = latent_model_input.shape[2] height = latent_model_input.shape[2]
width = latent_model_input.shape[3] width = latent_model_input.shape[3]
@ -550,8 +535,6 @@ class AITemplateModelWrapper:
ys.append(torch.empty(shape).cuda().half()) ys.append(torch.empty(shape).cuda().half())
self.unet_ait_exe.run_with_tensors(inputs, ys, graph_mode=False) self.unet_ait_exe.run_with_tensors(inputs, ys, graph_mode=False)
noise_pred = ys[0].permute((0, 3, 1, 2)).float() noise_pred = ys[0].permute((0, 3, 1, 2)).float()
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
return noise_pred return noise_pred
class KSampler: class KSampler:
@ -560,11 +543,12 @@ class KSampler:
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
"dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"] "dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"]
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}, aitemplate=None, cfg=None): def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}, aitemplate=None):
self.model = model self.model = model
if aitemplate: if aitemplate:
alphas_cumprod = model.alphas_cumprod.to(device) alphas_cumprod = model.alphas_cumprod.to(device)
self.model_denoise = AITemplateModelWrapper(aitemplate, alphas_cumprod, cfg) self.aitemplate_wrapper = AITemplateModelWrapper(aitemplate, alphas_cumprod)
self.model_denoise = CFGNoisePredictor(self.aitemplate_wrapper)
else: else:
self.model_denoise = CFGNoisePredictor(self.model) self.model_denoise = CFGNoisePredictor(self.model)
if self.model.parameterization == "v": if self.model.parameterization == "v":