aitemplate

This commit is contained in:
hlky 2023-05-15 19:11:30 +01:00
parent 2ec6d1c6e3
commit b32c2eaafd
5 changed files with 424 additions and 16 deletions

View File

@ -65,15 +65,16 @@ def cleanup_additional_models(models):
for m in models: for m in models:
m.cleanup() m.cleanup()
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False): def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, aitemplate=None):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
if noise_mask is not None: if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise.shape, device) noise_mask = prepare_mask(noise_mask, noise.shape, device)
real_model = None real_model = None
comfy.model_management.load_model_gpu(model) if aitemplate is None:
real_model = model.model comfy.model_management.load_model_gpu(model)
real_model = model.model
noise = noise.to(device) noise = noise.to(device)
latent_image = latent_image.to(device) latent_image = latent_image.to(device)
@ -83,7 +84,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
models = load_additional_models(positive, negative) models = load_additional_models(positive, negative)
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options if aitemplate is None else None, aitemplate=aitemplate, cfg=cfg)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar)
samples = samples.cpu() samples = samples.cpu()

View File

@ -1,13 +1,15 @@
from .k_diffusion import sampling as k_diffusion_sampling from .k_diffusion import sampling as k_diffusion_sampling
from .k_diffusion import external as k_diffusion_external from .k_diffusion import external as k_diffusion_external
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
import os
import torch import torch
import contextlib import contextlib
from diffusers import LMSDiscreteScheduler
from comfy import model_management from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
import math import math
from aitemplate.compiler import Model
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b) return abs(a*b) // math.gcd(a, b)
@ -493,6 +495,61 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
return conds return conds
class AITemplateModelWrapper:
def __init__(self, unet_ait_exe, alphas_cumprod, guidance_scale):
self.unet_ait_exe = unet_ait_exe
self.alphas_cumprod = alphas_cumprod
self.guidance_scale = guidance_scale
def apply_model(self, *args, **kwargs):
if len(args) == 3:
encoder_hidden_states = args[-1]
args = args[:2]
if kwargs.get("cond", None) is not None:
encoder_hidden_states = kwargs.pop("cond")
encoder_hidden_states = encoder_hidden_states[0][0]
encoder_hidden_states = torch.cat([encoder_hidden_states] * 2)
latent_model_input, timesteps = args
timesteps_pt = timesteps.expand(2)
if latent_model_input.shape[0] < 2:
latent_model_input = torch.cat([latent_model_input] * 2)
height = latent_model_input.shape[2]
width = latent_model_input.shape[3]
inputs = {
"input0": latent_model_input.permute((0, 2, 3, 1))
.contiguous()
.cuda()
.half(),
"input1": timesteps_pt.cuda().half(),
"input2": encoder_hidden_states.cuda().half(),
}
ys = []
num_outputs = len(self.unet_ait_exe.get_output_name_to_index_map())
for i in range(num_outputs):
shape = self.unet_ait_exe.get_output_maximum_shape(i)
shape[0] = 2
shape[1] = height
shape[2] = width
ys.append(torch.empty(shape).cuda().half())
# print(inputs["input0"].shape)
# print(inputs["input1"].shape)
# print(inputs["input2"].shape)
# print(ys)
self.unet_ait_exe.run_with_tensors(inputs, ys, graph_mode=False)
noise_pred = ys[0].permute((0, 3, 1, 2)).float()
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
return noise_pred
def init_ait_module(
model_name,
workdir,
):
mod = Model(os.path.join(workdir, model_name, "test.so"))
return mod
class KSampler: class KSampler:
SCHEDULERS = ["normal", "karras", "simple", "ddim_uniform"] SCHEDULERS = ["normal", "karras", "simple", "ddim_uniform"]
@ -500,14 +557,28 @@ class KSampler:
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] "dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}, aitemplate=None, cfg=None):
self.model = model self.model = model
self.model_denoise = CFGNoisePredictor(self.model) if aitemplate:
if self.model.parameterization == "v": scheduler = LMSDiscreteScheduler.from_config({
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"num_train_timesteps": 1000,
"set_alpha_to_one": False,
"skip_prk_steps": True,
"steps_offset": 1,
"trained_betas": None,
"clip_sample": False
})
self.model_denoise = AITemplateModelWrapper(aitemplate, scheduler.alphas_cumprod, cfg)
else:
self.model_denoise = CFGNoisePredictor(self.model)
if not isinstance(self.model_denoise, AITemplateModelWrapper) and self.model.parameterization == "v":
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True) self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
self.model_wrap.parameterization = self.model.parameterization
else: else:
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True) self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
self.model_wrap.parameterization = self.model.parameterization
self.model_k = KSamplerX0Inpaint(self.model_wrap) self.model_k = KSamplerX0Inpaint(self.model_wrap)
self.device = device self.device = device
if scheduler not in self.SCHEDULERS: if scheduler not in self.SCHEDULERS:
@ -589,19 +660,21 @@ class KSampler:
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if self.model.model.diffusion_model.dtype == torch.float16: if isinstance(self.model_denoise, AITemplateModelWrapper):
precision_scope = torch.autocast
elif self.model.model.diffusion_model.dtype == torch.float16:
precision_scope = torch.autocast precision_scope = torch.autocast
else: else:
precision_scope = contextlib.nullcontext precision_scope = contextlib.nullcontext
if hasattr(self.model, 'noise_augmentor'): #unclip if not isinstance(self.model_denoise, AITemplateModelWrapper) and hasattr(self.model, 'noise_augmentor'): #unclip
positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device) positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device)
negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device) negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
cond_concat = None cond_concat = None
if hasattr(self.model, 'concat_keys'): #inpaint if not isinstance(self.model_denoise, AITemplateModelWrapper) and hasattr(self.model, 'concat_keys'): #inpaint
cond_concat = [] cond_concat = []
for ck in self.model.concat_keys: for ck in self.model.concat_keys:
if denoise_mask is not None: if denoise_mask is not None:

