mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-13 21:27:41 +08:00
AITemplate uses main sampling function
This commit is contained in:
parent
7e4da3c48a
commit
fbc74fbb25
@ -84,7 +84,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
|||||||
|
|
||||||
models = load_additional_models(positive, negative)
|
models = load_additional_models(positive, negative)
|
||||||
|
|
||||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options, aitemplate=aitemplate, cfg=cfg)
|
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options, aitemplate=aitemplate)
|
||||||
|
|
||||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar)
|
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar)
|
||||||
samples = samples.cpu()
|
samples = samples.cpu()
|
||||||
|
|||||||
@ -493,42 +493,27 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
|
|||||||
|
|
||||||
return conds
|
return conds
|
||||||
|
|
||||||
class AITemplateModelWrapper:
|
class AITemplateModelWrapper(torch.nn.Module):
|
||||||
def __init__(self, unet_ait_exe, alphas_cumprod, guidance_scale):
|
def __init__(self, unet_ait_exe, alphas_cumprod, conditioning_key=None):
|
||||||
|
super().__init__()
|
||||||
self.unet_ait_exe = unet_ait_exe
|
self.unet_ait_exe = unet_ait_exe
|
||||||
self.alphas_cumprod = alphas_cumprod
|
self.alphas_cumprod = alphas_cumprod
|
||||||
self.guidance_scale = guidance_scale
|
#TODO: use the conditioning key
|
||||||
|
self.conditioning_key = conditioning_key
|
||||||
|
|
||||||
def apply_model(self, *args, **kwargs):
|
def apply_model(self, x, t, cond):
|
||||||
if len(args) == 3:
|
timesteps_pt = t
|
||||||
# unsure if this path is ever used
|
latent_model_input = x
|
||||||
cond = args[-1]
|
encoder_hidden_states = None
|
||||||
args = args[:2]
|
#TODO: verify this is correct/match DiffusionWrapper (ddpm.py)
|
||||||
if kwargs.get("cond", None) is not None:
|
if 'c_crossattn' in cond:
|
||||||
cond = kwargs.pop("cond")
|
encoder_hidden_states = cond['c_crossattn']
|
||||||
cond = cond[0][0]
|
if 'c_concat' in cond:
|
||||||
if kwargs.get("uncond", None) is not None:
|
encoder_hidden_states = cond['c_concat']
|
||||||
uncond = kwargs.pop("uncond")
|
if encoder_hidden_states is None:
|
||||||
uncond = uncond[0][0]
|
raise f"conditioning missing, it should be one of these {cond.keys()}"
|
||||||
if uncond is not None and cond is not None:
|
if type(encoder_hidden_states) is list:
|
||||||
if cond.shape[1] > uncond.shape[1]:
|
encoder_hidden_states = encoder_hidden_states[0]
|
||||||
to_add = cond.shape[1] - uncond.shape[1]
|
|
||||||
padding = torch.zeros((uncond.shape[0], to_add, uncond.shape[2]), device=uncond.device)
|
|
||||||
uncond = torch.cat((uncond, padding), 1)
|
|
||||||
elif uncond.shape[1] > cond.shape[1]:
|
|
||||||
to_add = uncond.shape[1] - cond.shape[1]
|
|
||||||
padding = torch.zeros((cond.shape[0], to_add, cond.shape[2]), device=cond.device)
|
|
||||||
cond = torch.cat((cond, padding), 1)
|
|
||||||
encoder_hidden_states = torch.cat((uncond, cond))
|
|
||||||
elif cond is not None and uncond is None:
|
|
||||||
encoder_hidden_states = torch.cat((cond, cond))
|
|
||||||
elif uncond is not None and cond is None:
|
|
||||||
encoder_hidden_states = torch.cat((uncond, uncond))
|
|
||||||
else:
|
|
||||||
raise Exception("Must provide uncond or cond")
|
|
||||||
latent_model_input, timesteps = args
|
|
||||||
timesteps_pt = torch.cat([timesteps] * 2)
|
|
||||||
latent_model_input = torch.cat([latent_model_input] * 2)
|
|
||||||
height = latent_model_input.shape[2]
|
height = latent_model_input.shape[2]
|
||||||
width = latent_model_input.shape[3]
|
width = latent_model_input.shape[3]
|
||||||
|
|
||||||
@ -550,8 +535,6 @@ class AITemplateModelWrapper:
|
|||||||
ys.append(torch.empty(shape).cuda().half())
|
ys.append(torch.empty(shape).cuda().half())
|
||||||
self.unet_ait_exe.run_with_tensors(inputs, ys, graph_mode=False)
|
self.unet_ait_exe.run_with_tensors(inputs, ys, graph_mode=False)
|
||||||
noise_pred = ys[0].permute((0, 3, 1, 2)).float()
|
noise_pred = ys[0].permute((0, 3, 1, 2)).float()
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
||||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
class KSampler:
|
class KSampler:
|
||||||
@ -560,11 +543,12 @@ class KSampler:
|
|||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"]
|
"dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}, aitemplate=None, cfg=None):
|
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}, aitemplate=None):
|
||||||
self.model = model
|
self.model = model
|
||||||
if aitemplate:
|
if aitemplate:
|
||||||
alphas_cumprod = model.alphas_cumprod.to(device)
|
alphas_cumprod = model.alphas_cumprod.to(device)
|
||||||
self.model_denoise = AITemplateModelWrapper(aitemplate, alphas_cumprod, cfg)
|
self.aitemplate_wrapper = AITemplateModelWrapper(aitemplate, alphas_cumprod)
|
||||||
|
self.model_denoise = CFGNoisePredictor(self.aitemplate_wrapper)
|
||||||
else:
|
else:
|
||||||
self.model_denoise = CFGNoisePredictor(self.model)
|
self.model_denoise = CFGNoisePredictor(self.model)
|
||||||
if self.model.parameterization == "v":
|
if self.model.parameterization == "v":
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user