Adding returned attention

KSampler's dpmpp_2m now returns attention, and added nodes for loading, and saving attention. Modified the PrintNode to print attention. Still have to add it to other samplers.
This commit is contained in:
InconsolableCellist 2023-04-28 17:23:19 -06:00
parent e42746498d
commit 5e062a88de
6 changed files with 106 additions and 30 deletions

View File

@ -111,8 +111,8 @@ class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
def forward(self, input, sigma, **kwargs):
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
return input + eps * c_out
eps, attn = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
return input + eps * c_out, attn
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):

View File

@ -590,8 +590,13 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
t_fn = lambda sigma: sigma.log().neg()
old_denoised = None
attention = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
denoised, attn = model(x, sigmas[i] * s_in, **extra_args)
if attention is None:
attention = torch.empty((len(sigmas), *attn.shape), dtype=attn.dtype, device=attn.device)
attention[i] = attn
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
@ -604,4 +609,4 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
old_denoised = denoised
return x
return x, attention

View File

@ -76,8 +76,9 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback)
samples, attention = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback)
samples = samples.cpu()
attention = attention.cpu()
cleanup_additional_models(models)
return samples
return samples, attention

View File

@ -38,17 +38,17 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {}
conditionning['c_crossattn'] = cond[0]
conditioning = {}
conditioning['c_crossattn'] = cond[0]
if cond_concat_in is not None and len(cond_concat_in) > 0:
cropped = []
for x in cond_concat_in:
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
cropped.append(cr)
conditionning['c_concat'] = torch.cat(cropped, dim=1)
conditioning['c_concat'] = torch.cat(cropped, dim=1)
if adm_cond is not None:
conditionning['c_adm'] = adm_cond
conditioning['c_adm'] = adm_cond
control = None
if 'control' in cond[1]:
@ -67,7 +67,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
patches['middle_patch'] = [gligen_patch]
return (input_x, mult, conditionning, area, control, patches)
return (input_x, mult, conditioning, area, control, patches)
def cond_equal_size(c1, c2):
if c1 is c2:
@ -234,9 +234,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
if "sampler_cfg_function" in model_options:
return model_options["sampler_cfg_function"](cond, uncond, cond_scale)
return model_options["sampler_cfg_function"](cond, uncond, cond_scale), cond[0] # cond[0] is attention
else:
return uncond + (cond - uncond) * cond_scale
return uncond + (cond - uncond) * cond_scale, cond[0] # cond[0] is attention
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
@ -253,8 +253,8 @@ class CFGNoisePredictor(torch.nn.Module):
self.inner_model = model
self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options)
return out
out, attn = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options)
return out, attn
class KSamplerX0Inpaint(torch.nn.Module):
@ -265,13 +265,13 @@ class KSamplerX0Inpaint(torch.nn.Module):
if denoise_mask is not None:
latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options)
out, attn = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options)
if denoise_mask is not None:
out *= denoise_mask
if denoise_mask is not None:
out += self.latent_image * latent_mask
return out
return out, attn
def simple_scheduler(model, steps):
sigs = []
@ -580,6 +580,6 @@ class KSampler:
elif self.sampler == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback)
else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback)
samples, attention = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback)
return samples.to(torch.float32)
return samples.to(torch.float32), attention

View File

@ -2,6 +2,7 @@ import os
import torch
import folder_paths
import numpy as np
from safetensors.torch import save_file, safe_open
def load_torch_file(ckpt, safe_load=False):
@ -27,6 +28,21 @@ def save_latent(samples, filename):
np.save(filename, samples)
def save_attention(attention, filename):
filename = os.path.join(folder_paths.get_output_directory(), filename)
save_file({"attention": attention}, filename)
print(f"Attention tensor saved to {filename}")
def load_attention(filename):
tensors = {}
filename = os.path.join(folder_paths.get_output_directory(), filename)
with safe_open(filename, framework='pt', device=0) as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)
return tensors['attention']
def load_latent(filename):
filename = os.path.join(folder_paths.get_output_directory(), filename)
return torch.from_numpy(np.load(filename))

View File

@ -665,9 +665,6 @@ class SaveLatent:
def save(self, samples, filename):
s = samples.copy()
comfy.utils.save_latent(samples["samples"], filename)
@clas
return (samples,)
class LoadLatent:
@ -882,12 +879,12 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
samples, attention = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask)
out = latent.copy()
out["samples"] = samples
return (out, )
return (out, attention)
class KSampler:
def __init__(self, event_dispatcher):
@ -907,7 +904,7 @@ class KSampler:
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("LATENT",)
RETURN_TYPES = ("LATENT","ATTENTION")
FUNCTION = "sample"
CATEGORY = "sampling"
@ -1281,7 +1278,6 @@ class EventListener:
return True
RETURN_TYPES = ("BOOL",)
RETURN_NAMES = ("fired",)
FUNCTION = "listen"
@ -1299,7 +1295,7 @@ class EventListener:
return (self._fired,)
class PrinterNode:
class PrintNode:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@ -1310,6 +1306,7 @@ class PrinterNode:
"required": {},
"optional": {
"text": ("text",),
"attention": ("ATTENTION",),
"latent": ("LATENT",),
}
}
@ -1322,16 +1319,69 @@ class PrinterNode:
CATEGORY = "operations"
OUTPUT_NODE = True
def print_value(self, text=None, latent=None):
def print_value(self, text=None, latent=None, attention=None):
if latent is not None:
latent_hash = hashlib.sha256(latent["samples"].cpu().numpy().tobytes()).hexdigest()
print(f"Latent hash: {latent_hash}")
print(np.array2string(latent["samples"].cpu().numpy(), separator=', '))
if attention is not None:
print(np.array2string(attention.cpu().numpy(), separator=', '))
print(text)
if text is not None:
print(text)
return {"ui": {"": text}}
class SaveAttention:
@classmethod
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"attention": ("ATTENTION",),
"filename": ("STRING", {"default": "attention.safetensor"}),
},
}
@classmethod
def IS_CHANGED(cls, *args, **kwargs):
return True
RETURN_TYPES = ()
FUNCTION = "save_attention"
CATEGORY = "operations"
OUTPUT_NODE = True
def save_attention(self, attention, filename):
comfy.utils.save_attention(attention, filename)
return {"ui": {"message": "Saved attention to " + filename}}
class LoadAttention:
@classmethod
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"filename": ("STRING", {"default": "attention.safetensor"}),
},
}
RETURN_TYPES = ("ATTENTION",)
FUNCTION = "load_attention"
CATEGORY = "operations"
def load_attention(self, filename):
return (comfy.utils.load_attention(filename),)
NODE_CLASS_MAPPINGS = {
"KSampler": KSampler,
@ -1382,9 +1432,11 @@ NODE_CLASS_MAPPINGS = {
"CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader,
"FrameCounter": FrameCounter,
"PrinterNode": PrinterNode,
"PrinterNode": PrintNode,
"EventListener": EventListener,
"MuxLatent": MuxLatent,
"SaveAttention": SaveAttention,
"LoadAttention": LoadAttention,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@ -1440,6 +1492,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"PrinterNode": "Print",
"EventListener": "Event Listener",
"MuxLatent": "Mux Latent",
"SaveAttention": "Save Attention",
"LoadAttention": "Load Attention",
}
def load_custom_node(module_path):