View File

@ -2,6 +2,7 @@ import os
supported_ckpt_extensions = set(['.ckpt', '.pth']) supported_ckpt_extensions = set(['.ckpt', '.pth'])
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth']) supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth'])
supported_ait_extensions = set(['.so'])
try: try:
import safetensors.torch import safetensors.torch
supported_ckpt_extensions.add('.safetensors') supported_ckpt_extensions.add('.safetensors')
@ -16,7 +17,7 @@ base_path = os.path.dirname(os.path.realpath(__file__))
models_dir = os.path.join(base_path, "models") models_dir = os.path.join(base_path, "models")
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions) folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions)
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"]) folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
folder_names_and_paths["aitemplate"] = ([os.path.join(models_dir, "aitemplate")], supported_ait_extensions)
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions) folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions) folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions) folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)

View File

339
nodes.py
View File

@ -7,7 +7,8 @@ import hashlib
import traceback import traceback
import math import math
import time import time
from aitemplate.compiler import Model
from diffusers import LMSDiscreteScheduler
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
import numpy as np import numpy as np
@ -334,6 +335,308 @@ class CLIPSetLastLayer:
clip.clip_layer(stop_at_clip_layer) clip.clip_layer(stop_at_clip_layer)
return (clip,) return (clip,)
class AITemplateLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"aitemplate_module": (folder_paths.get_filename_list("aitemplate"), ),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_aitemplate"
CATEGORY = "loaders"
def load_aitemplate(self, model, aitemplate_module):
aitemplate_path = folder_paths.get_full_path("aitemplate", aitemplate_module)
aitemplate = Model(aitemplate_path)
model = self.convert_ldm_unet_checkpoint(model.model.state_dict())
unet_params_ait = self.map_unet_state_dict(model)
print("Setting constants")
aitemplate.set_many_constants_with_tensors(unet_params_ait)
print("Folding constants")
aitemplate.fold_constants()
return (aitemplate,)
#=================#
# UNet Conversion #
#=================#
def assign_to_checkpoint(
self, paths, checkpoint, old_checkpoint, additional_replacements=None
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
attention layers, and takes into account additional replacements that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
for path in paths:
new_path = path["new"]
# Global renaming happens here
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
def conv_attn_to_linear(self, checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def renew_attention_paths(self, old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def shave_segments(self, path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(self, old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = self.shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def convert_ldm_unet_checkpoint(self, unet_state_dict, layers_per_block=2):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
temp = {}
for key, value in unet_state_dict.items():
if key.startswith("model.diffusion_model."):
key = key.replace("model.diffusion_model.", "")
temp[key] = value
unet_state_dict = temp
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (layers_per_block + 1)
layer_in_block_id = (i - 1) % (layers_per_block + 1)
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
paths = self.renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
self.assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path]
)
if len(attentions):
paths = self.renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
self.assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path]
)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
resnet_0_paths = self.renew_resnet_paths(resnet_0)
self.assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict)
resnet_1_paths = self.renew_resnet_paths(resnet_1)
self.assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict)
attentions_paths = self.renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
self.assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path]
)
for i in range(num_output_blocks):
block_id = i // (layers_per_block + 1)
layer_in_block_id = i % (layers_per_block + 1)
output_block_layers = [self.shave_segments(name, 2) for name in output_blocks[i]]
output_block_list = {}
for layer in output_block_layers:
layer_id, layer_name = layer.split(".")[0], self.shave_segments(layer, 1)
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1:
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
resnet_0_paths = self.renew_resnet_paths(resnets)
paths = self.renew_resnet_paths(resnets)
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
self.assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path]
)
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
if len(attentions):
paths = self.renew_attention_paths(attentions)
meta_path = {
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
self.assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path]
)
else:
resnet_0_paths = self.renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
new_checkpoint[new_path] = unet_state_dict[old_path]
return new_checkpoint
#=========================#
# AITemplate mapping #
#=========================#
def map_unet_state_dict(self, state_dict, dim=320):
params_ait = {}
for key, arr in state_dict.items():
if key.startswith("model.diffusion_model."):
key = key.replace("model.diffusion_model.", "")
arr = arr.to("cuda", dtype=torch.float16)
if len(arr.shape) == 4:
arr = arr.permute((0, 2, 3, 1)).contiguous()
elif key.endswith("ff.net.0.proj.weight"):
# print("ff.net.0.proj.weight")
w1, w2 = arr.chunk(2, dim=0)
params_ait[key.replace(".", "_")] = w1
params_ait[key.replace(".", "_").replace("proj", "gate")] = w2
continue
elif key.endswith("ff.net.0.proj.bias"):
# print("ff.net.0.proj.bias")
w1, w2 = arr.chunk(2, dim=0)
params_ait[key.replace(".", "_")] = w1
params_ait[key.replace(".", "_").replace("proj", "gate")] = w2
continue
params_ait[key.replace(".", "_")] = arr
params_ait["arange"] = (
torch.arange(start=0, end=dim // 2, dtype=torch.float32).cuda().half()
)
return params_ait
class LoraLoader: class LoraLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -826,7 +1129,7 @@ class SetLatentNoiseMask:
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
return (s,) return (s,)
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, aitemplate=None):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
latent_image = latent["samples"] latent_image = latent["samples"]
@ -846,7 +1149,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback) force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, aitemplate=aitemplate)
out = latent.copy() out = latent.copy()
out["samples"] = samples out["samples"] = samples
return (out, ) return (out, )
@ -875,6 +1178,32 @@ class KSampler:
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
class KSamplerAITemplate:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "sample"
CATEGORY = "sampling"
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, aitemplate=model)
class KSamplerAdvanced: class KSamplerAdvanced:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -1211,12 +1540,14 @@ NODE_CLASS_MAPPINGS = {
"ConditioningSetArea": ConditioningSetArea, "ConditioningSetArea": ConditioningSetArea,
"ConditioningSetMask": ConditioningSetMask, "ConditioningSetMask": ConditioningSetMask,
"KSamplerAdvanced": KSamplerAdvanced, "KSamplerAdvanced": KSamplerAdvanced,
"KSamplerAITemplate": KSamplerAITemplate,
"SetLatentNoiseMask": SetLatentNoiseMask, "SetLatentNoiseMask": SetLatentNoiseMask,
"LatentComposite": LatentComposite, "LatentComposite": LatentComposite,
"LatentRotate": LatentRotate, "LatentRotate": LatentRotate,
"LatentFlip": LatentFlip, "LatentFlip": LatentFlip,
"LatentCrop": LatentCrop, "LatentCrop": LatentCrop,
"LoraLoader": LoraLoader, "LoraLoader": LoraLoader,
"AITemplateLoader": AITemplateLoader,
"CLIPLoader": CLIPLoader, "CLIPLoader": CLIPLoader,
"CLIPVisionEncode": CLIPVisionEncode, "CLIPVisionEncode": CLIPVisionEncode,
"StyleModelApply": StyleModelApply, "StyleModelApply": StyleModelApply,
@ -1241,11 +1572,13 @@ NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling # Sampling
"KSampler": "KSampler", "KSampler": "KSampler",
"KSamplerAdvanced": "KSampler (Advanced)", "KSamplerAdvanced": "KSampler (Advanced)",
"KSamplerAITemplate": "KSampler (AITemplate)",
# Loaders # Loaders
"CheckpointLoader": "Load Checkpoint (With Config)", "CheckpointLoader": "Load Checkpoint (With Config)",
"CheckpointLoaderSimple": "Load Checkpoint", "CheckpointLoaderSimple": "Load Checkpoint",
"VAELoader": "Load VAE", "VAELoader": "Load VAE",
"LoraLoader": "Load LoRA", "LoraLoader": "Load LoRA",
"AITemplateLoader": "Load AITemplate",
"CLIPLoader": "Load CLIP", "CLIPLoader": "Load CLIP",
"ControlNetLoader": "Load ControlNet Model", "ControlNetLoader": "Load ControlNet Model",
"DiffControlNetLoader": "Load ControlNet Model (diff)", "DiffControlNetLoader": "Load ControlNet Model (diff)",