First working attention loading, maybe?

This commit is contained in:
InconsolableCellist 2023-04-30 12:51:13 -06:00
parent f969ec5108
commit 2ae3b42b26
10 changed files with 171 additions and 60 deletions

View File

@ -590,14 +590,11 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
t_fn = lambda sigma: sigma.log().neg() t_fn = lambda sigma: sigma.log().neg()
old_denoised = None old_denoised = None
attention_out = None attention_out = []
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
if attention is not None: extra_args['model_options']['transformer_options']['attention_step'] = i
extra_args['attention'] = attention[i]
denoised, attn = model(x, sigmas[i] * s_in, **extra_args) denoised, attn = model(x, sigmas[i] * s_in, **extra_args)
if attention_out is None: attention_out.append(attn)
attention_out = torch.empty((len(sigmas), *attn.shape), dtype=attn.dtype, device=attn.device)
attention_out[i] = attn
if callback is not None: if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})

View File

@ -846,7 +846,7 @@ class LatentDiffusion(DDPM):
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs) return self.p_losses(x, c, t, *args, **kwargs)
def apply_model(self, x_noisy, t, cond, return_ids=False): def apply_model(self, x_noisy, t, cond, return_ids=False, return_attention=False):
if isinstance(cond, dict): if isinstance(cond, dict):
# hybrid case, cond is expected to be a dict # hybrid case, cond is expected to be a dict
pass pass
@ -859,7 +859,12 @@ class LatentDiffusion(DDPM):
x_recon, attn = self.model(x_noisy, t, **cond) x_recon, attn = self.model(x_noisy, t, **cond)
if isinstance(x_recon, tuple) and not return_ids: if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0] x_recon = x_recon[0]
else:
x_recon = x_recon
if return_attention:
return x_recon, attn
else: else:
return x_recon return x_recon

View File

