Pass AITemplate, MODEL to AITemplate KSampler, -diffusers req

* Pass MODEL to AITemplate KSampler, but don't move to device
* Take alphas_cumprod from MODEL, -diffusers req, move alphas_cumprod only to device
* -checks for AITemplateModelWrapper, inpaint etc maybe still won't work, untested
  * v2 should work, but will require module compiled for v2
This commit is contained in:
hlky 2023-05-16 07:15:12 +01:00
parent 22ad299546
commit c8195ab3e5
4 changed files with 13 additions and 27 deletions

View File

@ -74,7 +74,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
real_model = None real_model = None
if aitemplate is None: if aitemplate is None:
comfy.model_management.load_model_gpu(model) comfy.model_management.load_model_gpu(model)
real_model = model.model real_model = model.model
noise = noise.to(device) noise = noise.to(device)
latent_image = latent_image.to(device) latent_image = latent_image.to(device)
@ -84,7 +84,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
models = load_additional_models(positive, negative) models = load_additional_models(positive, negative)
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options if aitemplate is None else None, aitemplate=aitemplate, cfg=cfg) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options, aitemplate=aitemplate, cfg=cfg)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar)
samples = samples.cpu() samples = samples.cpu()

View File

@ -3,7 +3,6 @@ from .k_diffusion import external as k_diffusion_external
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
import torch import torch
import contextlib import contextlib
from diffusers import LMSDiscreteScheduler
from comfy import model_management from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
@ -546,25 +545,15 @@ class KSampler:
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}, aitemplate=None, cfg=None): def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}, aitemplate=None, cfg=None):
self.model = model self.model = model
if aitemplate: if aitemplate:
scheduler = LMSDiscreteScheduler.from_config({ alphas_cumprod = model.alphas_cumprod.to(device)
"beta_end": 0.012, self.model_denoise = AITemplateModelWrapper(aitemplate, alphas_cumprod, cfg)
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"num_train_timesteps": 1000,
"set_alpha_to_one": False,
"skip_prk_steps": True,
"steps_offset": 1,
"trained_betas": None,
"clip_sample": False
})
self.model_denoise = AITemplateModelWrapper(aitemplate, scheduler.alphas_cumprod, cfg)
else: else:
self.model_denoise = CFGNoisePredictor(self.model) self.model_denoise = CFGNoisePredictor(self.model)
if not isinstance(self.model_denoise, AITemplateModelWrapper) and self.model.parameterization == "v": if self.model.parameterization == "v":
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True) self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
self.model_wrap.parameterization = self.model.parameterization
else: else:
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True) self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
self.model_wrap.parameterization = self.model.parameterization
self.model_k = KSamplerX0Inpaint(self.model_wrap) self.model_k = KSamplerX0Inpaint(self.model_wrap)
self.device = device self.device = device
if scheduler not in self.SCHEDULERS: if scheduler not in self.SCHEDULERS:
@ -646,21 +635,19 @@ class KSampler:
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if isinstance(self.model_denoise, AITemplateModelWrapper): if self.model.model.diffusion_model.dtype == torch.float16:
precision_scope = torch.autocast
elif self.model.model.diffusion_model.dtype == torch.float16:
precision_scope = torch.autocast precision_scope = torch.autocast
else: else:
precision_scope = contextlib.nullcontext precision_scope = contextlib.nullcontext
if not isinstance(self.model_denoise, AITemplateModelWrapper) and hasattr(self.model, 'noise_augmentor'): #unclip if hasattr(self.model, 'noise_augmentor'): #unclip
positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device) positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device)
negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device) negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
cond_concat = None cond_concat = None
if not isinstance(self.model_denoise, AITemplateModelWrapper) and hasattr(self.model, 'concat_keys'): #inpaint if hasattr(self.model, 'concat_keys'): #inpaint
cond_concat = [] cond_concat = []
for ck in self.model.concat_keys: for ck in self.model.concat_keys:
if denoise_mask is not None: if denoise_mask is not None:

View File

@ -350,13 +350,12 @@ class AITemplateLoader:
def load_aitemplate(self, model, aitemplate_module): def load_aitemplate(self, model, aitemplate_module):
aitemplate_path = folder_paths.get_full_path("aitemplate", aitemplate_module) aitemplate_path = folder_paths.get_full_path("aitemplate", aitemplate_module)
aitemplate = Model(aitemplate_path) aitemplate = Model(aitemplate_path)
model = self.convert_ldm_unet_checkpoint(model.model.state_dict()) unet_params_ait = self.map_unet_state_dict(self.convert_ldm_unet_checkpoint(model.model.state_dict()))
unet_params_ait = self.map_unet_state_dict(model)
print("Setting constants") print("Setting constants")
aitemplate.set_many_constants_with_tensors(unet_params_ait) aitemplate.set_many_constants_with_tensors(unet_params_ait)
print("Folding constants") print("Folding constants")
aitemplate.fold_constants() aitemplate.fold_constants()
return (aitemplate,) return ((aitemplate,model),)
#=================# #=================#
# UNet Conversion # # UNet Conversion #
#=================# #=================#
@ -1201,7 +1200,8 @@ class KSamplerAITemplate:
CATEGORY = "sampling" CATEGORY = "sampling"
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, aitemplate=model) aitemplate, model = model
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, aitemplate=aitemplate)
class KSamplerAdvanced: class KSamplerAdvanced:

View File

@ -9,4 +9,3 @@ pytorch_lightning
aiohttp aiohttp
accelerate accelerate
pyyaml pyyaml
diffusers