From 5e062a88de196723fbdcd94be5d342d73fc0c69a Mon Sep 17 00:00:00 2001 From: InconsolableCellist <23345188+InconsolableCellist@users.noreply.github.com> Date: Fri, 28 Apr 2023 17:23:19 -0600 Subject: [PATCH] Adding returned attention KSampler's dpmpp_2m now returns attention, and added nodes for loading, and saving attention. Modified the PrintNode to print attention. Still have to add it to other samplers. --- comfy/k_diffusion/external.py | 4 +- comfy/k_diffusion/sampling.py | 9 ++++- comfy/sample.py | 5 ++- comfy/samplers.py | 26 ++++++------ comfy/utils.py | 16 ++++++++ nodes.py | 76 ++++++++++++++++++++++++++++++----- 6 files changed, 106 insertions(+), 30 deletions(-) diff --git a/comfy/k_diffusion/external.py b/comfy/k_diffusion/external.py index 49ce5ae39..688f6c973 100644 --- a/comfy/k_diffusion/external.py +++ b/comfy/k_diffusion/external.py @@ -111,8 +111,8 @@ class DiscreteEpsDDPMDenoiser(DiscreteSchedule): def forward(self, input, sigma, **kwargs): c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) - return input + eps * c_out + eps, attn = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) + return input + eps * c_out, attn class OpenAIDenoiser(DiscreteEpsDDPMDenoiser): diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index c809d39fb..1bf6034d1 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -590,8 +590,13 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No t_fn = lambda sigma: sigma.log().neg() old_denoised = None + attention = None for i in trange(len(sigmas) - 1, disable=disable): - denoised = model(x, sigmas[i] * s_in, **extra_args) + denoised, attn = model(x, sigmas[i] * s_in, **extra_args) + if attention is None: + attention = torch.empty((len(sigmas), *attn.shape), dtype=attn.dtype, device=attn.device) + attention[i] = attn + if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) @@ -604,4 +609,4 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d old_denoised = denoised - return x + return x, attention diff --git a/comfy/sample.py b/comfy/sample.py index f4132bbed..f28123106 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -76,8 +76,9 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback) + samples, attention = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback) samples = samples.cpu() + attention = attention.cpu() cleanup_additional_models(models) - return samples + return samples, attention diff --git a/comfy/samplers.py b/comfy/samplers.py index fc19ddcfc..b7a1cf207 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -38,17 +38,17 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if (area[1] + area[3]) < x_in.shape[3]: for t in range(rr): mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) - conditionning = {} - conditionning['c_crossattn'] = cond[0] + conditioning = {} + conditioning['c_crossattn'] = cond[0] if cond_concat_in is not None and len(cond_concat_in) > 0: cropped = [] for x in cond_concat_in: cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] cropped.append(cr) - conditionning['c_concat'] = torch.cat(cropped, dim=1) + conditioning['c_concat'] = torch.cat(cropped, dim=1) if adm_cond is not None: - conditionning['c_adm'] = adm_cond + conditioning['c_adm'] = adm_cond control = None if 'control' in cond[1]: @@ -67,7 +67,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con patches['middle_patch'] = [gligen_patch] - return (input_x, mult, conditionning, area, control, patches) + return (input_x, mult, conditioning, area, control, patches) def cond_equal_size(c1, c2): if c1 is c2: @@ -234,9 +234,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con max_total_area = model_management.maximum_batch_area() cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) if "sampler_cfg_function" in model_options: - return model_options["sampler_cfg_function"](cond, uncond, cond_scale) + return model_options["sampler_cfg_function"](cond, uncond, cond_scale), cond[0] # cond[0] is attention else: - return uncond + (cond - uncond) * cond_scale + return uncond + (cond - uncond) * cond_scale, cond[0] # cond[0] is attention class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser): @@ -253,8 +253,8 @@ class CFGNoisePredictor(torch.nn.Module): self.inner_model = model self.alphas_cumprod = model.alphas_cumprod def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}): - out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options) - return out + out, attn = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options) + return out, attn class KSamplerX0Inpaint(torch.nn.Module): @@ -265,13 +265,13 @@ class KSamplerX0Inpaint(torch.nn.Module): if denoise_mask is not None: latent_mask = 1. - denoise_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask - out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options) + out, attn = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options) if denoise_mask is not None: out *= denoise_mask if denoise_mask is not None: out += self.latent_image * latent_mask - return out + return out, attn def simple_scheduler(model, steps): sigs = [] @@ -580,6 +580,6 @@ class KSampler: elif self.sampler == "dpm_adaptive": samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback) else: - samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback) + samples, attention = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback) - return samples.to(torch.float32) + return samples.to(torch.float32), attention diff --git a/comfy/utils.py b/comfy/utils.py index 46bc325c6..26ebac2c4 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -2,6 +2,7 @@ import os import torch import folder_paths import numpy as np +from safetensors.torch import save_file, safe_open def load_torch_file(ckpt, safe_load=False): @@ -27,6 +28,21 @@ def save_latent(samples, filename): np.save(filename, samples) +def save_attention(attention, filename): + filename = os.path.join(folder_paths.get_output_directory(), filename) + save_file({"attention": attention}, filename) + print(f"Attention tensor saved to {filename}") + + +def load_attention(filename): + tensors = {} + filename = os.path.join(folder_paths.get_output_directory(), filename) + with safe_open(filename, framework='pt', device=0) as f: + for k in f.keys(): + tensors[k] = f.get_tensor(k) + return tensors['attention'] + + def load_latent(filename): filename = os.path.join(folder_paths.get_output_directory(), filename) return torch.from_numpy(np.load(filename)) diff --git a/nodes.py b/nodes.py index 7eeea9650..52d34641d 100644 --- a/nodes.py +++ b/nodes.py @@ -665,9 +665,6 @@ class SaveLatent: def save(self, samples, filename): s = samples.copy() comfy.utils.save_latent(samples["samples"], filename) - - @clas - return (samples,) class LoadLatent: @@ -882,12 +879,12 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if "noise_mask" in latent: noise_mask = latent["noise_mask"] - samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + samples, attention = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, noise_mask=noise_mask) out = latent.copy() out["samples"] = samples - return (out, ) + return (out, attention) class KSampler: def __init__(self, event_dispatcher): @@ -907,7 +904,7 @@ class KSampler: "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }} - RETURN_TYPES = ("LATENT",) + RETURN_TYPES = ("LATENT","ATTENTION") FUNCTION = "sample" CATEGORY = "sampling" @@ -1281,7 +1278,6 @@ class EventListener: return True RETURN_TYPES = ("BOOL",) - RETURN_NAMES = ("fired",) FUNCTION = "listen" @@ -1299,7 +1295,7 @@ class EventListener: return (self._fired,) -class PrinterNode: +class PrintNode: def __init__(self, event_dispatcher): self.event_dispatcher = event_dispatcher @@ -1310,6 +1306,7 @@ class PrinterNode: "required": {}, "optional": { "text": ("text",), + "attention": ("ATTENTION",), "latent": ("LATENT",), } } @@ -1322,16 +1319,69 @@ class PrinterNode: CATEGORY = "operations" OUTPUT_NODE = True - def print_value(self, text=None, latent=None): + def print_value(self, text=None, latent=None, attention=None): if latent is not None: latent_hash = hashlib.sha256(latent["samples"].cpu().numpy().tobytes()).hexdigest() print(f"Latent hash: {latent_hash}") print(np.array2string(latent["samples"].cpu().numpy(), separator=', ')) + if attention is not None: + print(np.array2string(attention.cpu().numpy(), separator=', ')) - print(text) + if text is not None: + print(text) return {"ui": {"": text}} +class SaveAttention: + @classmethod + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "attention": ("ATTENTION",), + "filename": ("STRING", {"default": "attention.safetensor"}), + }, + } + + @classmethod + def IS_CHANGED(cls, *args, **kwargs): + return True + + RETURN_TYPES = () + FUNCTION = "save_attention" + CATEGORY = "operations" + OUTPUT_NODE = True + + def save_attention(self, attention, filename): + comfy.utils.save_attention(attention, filename) + return {"ui": {"message": "Saved attention to " + filename}} + + + +class LoadAttention: + @classmethod + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "filename": ("STRING", {"default": "attention.safetensor"}), + }, + } + + RETURN_TYPES = ("ATTENTION",) + FUNCTION = "load_attention" + CATEGORY = "operations" + + def load_attention(self, filename): + return (comfy.utils.load_attention(filename),) + + NODE_CLASS_MAPPINGS = { "KSampler": KSampler, @@ -1382,9 +1432,11 @@ NODE_CLASS_MAPPINGS = { "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, "FrameCounter": FrameCounter, - "PrinterNode": PrinterNode, + "PrinterNode": PrintNode, "EventListener": EventListener, "MuxLatent": MuxLatent, + "SaveAttention": SaveAttention, + "LoadAttention": LoadAttention, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -1440,6 +1492,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "PrinterNode": "Print", "EventListener": "Event Listener", "MuxLatent": "Mux Latent", + "SaveAttention": "Save Attention", + "LoadAttention": "Load Attention", } def load_custom_node(module_path):