@ -163,7 +163,8 @@ class CrossAttentionBirchSan(nn.Module):
nn.Dropout(dropout) nn.Dropout(dropout)
) )
def forward(self, x, context=None, value=None, mask=None, return_attention=False): def forward(self, x, context=None, value=None, mask=None, return_attention=False, attention_to_mux=None,
attention_weight=0.0):
query = self.to_q(x) query = self.to_q(x)
context = default(context, x) context = default(context, x)
key = self.to_k(context) key = self.to_k(context)
@ -228,6 +229,8 @@ class CrossAttentionBirchSan(nn.Module):
use_checkpoint=self.training, use_checkpoint=self.training,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
return_attention=return_attention, return_attention=return_attention,
attention_to_mux=attention_to_mux,
attention_weight=attention_weight,
) )
if return_attention: if return_attention:
hidden_states, attention = output hidden_states, attention = output
@ -559,14 +562,43 @@ class BasicTransformerBlock(nn.Module):
for p in patch: for p in patch:
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1) n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
attn1_to_mux = None
attn2_to_mux = None
attention_weight = 0
if transformer_options.get("middle_block_step", None) is not None:
if transformer_options.get("attention_weight", 0.0) > 0.0:
attentions_to_mux_array = transformer_options.get("attention", None)
if attentions_to_mux_array is not None:
a = transformer_options["attention_step"]
b = transformer_options["middle_block_step"]
c = transformer_options["transformer_block_step"]
attn1_to_mux = attentions_to_mux_array[a][b][c][0]
attn2_to_mux = attentions_to_mux_array[a][b][c][1]
attention_weight = transformer_options["attention_weight"]
if "tomesd" in transformer_options: if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1)) n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
else: else:
if return_attention: if return_attention:
n, attn1_weights = self.attn1(n, context=context_attn1, value=value_attn1, return_attention=True) n, attn1_weights = self.attn1(n, context=context_attn1, value=value_attn1, return_attention=True,
attention_to_mux=attn1_to_mux, attention_weight=attention_weight)
else: else:
n = self.attn1(n, context=context_attn1, value=value_attn1) n = self.attn1(n, context=context_attn1, value=value_attn1, attention_to_mux=attn1_to_mux,
attention_weight=attention_weight)
# Interpolate n with attn1_to_mux
# if transformer_options.get("middle_block_step", None) is not None:
# if transformer_options.get("attention_weight", 0.0) > 0.0:
# attentions_to_mux_array = transformer_options.get("attention", None)
# if attentions_to_mux_array is not None:
# a = transformer_options["attention_step"]
# b = transformer_options["middle_block_step"]
# c = transformer_options["transformer_block_step"]
# attn1_to_mux = attentions_to_mux_array[a][b][c][0]
# attention_weight = transformer_options["attention_weight"]
# n = n * (1 - attention_weight) + attn1_to_mux * attention_weight
# print(f"muxed n with attn1_to_mux")
x += n x += n
if "middle_patch" in transformer_patches: if "middle_patch" in transformer_patches:
@ -586,9 +618,24 @@ class BasicTransformerBlock(nn.Module):
if return_attention: if return_attention:
n, attn2_weights = self.attn2(n, context=context_attn2, value=value_attn2, return_attention=True) n, attn2_weights = self.attn2(n, context=context_attn2, value=value_attn2, return_attention=True,
attention_to_mux=attn2_to_mux, attention_weight=attention_weight)
else: else:
n = self.attn2(n, context=context_attn2, value=value_attn2) n = self.attn2(n, context=context_attn2, value=value_attn2, attention_to_mux=attn2_to_mux,
attention_weight=attention_weight)
# Interpolate n with attn2_to_mux
# if transformer_options.get("middle_block_step", None) is not None:
# if transformer_options.get("attention_weight", 0.0) > 0.0:
# attentions_to_mux_array = transformer_options.get("attention", None)
# if attentions_to_mux_array is not None:
# a = transformer_options["attention_step"]
# b = transformer_options["middle_block_step"]
# c = transformer_options["transformer_block_step"]
# attn2_to_mux = attentions_to_mux_array[a][b][c][1]
# attention_weight = transformer_options["attention_weight"]
# n = n * (1 - attention_weight) + attn2_to_mux * attention_weight
# print(f"muxed n with attn2_to_mux")
x += n x += n
x = self.ff(self.norm3(x)) + x x = self.ff(self.norm3(x)) + x
@ -597,7 +644,7 @@ class BasicTransformerBlock(nn.Module):
transformer_options["current_index"] += 1 transformer_options["current_index"] += 1
if return_attention: if return_attention:
return x, (attn1_weights, attn2_weights) return x, (attn1_weights.cpu(), attn2_weights.cpu())
else: else:
return x return x
@ -660,6 +707,7 @@ class SpatialTransformer(nn.Module):
attention_tensors = [] attention_tensors = []
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
transformer_options["transformer_block_step"] = i
if transformer_options.get("return_attention", False): if transformer_options.get("return_attention", False):
x, attention = block(x, context=context[i], transformer_options=transformer_options) x, attention = block(x, context=context[i], transformer_options=transformer_options)
attention_tensors.append(attention) attention_tensors.append(attention)

View File

