diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 4bef57580..b6c97e9b0 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -590,14 +590,11 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No t_fn = lambda sigma: sigma.log().neg() old_denoised = None - attention_out = None + attention_out = [] for i in trange(len(sigmas) - 1, disable=disable): - if attention is not None: - extra_args['attention'] = attention[i] + extra_args['model_options']['transformer_options']['attention_step'] = i denoised, attn = model(x, sigmas[i] * s_in, **extra_args) - if attention_out is None: - attention_out = torch.empty((len(sigmas), *attn.shape), dtype=attn.dtype, device=attn.device) - attention_out[i] = attn + attention_out.append(attn) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py index 22b6f52aa..52fda9cb8 100644 --- a/comfy/ldm/models/diffusion/ddpm.py +++ b/comfy/ldm/models/diffusion/ddpm.py @@ -846,7 +846,7 @@ class LatentDiffusion(DDPM): c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) - def apply_model(self, x_noisy, t, cond, return_ids=False): + def apply_model(self, x_noisy, t, cond, return_ids=False, return_attention=False): if isinstance(cond, dict): # hybrid case, cond is expected to be a dict pass @@ -859,7 +859,12 @@ class LatentDiffusion(DDPM): x_recon, attn = self.model(x_noisy, t, **cond) if isinstance(x_recon, tuple) and not return_ids: - return x_recon[0] + x_recon = x_recon[0] + else: + x_recon = x_recon + + if return_attention: + return x_recon, attn else: return x_recon diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index d96a2dbec..3a3c98bd2 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -163,7 +163,8 @@ class CrossAttentionBirchSan(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, value=None, mask=None, return_attention=False): + def forward(self, x, context=None, value=None, mask=None, return_attention=False, attention_to_mux=None, + attention_weight=0.0): query = self.to_q(x) context = default(context, x) key = self.to_k(context) @@ -228,6 +229,8 @@ class CrossAttentionBirchSan(nn.Module): use_checkpoint=self.training, upcast_attention=upcast_attention, return_attention=return_attention, + attention_to_mux=attention_to_mux, + attention_weight=attention_weight, ) if return_attention: hidden_states, attention = output @@ -559,14 +562,43 @@ class BasicTransformerBlock(nn.Module): for p in patch: n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1) + attn1_to_mux = None + attn2_to_mux = None + attention_weight = 0 + if transformer_options.get("middle_block_step", None) is not None: + if transformer_options.get("attention_weight", 0.0) > 0.0: + attentions_to_mux_array = transformer_options.get("attention", None) + if attentions_to_mux_array is not None: + a = transformer_options["attention_step"] + b = transformer_options["middle_block_step"] + c = transformer_options["transformer_block_step"] + attn1_to_mux = attentions_to_mux_array[a][b][c][0] + attn2_to_mux = attentions_to_mux_array[a][b][c][1] + attention_weight = transformer_options["attention_weight"] + if "tomesd" in transformer_options: m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) n = u(self.attn1(m(n), context=context_attn1, value=value_attn1)) else: if return_attention: - n, attn1_weights = self.attn1(n, context=context_attn1, value=value_attn1, return_attention=True) + n, attn1_weights = self.attn1(n, context=context_attn1, value=value_attn1, return_attention=True, + attention_to_mux=attn1_to_mux, attention_weight=attention_weight) else: - n = self.attn1(n, context=context_attn1, value=value_attn1) + n = self.attn1(n, context=context_attn1, value=value_attn1, attention_to_mux=attn1_to_mux, + attention_weight=attention_weight) + + # Interpolate n with attn1_to_mux + # if transformer_options.get("middle_block_step", None) is not None: + # if transformer_options.get("attention_weight", 0.0) > 0.0: + # attentions_to_mux_array = transformer_options.get("attention", None) + # if attentions_to_mux_array is not None: + # a = transformer_options["attention_step"] + # b = transformer_options["middle_block_step"] + # c = transformer_options["transformer_block_step"] + # attn1_to_mux = attentions_to_mux_array[a][b][c][0] + # attention_weight = transformer_options["attention_weight"] + # n = n * (1 - attention_weight) + attn1_to_mux * attention_weight + # print(f"muxed n with attn1_to_mux") x += n if "middle_patch" in transformer_patches: @@ -586,9 +618,24 @@ class BasicTransformerBlock(nn.Module): if return_attention: - n, attn2_weights = self.attn2(n, context=context_attn2, value=value_attn2, return_attention=True) + n, attn2_weights = self.attn2(n, context=context_attn2, value=value_attn2, return_attention=True, + attention_to_mux=attn2_to_mux, attention_weight=attention_weight) else: - n = self.attn2(n, context=context_attn2, value=value_attn2) + n = self.attn2(n, context=context_attn2, value=value_attn2, attention_to_mux=attn2_to_mux, + attention_weight=attention_weight) + + # Interpolate n with attn2_to_mux + # if transformer_options.get("middle_block_step", None) is not None: + # if transformer_options.get("attention_weight", 0.0) > 0.0: + # attentions_to_mux_array = transformer_options.get("attention", None) + # if attentions_to_mux_array is not None: + # a = transformer_options["attention_step"] + # b = transformer_options["middle_block_step"] + # c = transformer_options["transformer_block_step"] + # attn2_to_mux = attentions_to_mux_array[a][b][c][1] + # attention_weight = transformer_options["attention_weight"] + # n = n * (1 - attention_weight) + attn2_to_mux * attention_weight + # print(f"muxed n with attn2_to_mux") x += n x = self.ff(self.norm3(x)) + x @@ -597,7 +644,7 @@ class BasicTransformerBlock(nn.Module): transformer_options["current_index"] += 1 if return_attention: - return x, (attn1_weights, attn2_weights) + return x, (attn1_weights.cpu(), attn2_weights.cpu()) else: return x @@ -660,6 +707,7 @@ class SpatialTransformer(nn.Module): attention_tensors = [] for i, block in enumerate(self.transformer_blocks): + transformer_options["transformer_block_step"] = i if transformer_options.get("return_attention", False): x, attention = block(x, context=context[i], transformer_options=transformer_options) attention_tensors.append(attention) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index f204a4773..0e35501c7 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -6,17 +6,10 @@ import torch as th import torch.nn as nn import torch.nn.functional as F -from ldm.modules.diffusionmodules.util import ( - checkpoint, - conv_nd, - linear, - avg_pool_nd, - zero_module, - normalization, - timestep_embedding, -) -from ldm.modules.attention import SpatialTransformer -from ldm.util import exists +from comfy.ldm.modules.attention import SpatialTransformer +from comfy.ldm.modules.diffusionmodules.util import (checkpoint, conv_nd, linear, avg_pool_nd, zero_module, + normalization, timestep_embedding) +from comfy.ldm.util import exists # dummy replace @@ -81,7 +74,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): if isinstance(layer, TimestepBlock): x = layer(x, emb) elif isinstance(layer, SpatialTransformer): - if transformer_options.get("attention", False): + if transformer_options.get("return_attention", False): x, attention = layer(x, context, transformer_options) else: x = layer(x, context, transformer_options) @@ -825,6 +818,7 @@ class UNetModel(nn.Module): h = x.type(self.dtype) input_and_output_options = transformer_options.copy() input_and_output_options["return_attention"] = False # No attention to be had in the input blocks + input_and_output_options["middle_block_step"] = None for id, module in enumerate(self.input_blocks): h = module(h, emb, context, input_and_output_options) @@ -835,7 +829,9 @@ class UNetModel(nn.Module): hs.append(h) attention_tensors = [] + num_attention_blocks = 0 for i, module in enumerate(self.middle_block): + transformer_options["middle_block_step"] = num_attention_blocks if isinstance(module, AttentionBlock): if transformer_options.get("return_attention", False): h, attention = module(h, emb, context, transformer_options) @@ -845,10 +841,14 @@ class UNetModel(nn.Module): # h = h * combined_attention else: h = module(h, emb, context, transformer_options) + num_attention_blocks += 1 elif isinstance(module, SpatialTransformer): if transformer_options.get("return_attention", False): h, attention = module(h, context, transformer_options) attention_tensors.append(attention) + else: + h = module(h, context, transformer_options) + num_attention_blocks += 1 elif isinstance(module, TimestepBlock): h = module(h, emb) else: diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index f1e48d62a..981b85da2 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -96,6 +96,8 @@ def _query_chunk_attention( summarize_chunk: SummarizeChunk, kv_chunk_size: int, return_attention: bool, + attention_to_mux: Optional[Tensor] = None, + attention_weight: float = 0.0, ) -> Tensor: batch_x_heads, k_channels_per_head, k_tokens = key_t.shape _, _, v_channels_per_head = value.shape @@ -140,6 +142,8 @@ def _get_attention_scores_no_kv_chunking( scale: float, upcast_attention: bool, return_attention: bool, + attention_to_mux: Optional[Tensor] = None, + attention_weight: float = 0.0, ) -> Tensor: if upcast_attention: with torch.autocast(enabled=False, device_type = 'cuda'): @@ -172,6 +176,11 @@ def _get_attention_scores_no_kv_chunking( attn_scores /= summed attn_probs = attn_scores + if attention_to_mux is not None: + attention_to_mux = attention_to_mux.to(attn_probs.device) + attn_probs = attn_probs * (1 - attention_weight) + attention_to_mux * attention_weight + print(f"muxed attention with weight {attention_weight}") + hidden_states_slice = torch.bmm(attn_probs, value) if return_attention: @@ -193,6 +202,8 @@ def efficient_dot_product_attention( use_checkpoint=True, upcast_attention=False, return_attention=False, + attention_to_mux: Optional[Tensor] = None, + attention_weight: float = 0.0, ): """Computes efficient dot-product attention given query, transposed key, and value. This is efficient version of attention presented in @@ -235,6 +246,8 @@ def efficient_dot_product_attention( scale=scale, upcast_attention=upcast_attention, return_attention=return_attention, + attention_to_mux=attention_to_mux, + attention_weight=attention_weight, ) if k_tokens <= kv_chunk_size else ( # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) partial( @@ -242,6 +255,8 @@ def efficient_dot_product_attention( kv_chunk_size=kv_chunk_size, summarize_chunk=summarize_chunk, return_attention=return_attention, + attention_to_mux=attention_to_mux, + attention_weight=attention_weight, ) ) diff --git a/comfy/sample.py b/comfy/sample.py index 0676329f0..564ab1e5a 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -58,13 +58,12 @@ def cleanup_additional_models(models): def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, - sigmas=None, callback=None, attention=None): + sigmas=None, callback=None, attention=None, attention_weight=0.0): device = comfy.model_management.get_torch_device() if noise_mask is not None: noise_mask = prepare_mask(noise_mask, noise.shape, device) - real_model = None comfy.model_management.load_model_gpu(model) real_model = model.model @@ -79,12 +78,16 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) + transformer_options = model.model_options['transformer_options'] + if transformer_options is not None and attention is not None and attention_weight > 0.0: + transformer_options['attention'] = attention + transformer_options['attention_weight'] = attention_weight + samples, attention = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, - sigmas=sigmas, callback=callback, attention=attention) + sigmas=sigmas, callback=callback) samples = samples.cpu() - attention = attention.cpu() cleanup_additional_models(models) return samples, attention diff --git a/comfy/samplers.py b/comfy/samplers.py index b52993cd1..2301fa53a 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -9,8 +9,8 @@ from .ldm.modules.diffusionmodules.util import make_ddim_timesteps #The main sampling function shared by all the samplers #Returns predicted noise -def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}, - attention=None, attention_weight=0.5): +def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}): + def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 @@ -127,7 +127,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con return out def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, - model_options, attention=None, attention_weight=0.5): + model_options): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in)/100000.0 @@ -137,6 +137,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con COND = 0 UNCOND = 1 + transformer_options = {} + if 'transformer_options' in model_options: + transformer_options = model_options['transformer_options'].copy() + to_run = [] for x in cond: p = get_area_and_mult(x, x_in, cond_concat_in, timestep) @@ -199,9 +203,6 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con # mixed_attention = attention_weight * torch.cat(attention) + (1 - attention_weight) * generated_attention # c['c_crossattn'] = [mixed_attention] - transformer_options = {} - if 'transformer_options' in model_options: - transformer_options = model_options['transformer_options'].copy() if patches is not None: if "patches" in transformer_options: @@ -217,7 +218,11 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con # transformer_options['return_attention'] = True c['transformer_options'] = transformer_options - output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) + if transformer_options.get("return_attention", False): + output, attn = model_function(input_x, timestep_, cond=c, return_attention=True) + output = output.chunk(batch_chunks) + else: + output = model_function(input_x, timestep_, cond=c, return_attention=False).chunk(batch_chunks) del input_x model_management.throw_exception_if_processing_interrupted() @@ -236,17 +241,24 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con out_uncond /= out_uncond_count del out_uncond_count - return out_cond, out_uncond + if transformer_options.get("return_attention", False): + return out_cond, out_uncond, attn + else: + return out_cond, out_uncond max_total_area = model_management.maximum_batch_area() - cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, - model_options, attention=attention) + cond, uncond, attn = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, + cond_concat, model_options) if "sampler_cfg_function" in model_options: - return model_options["sampler_cfg_function"](cond, uncond, cond_scale), cond[0] # cond[0] is attention + retval = model_options["sampler_cfg_function"](cond, uncond, cond_scale), attn else: - return uncond + (cond - uncond) * cond_scale, cond[0] # cond[0] is attention + retval = uncond + (cond - uncond) * cond_scale, attn + if model_options["transformer_options"].get("return_attention", False): + return retval + else: + return retval, attn class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser): def __init__(self, model, quantize=False, device='cpu'): @@ -263,7 +275,7 @@ class CFGNoisePredictor(torch.nn.Module): self.alphas_cumprod = model.alphas_cumprod def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, attention=None): out, attn = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, - model_options=model_options, attention=attention) + model_options=model_options) return out, attn @@ -594,6 +606,6 @@ class KSampler: samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback) else: samples, attention = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, - noise, sigmas, extra_args=extra_args, callback=k_callback, attention=attention) + noise, sigmas, extra_args=extra_args, callback=k_callback) return samples.to(torch.float32), attention diff --git a/comfy/utils.py b/comfy/utils.py index 26ebac2c4..40cb8d04c 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -2,8 +2,7 @@ import os import torch import folder_paths import numpy as np -from safetensors.torch import save_file, safe_open - +import joblib def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -28,20 +27,25 @@ def save_latent(samples, filename): np.save(filename, samples) +# attention[a][b][c][d] +# a: number of steps/sigma in this diffusion process +# b: number of SpatialTransformer or AttentionBlocks used in the middle blocks of the latent diffusion model +# c: number of transformer layers in the SpatialTransformer or AttentionBlocks +# d: attn1, attn2 def save_attention(attention, filename): filename = os.path.join(folder_paths.get_output_directory(), filename) - save_file({"attention": attention}, filename) + joblib.dump(attention, filename) print(f"Attention tensor saved to {filename}") +# returns attention[a][b][c][d] +# a: number of steps/sigma in this diffusion process +# b: number of SpatialTransformer or AttentionBlocks used in the middle blocks of the latent diffusion model +# c: number of transformer layers in the SpatialTransformer or AttentionBlocks +# d: attn1, attn2 def load_attention(filename): - tensors = {} filename = os.path.join(folder_paths.get_output_directory(), filename) - with safe_open(filename, framework='pt', device=0) as f: - for k in f.keys(): - tensors[k] = f.get_tensor(k) - return tensors['attention'] - + return joblib.load(filename) def load_latent(filename): filename = os.path.join(folder_paths.get_output_directory(), filename) diff --git a/nodes.py b/nodes.py index f19692623..6d8368b31 100644 --- a/nodes.py +++ b/nodes.py @@ -866,7 +866,8 @@ class SetLatentNoiseMask: return (s,) def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, - disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, attention=None): + disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, attention=None, + attention_weight=0.0): device = comfy.model_management.get_torch_device() latent_image = latent["samples"] @@ -880,9 +881,16 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if "noise_mask" in latent: noise_mask = latent["noise_mask"] - samples, attention = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, - denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, - force_full_denoise=force_full_denoise, noise_mask=noise_mask, attention=attention) + samples, attention = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, + latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, + last_step=last_step, force_full_denoise=force_full_denoise, noise_mask=noise_mask, + attention=attention, attention_weight=attention_weight) + + # attention[a][b][c][d] + # a: number of steps/sigma in this diffusion process + # b: number of SpatialTransformer or AttentionBlocks used in the middle blocks of the latent diffusion model + # c: number of transformer layers in the SpatialTransformer or AttentionBlocks + # d: attn1, attn2 out = latent.copy() out["samples"] = samples return (out, attention) @@ -906,6 +914,7 @@ class KSampler: }, "optional": { "attention": ("ATTENTION",), + "attention_weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), } } @@ -915,10 +924,11 @@ class KSampler: CATEGORY = "sampling" def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, - denoise=1.0, attention=None): + denoise=1.0, attention=None, attention_weight=0.0): model.model_options["transformer_options"]["return_attention"] = True + return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, - denoise=denoise, attention=attention) + denoise=denoise, attention=attention, attention_weight=attention_weight) class KSamplerAdvanced: def __init__(self, event_dispatcher): @@ -1338,8 +1348,19 @@ class PrintNode: print(f"Latent hash: {latent_hash}") print(np.array2string(latent["samples"].cpu().numpy(), separator=', ')) + # attention[a][b][c][d] + # a: number of steps/sigma in this diffusion process + # b: number of SpatialTransformer or AttentionBlocks used in the middle blocks of the latent diffusion model + # c: number of transformer layers in the SpatialTransformer or AttentionBlocks + # d: attn1, attn2 if attention is not None: - print(np.array2string(attention.cpu().numpy(), separator=', ')) + print(f'attention has {len(attention)} steps') + print(f'each step has {len(attention[0])} transformer blocks') + print(f'each block has {len(attention[0][0])} transformer layers') + print(f'each transformer layer has {len(attention[0][0][0])} attention tensors (attn1, attn2)') + print(f'the shape of the attention tensors is {attention[0][0][0][0].shape}') + print(f'the first value of the first attention tensor is {attention[0][0][0][0][:1]}') + if text is not None: print(text) @@ -1368,6 +1389,11 @@ class SaveAttention: CATEGORY = "operations" OUTPUT_NODE = True + # attention[a][b][c][d] + # a: number of steps/sigma in this diffusion process + # b: number of SpatialTransformer or AttentionBlocks used in the middle blocks of the latent diffusion model + # c: number of transformer layers in the SpatialTransformer or AttentionBlocks + # d: attn1, attn2 def save_attention(self, attention, filename): comfy.utils.save_attention(attention, filename) return {"ui": {"message": "Saved attention to " + filename}} diff --git a/requirements.txt b/requirements.txt index 0527b31df..a47455609 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ pytorch_lightning aiohttp accelerate pyyaml +joblib