mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
Adding returned attention
KSampler's dpmpp_2m now returns attention, and added nodes for loading, and saving attention. Modified the PrintNode to print attention. Still have to add it to other samplers.
This commit is contained in:
parent
e42746498d
commit
5e062a88de
@ -111,8 +111,8 @@ class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return input + eps * c_out
|
||||
eps, attn = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return input + eps * c_out, attn
|
||||
|
||||
|
||||
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
|
||||
@ -590,8 +590,13 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
old_denoised = None
|
||||
|
||||
attention = None
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
denoised, attn = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if attention is None:
|
||||
attention = torch.empty((len(sigmas), *attn.shape), dtype=attn.dtype, device=attn.device)
|
||||
attention[i] = attn
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
@ -604,4 +609,4 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
||||
old_denoised = denoised
|
||||
return x
|
||||
return x, attention
|
||||
|
||||
@ -76,8 +76,9 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
||||
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback)
|
||||
samples, attention = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback)
|
||||
samples = samples.cpu()
|
||||
attention = attention.cpu()
|
||||
|
||||
cleanup_additional_models(models)
|
||||
return samples
|
||||
return samples, attention
|
||||
|
||||
@ -38,17 +38,17 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
if (area[1] + area[3]) < x_in.shape[3]:
|
||||
for t in range(rr):
|
||||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
||||
conditionning = {}
|
||||
conditionning['c_crossattn'] = cond[0]
|
||||
conditioning = {}
|
||||
conditioning['c_crossattn'] = cond[0]
|
||||
if cond_concat_in is not None and len(cond_concat_in) > 0:
|
||||
cropped = []
|
||||
for x in cond_concat_in:
|
||||
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
cropped.append(cr)
|
||||
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
||||
conditioning['c_concat'] = torch.cat(cropped, dim=1)
|
||||
|
||||
if adm_cond is not None:
|
||||
conditionning['c_adm'] = adm_cond
|
||||
conditioning['c_adm'] = adm_cond
|
||||
|
||||
control = None
|
||||
if 'control' in cond[1]:
|
||||
@ -67,7 +67,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
|
||||
patches['middle_patch'] = [gligen_patch]
|
||||
|
||||
return (input_x, mult, conditionning, area, control, patches)
|
||||
return (input_x, mult, conditioning, area, control, patches)
|
||||
|
||||
def cond_equal_size(c1, c2):
|
||||
if c1 is c2:
|
||||
@ -234,9 +234,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
max_total_area = model_management.maximum_batch_area()
|
||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
|
||||
if "sampler_cfg_function" in model_options:
|
||||
return model_options["sampler_cfg_function"](cond, uncond, cond_scale)
|
||||
return model_options["sampler_cfg_function"](cond, uncond, cond_scale), cond[0] # cond[0] is attention
|
||||
else:
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
return uncond + (cond - uncond) * cond_scale, cond[0] # cond[0] is attention
|
||||
|
||||
|
||||
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
|
||||
@ -253,8 +253,8 @@ class CFGNoisePredictor(torch.nn.Module):
|
||||
self.inner_model = model
|
||||
self.alphas_cumprod = model.alphas_cumprod
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}):
|
||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options)
|
||||
return out
|
||||
out, attn = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options)
|
||||
return out, attn
|
||||
|
||||
|
||||
class KSamplerX0Inpaint(torch.nn.Module):
|
||||
@ -265,13 +265,13 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
||||
if denoise_mask is not None:
|
||||
latent_mask = 1. - denoise_mask
|
||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options)
|
||||
out, attn = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options)
|
||||
if denoise_mask is not None:
|
||||
out *= denoise_mask
|
||||
|
||||
if denoise_mask is not None:
|
||||
out += self.latent_image * latent_mask
|
||||
return out
|
||||
return out, attn
|
||||
|
||||
def simple_scheduler(model, steps):
|
||||
sigs = []
|
||||
@ -580,6 +580,6 @@ class KSampler:
|
||||
elif self.sampler == "dpm_adaptive":
|
||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback)
|
||||
else:
|
||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback)
|
||||
samples, attention = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback)
|
||||
|
||||
return samples.to(torch.float32)
|
||||
return samples.to(torch.float32), attention
|
||||
|
||||
@ -2,6 +2,7 @@ import os
|
||||
import torch
|
||||
import folder_paths
|
||||
import numpy as np
|
||||
from safetensors.torch import save_file, safe_open
|
||||
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False):
|
||||
@ -27,6 +28,21 @@ def save_latent(samples, filename):
|
||||
np.save(filename, samples)
|
||||
|
||||
|
||||
def save_attention(attention, filename):
|
||||
filename = os.path.join(folder_paths.get_output_directory(), filename)
|
||||
save_file({"attention": attention}, filename)
|
||||
print(f"Attention tensor saved to {filename}")
|
||||
|
||||
|
||||
def load_attention(filename):
|
||||
tensors = {}
|
||||
filename = os.path.join(folder_paths.get_output_directory(), filename)
|
||||
with safe_open(filename, framework='pt', device=0) as f:
|
||||
for k in f.keys():
|
||||
tensors[k] = f.get_tensor(k)
|
||||
return tensors['attention']
|
||||
|
||||
|
||||
def load_latent(filename):
|
||||
filename = os.path.join(folder_paths.get_output_directory(), filename)
|
||||
return torch.from_numpy(np.load(filename))
|
||||
|
||||
76
nodes.py
76
nodes.py
@ -665,9 +665,6 @@ class SaveLatent:
|
||||
def save(self, samples, filename):
|
||||
s = samples.copy()
|
||||
comfy.utils.save_latent(samples["samples"], filename)
|
||||
|
||||
@clas
|
||||
|
||||
return (samples,)
|
||||
|
||||
class LoadLatent:
|
||||
@ -882,12 +879,12 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
if "noise_mask" in latent:
|
||||
noise_mask = latent["noise_mask"]
|
||||
|
||||
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||
samples, attention = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||
force_full_denoise=force_full_denoise, noise_mask=noise_mask)
|
||||
out = latent.copy()
|
||||
out["samples"] = samples
|
||||
return (out, )
|
||||
return (out, attention)
|
||||
|
||||
class KSampler:
|
||||
def __init__(self, event_dispatcher):
|
||||
@ -907,7 +904,7 @@ class KSampler:
|
||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
RETURN_TYPES = ("LATENT","ATTENTION")
|
||||
FUNCTION = "sample"
|
||||
|
||||
CATEGORY = "sampling"
|
||||
@ -1281,7 +1278,6 @@ class EventListener:
|
||||
return True
|
||||
|
||||
RETURN_TYPES = ("BOOL",)
|
||||
RETURN_NAMES = ("fired",)
|
||||
|
||||
FUNCTION = "listen"
|
||||
|
||||
@ -1299,7 +1295,7 @@ class EventListener:
|
||||
|
||||
return (self._fired,)
|
||||
|
||||
class PrinterNode:
|
||||
class PrintNode:
|
||||
|
||||
def __init__(self, event_dispatcher):
|
||||
self.event_dispatcher = event_dispatcher
|
||||
@ -1310,6 +1306,7 @@ class PrinterNode:
|
||||
"required": {},
|
||||
"optional": {
|
||||
"text": ("text",),
|
||||
"attention": ("ATTENTION",),
|
||||
"latent": ("LATENT",),
|
||||
}
|
||||
}
|
||||
@ -1322,16 +1319,69 @@ class PrinterNode:
|
||||
CATEGORY = "operations"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def print_value(self, text=None, latent=None):
|
||||
def print_value(self, text=None, latent=None, attention=None):
|
||||
if latent is not None:
|
||||
latent_hash = hashlib.sha256(latent["samples"].cpu().numpy().tobytes()).hexdigest()
|
||||
print(f"Latent hash: {latent_hash}")
|
||||
print(np.array2string(latent["samples"].cpu().numpy(), separator=', '))
|
||||
|
||||
if attention is not None:
|
||||
print(np.array2string(attention.cpu().numpy(), separator=', '))
|
||||
|
||||
print(text)
|
||||
if text is not None:
|
||||
print(text)
|
||||
return {"ui": {"": text}}
|
||||
|
||||
class SaveAttention:
|
||||
@classmethod
|
||||
def __init__(self, event_dispatcher):
|
||||
self.event_dispatcher = event_dispatcher
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"attention": ("ATTENTION",),
|
||||
"filename": ("STRING", {"default": "attention.safetensor"}),
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, *args, **kwargs):
|
||||
return True
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_attention"
|
||||
CATEGORY = "operations"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def save_attention(self, attention, filename):
|
||||
comfy.utils.save_attention(attention, filename)
|
||||
return {"ui": {"message": "Saved attention to " + filename}}
|
||||
|
||||
|
||||
|
||||
class LoadAttention:
|
||||
@classmethod
|
||||
def __init__(self, event_dispatcher):
|
||||
self.event_dispatcher = event_dispatcher
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"filename": ("STRING", {"default": "attention.safetensor"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ATTENTION",)
|
||||
FUNCTION = "load_attention"
|
||||
CATEGORY = "operations"
|
||||
|
||||
def load_attention(self, filename):
|
||||
return (comfy.utils.load_attention(filename),)
|
||||
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"KSampler": KSampler,
|
||||
@ -1382,9 +1432,11 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CheckpointLoader": CheckpointLoader,
|
||||
"DiffusersLoader": DiffusersLoader,
|
||||
"FrameCounter": FrameCounter,
|
||||
"PrinterNode": PrinterNode,
|
||||
"PrinterNode": PrintNode,
|
||||
"EventListener": EventListener,
|
||||
"MuxLatent": MuxLatent,
|
||||
"SaveAttention": SaveAttention,
|
||||
"LoadAttention": LoadAttention,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@ -1440,6 +1492,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PrinterNode": "Print",
|
||||
"EventListener": "Event Listener",
|
||||
"MuxLatent": "Mux Latent",
|
||||
"SaveAttention": "Save Attention",
|
||||
"LoadAttention": "Load Attention",
|
||||
}
|
||||
|
||||
def load_custom_node(module_path):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user