@ -6,17 +6,10 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ldm.modules.diffusionmodules.util import ( from comfy.ldm.modules.attention import SpatialTransformer
checkpoint, from comfy.ldm.modules.diffusionmodules.util import (checkpoint, conv_nd, linear, avg_pool_nd, zero_module,
conv_nd, normalization, timestep_embedding)
linear, from comfy.ldm.util import exists
avg_pool_nd,
zero_module,
normalization,
timestep_embedding,
)
from ldm.modules.attention import SpatialTransformer
from ldm.util import exists
# dummy replace # dummy replace
@ -81,7 +74,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
if isinstance(layer, TimestepBlock): if isinstance(layer, TimestepBlock):
x = layer(x, emb) x = layer(x, emb)
elif isinstance(layer, SpatialTransformer): elif isinstance(layer, SpatialTransformer):
if transformer_options.get("attention", False): if transformer_options.get("return_attention", False):
x, attention = layer(x, context, transformer_options) x, attention = layer(x, context, transformer_options)
else: else:
x = layer(x, context, transformer_options) x = layer(x, context, transformer_options)
@ -825,6 +818,7 @@ class UNetModel(nn.Module):
h = x.type(self.dtype) h = x.type(self.dtype)
input_and_output_options = transformer_options.copy() input_and_output_options = transformer_options.copy()
input_and_output_options["return_attention"] = False # No attention to be had in the input blocks input_and_output_options["return_attention"] = False # No attention to be had in the input blocks
input_and_output_options["middle_block_step"] = None
for id, module in enumerate(self.input_blocks): for id, module in enumerate(self.input_blocks):
h = module(h, emb, context, input_and_output_options) h = module(h, emb, context, input_and_output_options)
@ -835,7 +829,9 @@ class UNetModel(nn.Module):
hs.append(h) hs.append(h)
attention_tensors = [] attention_tensors = []
num_attention_blocks = 0
for i, module in enumerate(self.middle_block): for i, module in enumerate(self.middle_block):
transformer_options["middle_block_step"] = num_attention_blocks
if isinstance(module, AttentionBlock): if isinstance(module, AttentionBlock):
if transformer_options.get("return_attention", False): if transformer_options.get("return_attention", False):
h, attention = module(h, emb, context, transformer_options) h, attention = module(h, emb, context, transformer_options)
@ -845,10 +841,14 @@ class UNetModel(nn.Module):
# h = h * combined_attention # h = h * combined_attention
else: else:
h = module(h, emb, context, transformer_options) h = module(h, emb, context, transformer_options)
num_attention_blocks += 1
elif isinstance(module, SpatialTransformer): elif isinstance(module, SpatialTransformer):
if transformer_options.get("return_attention", False): if transformer_options.get("return_attention", False):
h, attention = module(h, context, transformer_options) h, attention = module(h, context, transformer_options)
attention_tensors.append(attention) attention_tensors.append(attention)
else:
h = module(h, context, transformer_options)
num_attention_blocks += 1
elif isinstance(module, TimestepBlock): elif isinstance(module, TimestepBlock):
h = module(h, emb) h = module(h, emb)
else: else:

View File

@ -96,6 +96,8 @@ def _query_chunk_attention(
summarize_chunk: SummarizeChunk, summarize_chunk: SummarizeChunk,
kv_chunk_size: int, kv_chunk_size: int,
return_attention: bool, return_attention: bool,
attention_to_mux: Optional[Tensor] = None,
attention_weight: float = 0.0,
) -> Tensor: ) -> Tensor:
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
_, _, v_channels_per_head = value.shape _, _, v_channels_per_head = value.shape
@ -140,6 +142,8 @@ def _get_attention_scores_no_kv_chunking(
scale: float, scale: float,
upcast_attention: bool, upcast_attention: bool,
return_attention: bool, return_attention: bool,
attention_to_mux: Optional[Tensor] = None,
attention_weight: float = 0.0,
) -> Tensor: ) -> Tensor:
if upcast_attention: if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'): with torch.autocast(enabled=False, device_type = 'cuda'):
@ -172,6 +176,11 @@ def _get_attention_scores_no_kv_chunking(
attn_scores /= summed attn_scores /= summed
attn_probs = attn_scores attn_probs = attn_scores
if attention_to_mux is not None:
attention_to_mux = attention_to_mux.to(attn_probs.device)
attn_probs = attn_probs * (1 - attention_weight) + attention_to_mux * attention_weight
print(f"muxed attention with weight {attention_weight}")
hidden_states_slice = torch.bmm(attn_probs, value) hidden_states_slice = torch.bmm(attn_probs, value)
if return_attention: if return_attention:
@ -193,6 +202,8 @@ def efficient_dot_product_attention(
use_checkpoint=True, use_checkpoint=True,
upcast_attention=False, upcast_attention=False,
return_attention=False, return_attention=False,
attention_to_mux: Optional[Tensor] = None,
attention_weight: float = 0.0,
): ):
"""Computes efficient dot-product attention given query, transposed key, and value. """Computes efficient dot-product attention given query, transposed key, and value.
This is efficient version of attention presented in This is efficient version of attention presented in
@ -235,6 +246,8 @@ def efficient_dot_product_attention(
scale=scale, scale=scale,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
return_attention=return_attention, return_attention=return_attention,
attention_to_mux=attention_to_mux,
attention_weight=attention_weight,
) if k_tokens <= kv_chunk_size else ( ) if k_tokens <= kv_chunk_size else (
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
partial( partial(
@ -242,6 +255,8 @@ def efficient_dot_product_attention(
kv_chunk_size=kv_chunk_size, kv_chunk_size=kv_chunk_size,
summarize_chunk=summarize_chunk, summarize_chunk=summarize_chunk,
return_attention=return_attention, return_attention=return_attention,
attention_to_mux=attention_to_mux,
attention_weight=attention_weight,
) )
) )

