mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-12 20:57:47 +08:00
AITemplate uses main sampling function
This commit is contained in:
parent
7e4da3c48a
commit
fbc74fbb25
@ -84,7 +84,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
||||
|
||||
models = load_additional_models(positive, negative)
|
||||
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options, aitemplate=aitemplate, cfg=cfg)
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options, aitemplate=aitemplate)
|
||||
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar)
|
||||
samples = samples.cpu()
|
||||
|
||||
@ -493,42 +493,27 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
|
||||
|
||||
return conds
|
||||
|
||||
class AITemplateModelWrapper:
|
||||
def __init__(self, unet_ait_exe, alphas_cumprod, guidance_scale):
|
||||
class AITemplateModelWrapper(torch.nn.Module):
|
||||
def __init__(self, unet_ait_exe, alphas_cumprod, conditioning_key=None):
|
||||
super().__init__()
|
||||
self.unet_ait_exe = unet_ait_exe
|
||||
self.alphas_cumprod = alphas_cumprod
|
||||
self.guidance_scale = guidance_scale
|
||||
#TODO: use the conditioning key
|
||||
self.conditioning_key = conditioning_key
|
||||
|
||||
def apply_model(self, *args, **kwargs):
|
||||
if len(args) == 3:
|
||||
# unsure if this path is ever used
|
||||
cond = args[-1]
|
||||
args = args[:2]
|
||||
if kwargs.get("cond", None) is not None:
|
||||
cond = kwargs.pop("cond")
|
||||
cond = cond[0][0]
|
||||
if kwargs.get("uncond", None) is not None:
|
||||
uncond = kwargs.pop("uncond")
|
||||
uncond = uncond[0][0]
|
||||
if uncond is not None and cond is not None:
|
||||
if cond.shape[1] > uncond.shape[1]:
|
||||
to_add = cond.shape[1] - uncond.shape[1]
|
||||
padding = torch.zeros((uncond.shape[0], to_add, uncond.shape[2]), device=uncond.device)
|
||||
uncond = torch.cat((uncond, padding), 1)
|
||||
elif uncond.shape[1] > cond.shape[1]:
|
||||
to_add = uncond.shape[1] - cond.shape[1]
|
||||
padding = torch.zeros((cond.shape[0], to_add, cond.shape[2]), device=cond.device)
|
||||
cond = torch.cat((cond, padding), 1)
|
||||
encoder_hidden_states = torch.cat((uncond, cond))
|
||||
elif cond is not None and uncond is None:
|
||||
encoder_hidden_states = torch.cat((cond, cond))
|
||||
elif uncond is not None and cond is None:
|
||||
encoder_hidden_states = torch.cat((uncond, uncond))
|
||||
else:
|
||||
raise Exception("Must provide uncond or cond")
|
||||
latent_model_input, timesteps = args
|
||||
timesteps_pt = torch.cat([timesteps] * 2)
|
||||
latent_model_input = torch.cat([latent_model_input] * 2)
|
||||
def apply_model(self, x, t, cond):
|
||||
timesteps_pt = t
|
||||
latent_model_input = x
|
||||
encoder_hidden_states = None
|
||||
#TODO: verify this is correct/match DiffusionWrapper (ddpm.py)
|
||||
if 'c_crossattn' in cond:
|
||||
encoder_hidden_states = cond['c_crossattn']
|
||||
if 'c_concat' in cond:
|
||||
encoder_hidden_states = cond['c_concat']
|
||||
if encoder_hidden_states is None:
|
||||
raise f"conditioning missing, it should be one of these {cond.keys()}"
|
||||
if type(encoder_hidden_states) is list:
|
||||
encoder_hidden_states = encoder_hidden_states[0]
|
||||
height = latent_model_input.shape[2]
|
||||
width = latent_model_input.shape[3]
|
||||
|
||||
@ -550,8 +535,6 @@ class AITemplateModelWrapper:
|
||||
ys.append(torch.empty(shape).cuda().half())
|
||||
self.unet_ait_exe.run_with_tensors(inputs, ys, graph_mode=False)
|
||||
noise_pred = ys[0].permute((0, 3, 1, 2)).float()
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
return noise_pred
|
||||
|
||||
class KSampler:
|
||||
@ -560,11 +543,12 @@ class KSampler:
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}, aitemplate=None, cfg=None):
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}, aitemplate=None):
|
||||
self.model = model
|
||||
if aitemplate:
|
||||
alphas_cumprod = model.alphas_cumprod.to(device)
|
||||
self.model_denoise = AITemplateModelWrapper(aitemplate, alphas_cumprod, cfg)
|
||||
self.aitemplate_wrapper = AITemplateModelWrapper(aitemplate, alphas_cumprod)
|
||||
self.model_denoise = CFGNoisePredictor(self.aitemplate_wrapper)
|
||||
else:
|
||||
self.model_denoise = CFGNoisePredictor(self.model)
|
||||
if self.model.parameterization == "v":
|
||||
|
||||
Loading…
Reference in New Issue
Block a user