View File

@ -58,13 +58,12 @@ def cleanup_additional_models(models):
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0,
disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None,
sigmas=None, callback=None, attention=None): sigmas=None, callback=None, attention=None, attention_weight=0.0):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
if noise_mask is not None: if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise.shape, device) noise_mask = prepare_mask(noise_mask, noise.shape, device)
real_model = None
comfy.model_management.load_model_gpu(model) comfy.model_management.load_model_gpu(model)
real_model = model.model real_model = model.model
@ -79,12 +78,16 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler,
denoise=denoise, model_options=model.model_options) denoise=denoise, model_options=model.model_options)
transformer_options = model.model_options['transformer_options']
if transformer_options is not None and attention is not None and attention_weight > 0.0:
transformer_options['attention'] = attention
transformer_options['attention_weight'] = attention_weight
samples, attention = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, samples, attention = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image,
start_step=start_step, last_step=last_step, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, denoise_mask=noise_mask, force_full_denoise=force_full_denoise, denoise_mask=noise_mask,
sigmas=sigmas, callback=callback, attention=attention) sigmas=sigmas, callback=callback)
samples = samples.cpu() samples = samples.cpu()
attention = attention.cpu()
cleanup_additional_models(models) cleanup_additional_models(models)
return samples, attention return samples, attention

View File

@ -9,8 +9,8 @@ from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns predicted noise #Returns predicted noise
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}, def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}):
attention=None, attention_weight=0.5):
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0 strength = 1.0
@ -127,7 +127,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
return out return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in,
model_options, attention=None, attention_weight=0.5): model_options):
out_cond = torch.zeros_like(x_in) out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0 out_count = torch.ones_like(x_in)/100000.0
@ -137,6 +137,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
COND = 0 COND = 0
UNCOND = 1 UNCOND = 1
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
to_run = [] to_run = []
for x in cond: for x in cond:
p = get_area_and_mult(x, x_in, cond_concat_in, timestep) p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
@ -199,9 +203,6 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
# mixed_attention = attention_weight * torch.cat(attention) + (1 - attention_weight) * generated_attention # mixed_attention = attention_weight * torch.cat(attention) + (1 - attention_weight) * generated_attention
# c['c_crossattn'] = [mixed_attention] # c['c_crossattn'] = [mixed_attention]
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
if patches is not None: if patches is not None:
if "patches" in transformer_options: if "patches" in transformer_options:
@ -217,7 +218,11 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
# transformer_options['return_attention'] = True # transformer_options['return_attention'] = True
c['transformer_options'] = transformer_options c['transformer_options'] = transformer_options
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) if transformer_options.get("return_attention", False):
output, attn = model_function(input_x, timestep_, cond=c, return_attention=True)
output = output.chunk(batch_chunks)
else:
output = model_function(input_x, timestep_, cond=c, return_attention=False).chunk(batch_chunks)
del input_x del input_x
model_management.throw_exception_if_processing_interrupted() model_management.throw_exception_if_processing_interrupted()
@ -236,17 +241,24 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
out_uncond /= out_uncond_count out_uncond /= out_uncond_count
del out_uncond_count del out_uncond_count
return out_cond, out_uncond if transformer_options.get("return_attention", False):
return out_cond, out_uncond, attn
else:
return out_cond, out_uncond
max_total_area = model_management.maximum_batch_area() max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, cond, uncond, attn = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area,
model_options, attention=attention) cond_concat, model_options)
if "sampler_cfg_function" in model_options: if "sampler_cfg_function" in model_options:
return model_options["sampler_cfg_function"](cond, uncond, cond_scale), cond[0] # cond[0] is attention retval = model_options["sampler_cfg_function"](cond, uncond, cond_scale), attn
else: else:
return uncond + (cond - uncond) * cond_scale, cond[0] # cond[0] is attention retval = uncond + (cond - uncond) * cond_scale, attn
if model_options["transformer_options"].get("return_attention", False):
return retval
else:
return retval, attn
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser): class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
def __init__(self, model, quantize=False, device='cpu'): def __init__(self, model, quantize=False, device='cpu'):
@ -263,7 +275,7 @@ class CFGNoisePredictor(torch.nn.Module):
self.alphas_cumprod = model.alphas_cumprod self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, attention=None): def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, attention=None):
out, attn = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, out, attn = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat,
model_options=model_options, attention=attention) model_options=model_options)
return out, attn return out, attn
@ -594,6 +606,6 @@ class KSampler:
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback) samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback)
else: else:
samples, attention = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, samples, attention = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k,
noise, sigmas, extra_args=extra_args, callback=k_callback, attention=attention) noise, sigmas, extra_args=extra_args, callback=k_callback)
return samples.to(torch.float32), attention return samples.to(torch.float32), attention

View File

@ -2,8 +2,7 @@ import os
import torch import torch
import folder_paths import folder_paths
import numpy as np import numpy as np
from safetensors.torch import save_file, safe_open import joblib
def load_torch_file(ckpt, safe_load=False): def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"): if ckpt.lower().endswith(".safetensors"):
@ -28,20 +27,25 @@ def save_latent(samples, filename):
np.save(filename, samples) np.save(filename, samples)
# attention[a][b][c][d]
# a: number of steps/sigma in this diffusion process
# b: number of SpatialTransformer or AttentionBlocks used in the middle blocks of the latent diffusion model
# c: number of transformer layers in the SpatialTransformer or AttentionBlocks
# d: attn1, attn2
def save_attention(attention, filename): def save_attention(attention, filename):
filename = os.path.join(folder_paths.get_output_directory(), filename) filename = os.path.join(folder_paths.get_output_directory(), filename)
save_file({"attention": attention}, filename) joblib.dump(attention, filename)
print(f"Attention tensor saved to {filename}") print(f"Attention tensor saved to {filename}")
# returns attention[a][b][c][d]
# a: number of steps/sigma in this diffusion process
# b: number of SpatialTransformer or AttentionBlocks used in the middle blocks of the latent diffusion model
# c: number of transformer layers in the SpatialTransformer or AttentionBlocks
# d: attn1, attn2
def load_attention(filename): def load_attention(filename):
tensors = {}
filename = os.path.join(folder_paths.get_output_directory(), filename) filename = os.path.join(folder_paths.get_output_directory(), filename)
with safe_open(filename, framework='pt', device=0) as f: return joblib.load(filename)
for k in f.keys():
tensors[k] = f.get_tensor(k)
return tensors['attention']
def load_latent(filename): def load_latent(filename):
filename = os.path.join(folder_paths.get_output_directory(), filename) filename = os.path.join(folder_paths.get_output_directory(), filename)

View File

@ -866,7 +866,8 @@ class SetLatentNoiseMask:
return (s,) return (s,)
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0,
disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, attention=None): disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, attention=None,
attention_weight=0.0):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
latent_image = latent["samples"] latent_image = latent["samples"]
@ -880,9 +881,16 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent: if "noise_mask" in latent:
noise_mask = latent["noise_mask"] noise_mask = latent["noise_mask"]
samples, attention = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, samples, attention = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, attention=attention) last_step=last_step, force_full_denoise=force_full_denoise, noise_mask=noise_mask,
attention=attention, attention_weight=attention_weight)
# attention[a][b][c][d]
# a: number of steps/sigma in this diffusion process
# b: number of SpatialTransformer or AttentionBlocks used in the middle blocks of the latent diffusion model
# c: number of transformer layers in the SpatialTransformer or AttentionBlocks
# d: attn1, attn2
out = latent.copy() out = latent.copy()
out["samples"] = samples out["samples"] = samples
return (out, attention) return (out, attention)
@ -906,6 +914,7 @@ class KSampler:
}, },
"optional": { "optional": {
"attention": ("ATTENTION",), "attention": ("ATTENTION",),
"attention_weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
} }
} }
@ -915,10 +924,11 @@ class KSampler:
CATEGORY = "sampling" CATEGORY = "sampling"
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=1.0, attention=None): denoise=1.0, attention=None, attention_weight=0.0):
model.model_options["transformer_options"]["return_attention"] = True model.model_options["transformer_options"]["return_attention"] = True
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, attention=attention) denoise=denoise, attention=attention, attention_weight=attention_weight)
class KSamplerAdvanced: class KSamplerAdvanced:
def __init__(self, event_dispatcher): def __init__(self, event_dispatcher):
@ -1338,8 +1348,19 @@ class PrintNode:
print(f"Latent hash: {latent_hash}") print(f"Latent hash: {latent_hash}")
print(np.array2string(latent["samples"].cpu().numpy(), separator=', ')) print(np.array2string(latent["samples"].cpu().numpy(), separator=', '))
# attention[a][b][c][d]
# a: number of steps/sigma in this diffusion process
# b: number of SpatialTransformer or AttentionBlocks used in the middle blocks of the latent diffusion model
# c: number of transformer layers in the SpatialTransformer or AttentionBlocks
# d: attn1, attn2
if attention is not None: if attention is not None:
print(np.array2string(attention.cpu().numpy(), separator=', ')) print(f'attention has {len(attention)} steps')
print(f'each step has {len(attention[0])} transformer blocks')
print(f'each block has {len(attention[0][0])} transformer layers')
print(f'each transformer layer has {len(attention[0][0][0])} attention tensors (attn1, attn2)')
print(f'the shape of the attention tensors is {attention[0][0][0][0].shape}')
print(f'the first value of the first attention tensor is {attention[0][0][0][0][:1]}')
if text is not None: if text is not None:
print(text) print(text)
@ -1368,6 +1389,11 @@ class SaveAttention:
CATEGORY = "operations" CATEGORY = "operations"
OUTPUT_NODE = True OUTPUT_NODE = True
# attention[a][b][c][d]
# a: number of steps/sigma in this diffusion process
# b: number of SpatialTransformer or AttentionBlocks used in the middle blocks of the latent diffusion model
# c: number of transformer layers in the SpatialTransformer or AttentionBlocks
# d: attn1, attn2
def save_attention(self, attention, filename): def save_attention(self, attention, filename):
comfy.utils.save_attention(attention, filename) comfy.utils.save_attention(attention, filename)
return {"ui": {"message": "Saved attention to " + filename}} return {"ui": {"message": "Saved attention to " + filename}}

View File

@ -9,3 +9,4 @@ pytorch_lightning
aiohttp aiohttp
accelerate accelerate
pyyaml pyyaml
joblib