mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
First working attention loading, maybe?
This commit is contained in:
parent
f969ec5108
commit
2ae3b42b26
@ -590,14 +590,11 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
t_fn = lambda sigma: sigma.log().neg()
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
old_denoised = None
|
old_denoised = None
|
||||||
|
|
||||||
attention_out = None
|
attention_out = []
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
if attention is not None:
|
extra_args['model_options']['transformer_options']['attention_step'] = i
|
||||||
extra_args['attention'] = attention[i]
|
|
||||||
denoised, attn = model(x, sigmas[i] * s_in, **extra_args)
|
denoised, attn = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
if attention_out is None:
|
attention_out.append(attn)
|
||||||
attention_out = torch.empty((len(sigmas), *attn.shape), dtype=attn.dtype, device=attn.device)
|
|
||||||
attention_out[i] = attn
|
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
|||||||
@ -846,7 +846,7 @@ class LatentDiffusion(DDPM):
|
|||||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||||
return self.p_losses(x, c, t, *args, **kwargs)
|
return self.p_losses(x, c, t, *args, **kwargs)
|
||||||
|
|
||||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
def apply_model(self, x_noisy, t, cond, return_ids=False, return_attention=False):
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
# hybrid case, cond is expected to be a dict
|
# hybrid case, cond is expected to be a dict
|
||||||
pass
|
pass
|
||||||
@ -859,7 +859,12 @@ class LatentDiffusion(DDPM):
|
|||||||
x_recon, attn = self.model(x_noisy, t, **cond)
|
x_recon, attn = self.model(x_noisy, t, **cond)
|
||||||
|
|
||||||
if isinstance(x_recon, tuple) and not return_ids:
|
if isinstance(x_recon, tuple) and not return_ids:
|
||||||
return x_recon[0]
|
x_recon = x_recon[0]
|
||||||
|
else:
|
||||||
|
x_recon = x_recon
|
||||||
|
|
||||||
|
if return_attention:
|
||||||
|
return x_recon, attn
|
||||||
else:
|
else:
|
||||||
return x_recon
|
return x_recon
|
||||||
|
|
||||||
|
|||||||
@ -163,7 +163,8 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None, return_attention=False):
|
def forward(self, x, context=None, value=None, mask=None, return_attention=False, attention_to_mux=None,
|
||||||
|
attention_weight=0.0):
|
||||||
query = self.to_q(x)
|
query = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
key = self.to_k(context)
|
key = self.to_k(context)
|
||||||
@ -228,6 +229,8 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
use_checkpoint=self.training,
|
use_checkpoint=self.training,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
return_attention=return_attention,
|
return_attention=return_attention,
|
||||||
|
attention_to_mux=attention_to_mux,
|
||||||
|
attention_weight=attention_weight,
|
||||||
)
|
)
|
||||||
if return_attention:
|
if return_attention:
|
||||||
hidden_states, attention = output
|
hidden_states, attention = output
|
||||||
@ -559,14 +562,43 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
|
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
|
||||||
|
|
||||||
|
attn1_to_mux = None
|
||||||
|
attn2_to_mux = None
|
||||||
|
attention_weight = 0
|
||||||
|
if transformer_options.get("middle_block_step", None) is not None:
|
||||||
|
if transformer_options.get("attention_weight", 0.0) > 0.0:
|
||||||
|
attentions_to_mux_array = transformer_options.get("attention", None)
|
||||||
|
if attentions_to_mux_array is not None:
|
||||||
|
a = transformer_options["attention_step"]
|
||||||
|
b = transformer_options["middle_block_step"]
|
||||||
|
c = transformer_options["transformer_block_step"]
|
||||||
|
attn1_to_mux = attentions_to_mux_array[a][b][c][0]
|
||||||
|
attn2_to_mux = attentions_to_mux_array[a][b][c][1]
|
||||||
|
attention_weight = transformer_options["attention_weight"]
|
||||||
|
|
||||||
if "tomesd" in transformer_options:
|
if "tomesd" in transformer_options:
|
||||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||||
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
|
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
|
||||||
else:
|
else:
|
||||||
if return_attention:
|
if return_attention:
|
||||||
n, attn1_weights = self.attn1(n, context=context_attn1, value=value_attn1, return_attention=True)
|
n, attn1_weights = self.attn1(n, context=context_attn1, value=value_attn1, return_attention=True,
|
||||||
|
attention_to_mux=attn1_to_mux, attention_weight=attention_weight)
|
||||||
else:
|
else:
|
||||||
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
n = self.attn1(n, context=context_attn1, value=value_attn1, attention_to_mux=attn1_to_mux,
|
||||||
|
attention_weight=attention_weight)
|
||||||
|
|
||||||
|
# Interpolate n with attn1_to_mux
|
||||||
|
# if transformer_options.get("middle_block_step", None) is not None:
|
||||||
|
# if transformer_options.get("attention_weight", 0.0) > 0.0:
|
||||||
|
# attentions_to_mux_array = transformer_options.get("attention", None)
|
||||||
|
# if attentions_to_mux_array is not None:
|
||||||
|
# a = transformer_options["attention_step"]
|
||||||
|
# b = transformer_options["middle_block_step"]
|
||||||
|
# c = transformer_options["transformer_block_step"]
|
||||||
|
# attn1_to_mux = attentions_to_mux_array[a][b][c][0]
|
||||||
|
# attention_weight = transformer_options["attention_weight"]
|
||||||
|
# n = n * (1 - attention_weight) + attn1_to_mux * attention_weight
|
||||||
|
# print(f"muxed n with attn1_to_mux")
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
if "middle_patch" in transformer_patches:
|
if "middle_patch" in transformer_patches:
|
||||||
@ -586,9 +618,24 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
if return_attention:
|
if return_attention:
|
||||||
n, attn2_weights = self.attn2(n, context=context_attn2, value=value_attn2, return_attention=True)
|
n, attn2_weights = self.attn2(n, context=context_attn2, value=value_attn2, return_attention=True,
|
||||||
|
attention_to_mux=attn2_to_mux, attention_weight=attention_weight)
|
||||||
else:
|
else:
|
||||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
n = self.attn2(n, context=context_attn2, value=value_attn2, attention_to_mux=attn2_to_mux,
|
||||||
|
attention_weight=attention_weight)
|
||||||
|
|
||||||
|
# Interpolate n with attn2_to_mux
|
||||||
|
# if transformer_options.get("middle_block_step", None) is not None:
|
||||||
|
# if transformer_options.get("attention_weight", 0.0) > 0.0:
|
||||||
|
# attentions_to_mux_array = transformer_options.get("attention", None)
|
||||||
|
# if attentions_to_mux_array is not None:
|
||||||
|
# a = transformer_options["attention_step"]
|
||||||
|
# b = transformer_options["middle_block_step"]
|
||||||
|
# c = transformer_options["transformer_block_step"]
|
||||||
|
# attn2_to_mux = attentions_to_mux_array[a][b][c][1]
|
||||||
|
# attention_weight = transformer_options["attention_weight"]
|
||||||
|
# n = n * (1 - attention_weight) + attn2_to_mux * attention_weight
|
||||||
|
# print(f"muxed n with attn2_to_mux")
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
x = self.ff(self.norm3(x)) + x
|
x = self.ff(self.norm3(x)) + x
|
||||||
@ -597,7 +644,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
transformer_options["current_index"] += 1
|
transformer_options["current_index"] += 1
|
||||||
|
|
||||||
if return_attention:
|
if return_attention:
|
||||||
return x, (attn1_weights, attn2_weights)
|
return x, (attn1_weights.cpu(), attn2_weights.cpu())
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -660,6 +707,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
|
|
||||||
attention_tensors = []
|
attention_tensors = []
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
transformer_options["transformer_block_step"] = i
|
||||||
if transformer_options.get("return_attention", False):
|
if transformer_options.get("return_attention", False):
|
||||||
x, attention = block(x, context=context[i], transformer_options=transformer_options)
|
x, attention = block(x, context=context[i], transformer_options=transformer_options)
|
||||||
attention_tensors.append(attention)
|
attention_tensors.append(attention)
|
||||||
|
|||||||
@ -6,17 +6,10 @@ import torch as th
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ldm.modules.diffusionmodules.util import (
|
from comfy.ldm.modules.attention import SpatialTransformer
|
||||||
checkpoint,
|
from comfy.ldm.modules.diffusionmodules.util import (checkpoint, conv_nd, linear, avg_pool_nd, zero_module,
|
||||||
conv_nd,
|
normalization, timestep_embedding)
|
||||||
linear,
|
from comfy.ldm.util import exists
|
||||||
avg_pool_nd,
|
|
||||||
zero_module,
|
|
||||||
normalization,
|
|
||||||
timestep_embedding,
|
|
||||||
)
|
|
||||||
from ldm.modules.attention import SpatialTransformer
|
|
||||||
from ldm.util import exists
|
|
||||||
|
|
||||||
|
|
||||||
# dummy replace
|
# dummy replace
|
||||||
@ -81,7 +74,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|||||||
if isinstance(layer, TimestepBlock):
|
if isinstance(layer, TimestepBlock):
|
||||||
x = layer(x, emb)
|
x = layer(x, emb)
|
||||||
elif isinstance(layer, SpatialTransformer):
|
elif isinstance(layer, SpatialTransformer):
|
||||||
if transformer_options.get("attention", False):
|
if transformer_options.get("return_attention", False):
|
||||||
x, attention = layer(x, context, transformer_options)
|
x, attention = layer(x, context, transformer_options)
|
||||||
else:
|
else:
|
||||||
x = layer(x, context, transformer_options)
|
x = layer(x, context, transformer_options)
|
||||||
@ -825,6 +818,7 @@ class UNetModel(nn.Module):
|
|||||||
h = x.type(self.dtype)
|
h = x.type(self.dtype)
|
||||||
input_and_output_options = transformer_options.copy()
|
input_and_output_options = transformer_options.copy()
|
||||||
input_and_output_options["return_attention"] = False # No attention to be had in the input blocks
|
input_and_output_options["return_attention"] = False # No attention to be had in the input blocks
|
||||||
|
input_and_output_options["middle_block_step"] = None
|
||||||
|
|
||||||
for id, module in enumerate(self.input_blocks):
|
for id, module in enumerate(self.input_blocks):
|
||||||
h = module(h, emb, context, input_and_output_options)
|
h = module(h, emb, context, input_and_output_options)
|
||||||
@ -835,7 +829,9 @@ class UNetModel(nn.Module):
|
|||||||
hs.append(h)
|
hs.append(h)
|
||||||
|
|
||||||
attention_tensors = []
|
attention_tensors = []
|
||||||
|
num_attention_blocks = 0
|
||||||
for i, module in enumerate(self.middle_block):
|
for i, module in enumerate(self.middle_block):
|
||||||
|
transformer_options["middle_block_step"] = num_attention_blocks
|
||||||
if isinstance(module, AttentionBlock):
|
if isinstance(module, AttentionBlock):
|
||||||
if transformer_options.get("return_attention", False):
|
if transformer_options.get("return_attention", False):
|
||||||
h, attention = module(h, emb, context, transformer_options)
|
h, attention = module(h, emb, context, transformer_options)
|
||||||
@ -845,10 +841,14 @@ class UNetModel(nn.Module):
|
|||||||
# h = h * combined_attention
|
# h = h * combined_attention
|
||||||
else:
|
else:
|
||||||
h = module(h, emb, context, transformer_options)
|
h = module(h, emb, context, transformer_options)
|
||||||
|
num_attention_blocks += 1
|
||||||
elif isinstance(module, SpatialTransformer):
|
elif isinstance(module, SpatialTransformer):
|
||||||
if transformer_options.get("return_attention", False):
|
if transformer_options.get("return_attention", False):
|
||||||
h, attention = module(h, context, transformer_options)
|
h, attention = module(h, context, transformer_options)
|
||||||
attention_tensors.append(attention)
|
attention_tensors.append(attention)
|
||||||
|
else:
|
||||||
|
h = module(h, context, transformer_options)
|
||||||
|
num_attention_blocks += 1
|
||||||
elif isinstance(module, TimestepBlock):
|
elif isinstance(module, TimestepBlock):
|
||||||
h = module(h, emb)
|
h = module(h, emb)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -96,6 +96,8 @@ def _query_chunk_attention(
|
|||||||
summarize_chunk: SummarizeChunk,
|
summarize_chunk: SummarizeChunk,
|
||||||
kv_chunk_size: int,
|
kv_chunk_size: int,
|
||||||
return_attention: bool,
|
return_attention: bool,
|
||||||
|
attention_to_mux: Optional[Tensor] = None,
|
||||||
|
attention_weight: float = 0.0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
|
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
|
||||||
_, _, v_channels_per_head = value.shape
|
_, _, v_channels_per_head = value.shape
|
||||||
@ -140,6 +142,8 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
scale: float,
|
scale: float,
|
||||||
upcast_attention: bool,
|
upcast_attention: bool,
|
||||||
return_attention: bool,
|
return_attention: bool,
|
||||||
|
attention_to_mux: Optional[Tensor] = None,
|
||||||
|
attention_weight: float = 0.0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if upcast_attention:
|
if upcast_attention:
|
||||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||||
@ -172,6 +176,11 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
attn_scores /= summed
|
attn_scores /= summed
|
||||||
attn_probs = attn_scores
|
attn_probs = attn_scores
|
||||||
|
|
||||||
|
if attention_to_mux is not None:
|
||||||
|
attention_to_mux = attention_to_mux.to(attn_probs.device)
|
||||||
|
attn_probs = attn_probs * (1 - attention_weight) + attention_to_mux * attention_weight
|
||||||
|
print(f"muxed attention with weight {attention_weight}")
|
||||||
|
|
||||||
hidden_states_slice = torch.bmm(attn_probs, value)
|
hidden_states_slice = torch.bmm(attn_probs, value)
|
||||||
|
|
||||||
if return_attention:
|
if return_attention:
|
||||||
@ -193,6 +202,8 @@ def efficient_dot_product_attention(
|
|||||||
use_checkpoint=True,
|
use_checkpoint=True,
|
||||||
upcast_attention=False,
|
upcast_attention=False,
|
||||||
return_attention=False,
|
return_attention=False,
|
||||||
|
attention_to_mux: Optional[Tensor] = None,
|
||||||
|
attention_weight: float = 0.0,
|
||||||
):
|
):
|
||||||
"""Computes efficient dot-product attention given query, transposed key, and value.
|
"""Computes efficient dot-product attention given query, transposed key, and value.
|
||||||
This is efficient version of attention presented in
|
This is efficient version of attention presented in
|
||||||
@ -235,6 +246,8 @@ def efficient_dot_product_attention(
|
|||||||
scale=scale,
|
scale=scale,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
return_attention=return_attention,
|
return_attention=return_attention,
|
||||||
|
attention_to_mux=attention_to_mux,
|
||||||
|
attention_weight=attention_weight,
|
||||||
) if k_tokens <= kv_chunk_size else (
|
) if k_tokens <= kv_chunk_size else (
|
||||||
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
|
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
|
||||||
partial(
|
partial(
|
||||||
@ -242,6 +255,8 @@ def efficient_dot_product_attention(
|
|||||||
kv_chunk_size=kv_chunk_size,
|
kv_chunk_size=kv_chunk_size,
|
||||||
summarize_chunk=summarize_chunk,
|
summarize_chunk=summarize_chunk,
|
||||||
return_attention=return_attention,
|
return_attention=return_attention,
|
||||||
|
attention_to_mux=attention_to_mux,
|
||||||
|
attention_weight=attention_weight,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -58,13 +58,12 @@ def cleanup_additional_models(models):
|
|||||||
|
|
||||||
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0,
|
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0,
|
||||||
disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None,
|
disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None,
|
||||||
sigmas=None, callback=None, attention=None):
|
sigmas=None, callback=None, attention=None, attention_weight=0.0):
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
if noise_mask is not None:
|
if noise_mask is not None:
|
||||||
noise_mask = prepare_mask(noise_mask, noise.shape, device)
|
noise_mask = prepare_mask(noise_mask, noise.shape, device)
|
||||||
|
|
||||||
real_model = None
|
|
||||||
comfy.model_management.load_model_gpu(model)
|
comfy.model_management.load_model_gpu(model)
|
||||||
real_model = model.model
|
real_model = model.model
|
||||||
|
|
||||||
@ -79,12 +78,16 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
|||||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler,
|
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler,
|
||||||
denoise=denoise, model_options=model.model_options)
|
denoise=denoise, model_options=model.model_options)
|
||||||
|
|
||||||
|
transformer_options = model.model_options['transformer_options']
|
||||||
|
if transformer_options is not None and attention is not None and attention_weight > 0.0:
|
||||||
|
transformer_options['attention'] = attention
|
||||||
|
transformer_options['attention_weight'] = attention_weight
|
||||||
|
|
||||||
samples, attention = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image,
|
samples, attention = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image,
|
||||||
start_step=start_step, last_step=last_step,
|
start_step=start_step, last_step=last_step,
|
||||||
force_full_denoise=force_full_denoise, denoise_mask=noise_mask,
|
force_full_denoise=force_full_denoise, denoise_mask=noise_mask,
|
||||||
sigmas=sigmas, callback=callback, attention=attention)
|
sigmas=sigmas, callback=callback)
|
||||||
samples = samples.cpu()
|
samples = samples.cpu()
|
||||||
attention = attention.cpu()
|
|
||||||
|
|
||||||
cleanup_additional_models(models)
|
cleanup_additional_models(models)
|
||||||
return samples, attention
|
return samples, attention
|
||||||
|
|||||||
@ -9,8 +9,8 @@ from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
|||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns predicted noise
|
#Returns predicted noise
|
||||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={},
|
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}):
|
||||||
attention=None, attention_weight=0.5):
|
|
||||||
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
strength = 1.0
|
strength = 1.0
|
||||||
@ -127,7 +127,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in,
|
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in,
|
||||||
model_options, attention=None, attention_weight=0.5):
|
model_options):
|
||||||
out_cond = torch.zeros_like(x_in)
|
out_cond = torch.zeros_like(x_in)
|
||||||
out_count = torch.ones_like(x_in)/100000.0
|
out_count = torch.ones_like(x_in)/100000.0
|
||||||
|
|
||||||
@ -137,6 +137,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
COND = 0
|
COND = 0
|
||||||
UNCOND = 1
|
UNCOND = 1
|
||||||
|
|
||||||
|
transformer_options = {}
|
||||||
|
if 'transformer_options' in model_options:
|
||||||
|
transformer_options = model_options['transformer_options'].copy()
|
||||||
|
|
||||||
to_run = []
|
to_run = []
|
||||||
for x in cond:
|
for x in cond:
|
||||||
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
|
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
|
||||||
@ -199,9 +203,6 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
# mixed_attention = attention_weight * torch.cat(attention) + (1 - attention_weight) * generated_attention
|
# mixed_attention = attention_weight * torch.cat(attention) + (1 - attention_weight) * generated_attention
|
||||||
# c['c_crossattn'] = [mixed_attention]
|
# c['c_crossattn'] = [mixed_attention]
|
||||||
|
|
||||||
transformer_options = {}
|
|
||||||
if 'transformer_options' in model_options:
|
|
||||||
transformer_options = model_options['transformer_options'].copy()
|
|
||||||
|
|
||||||
if patches is not None:
|
if patches is not None:
|
||||||
if "patches" in transformer_options:
|
if "patches" in transformer_options:
|
||||||
@ -217,7 +218,11 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
# transformer_options['return_attention'] = True
|
# transformer_options['return_attention'] = True
|
||||||
c['transformer_options'] = transformer_options
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
if transformer_options.get("return_attention", False):
|
||||||
|
output, attn = model_function(input_x, timestep_, cond=c, return_attention=True)
|
||||||
|
output = output.chunk(batch_chunks)
|
||||||
|
else:
|
||||||
|
output = model_function(input_x, timestep_, cond=c, return_attention=False).chunk(batch_chunks)
|
||||||
del input_x
|
del input_x
|
||||||
|
|
||||||
model_management.throw_exception_if_processing_interrupted()
|
model_management.throw_exception_if_processing_interrupted()
|
||||||
@ -236,17 +241,24 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
out_uncond /= out_uncond_count
|
out_uncond /= out_uncond_count
|
||||||
del out_uncond_count
|
del out_uncond_count
|
||||||
|
|
||||||
return out_cond, out_uncond
|
if transformer_options.get("return_attention", False):
|
||||||
|
return out_cond, out_uncond, attn
|
||||||
|
else:
|
||||||
|
return out_cond, out_uncond
|
||||||
|
|
||||||
|
|
||||||
max_total_area = model_management.maximum_batch_area()
|
max_total_area = model_management.maximum_batch_area()
|
||||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat,
|
cond, uncond, attn = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area,
|
||||||
model_options, attention=attention)
|
cond_concat, model_options)
|
||||||
if "sampler_cfg_function" in model_options:
|
if "sampler_cfg_function" in model_options:
|
||||||
return model_options["sampler_cfg_function"](cond, uncond, cond_scale), cond[0] # cond[0] is attention
|
retval = model_options["sampler_cfg_function"](cond, uncond, cond_scale), attn
|
||||||
else:
|
else:
|
||||||
return uncond + (cond - uncond) * cond_scale, cond[0] # cond[0] is attention
|
retval = uncond + (cond - uncond) * cond_scale, attn
|
||||||
|
|
||||||
|
if model_options["transformer_options"].get("return_attention", False):
|
||||||
|
return retval
|
||||||
|
else:
|
||||||
|
return retval, attn
|
||||||
|
|
||||||
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
|
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
|
||||||
def __init__(self, model, quantize=False, device='cpu'):
|
def __init__(self, model, quantize=False, device='cpu'):
|
||||||
@ -263,7 +275,7 @@ class CFGNoisePredictor(torch.nn.Module):
|
|||||||
self.alphas_cumprod = model.alphas_cumprod
|
self.alphas_cumprod = model.alphas_cumprod
|
||||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, attention=None):
|
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, attention=None):
|
||||||
out, attn = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat,
|
out, attn = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat,
|
||||||
model_options=model_options, attention=attention)
|
model_options=model_options)
|
||||||
return out, attn
|
return out, attn
|
||||||
|
|
||||||
|
|
||||||
@ -594,6 +606,6 @@ class KSampler:
|
|||||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback)
|
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback)
|
||||||
else:
|
else:
|
||||||
samples, attention = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k,
|
samples, attention = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k,
|
||||||
noise, sigmas, extra_args=extra_args, callback=k_callback, attention=attention)
|
noise, sigmas, extra_args=extra_args, callback=k_callback)
|
||||||
|
|
||||||
return samples.to(torch.float32), attention
|
return samples.to(torch.float32), attention
|
||||||
|
|||||||
@ -2,8 +2,7 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from safetensors.torch import save_file, safe_open
|
import joblib
|
||||||
|
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False):
|
def load_torch_file(ckpt, safe_load=False):
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
@ -28,20 +27,25 @@ def save_latent(samples, filename):
|
|||||||
np.save(filename, samples)
|
np.save(filename, samples)
|
||||||
|
|
||||||
|
|
||||||
|
# attention[a][b][c][d]
|
||||||
|
# a: number of steps/sigma in this diffusion process
|
||||||
|
# b: number of SpatialTransformer or AttentionBlocks used in the middle blocks of the latent diffusion model
|
||||||
|
# c: number of transformer layers in the SpatialTransformer or AttentionBlocks
|
||||||
|
# d: attn1, attn2
|
||||||
def save_attention(attention, filename):
|
def save_attention(attention, filename):
|
||||||
filename = os.path.join(folder_paths.get_output_directory(), filename)
|
filename = os.path.join(folder_paths.get_output_directory(), filename)
|
||||||
save_file({"attention": attention}, filename)
|
joblib.dump(attention, filename)
|
||||||
print(f"Attention tensor saved to {filename}")
|
print(f"Attention tensor saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
# returns attention[a][b][c][d]
|
||||||
|
# a: number of steps/sigma in this diffusion process
|
||||||
|
# b: number of SpatialTransformer or AttentionBlocks used in the middle blocks of the latent diffusion model
|
||||||
|
# c: number of transformer layers in the SpatialTransformer or AttentionBlocks
|
||||||
|
# d: attn1, attn2
|
||||||
def load_attention(filename):
|
def load_attention(filename):
|
||||||
tensors = {}
|
|
||||||
filename = os.path.join(folder_paths.get_output_directory(), filename)
|
filename = os.path.join(folder_paths.get_output_directory(), filename)
|
||||||
with safe_open(filename, framework='pt', device=0) as f:
|
return joblib.load(filename)
|
||||||
for k in f.keys():
|
|
||||||
tensors[k] = f.get_tensor(k)
|
|
||||||
return tensors['attention']
|
|
||||||
|
|
||||||
|
|
||||||
def load_latent(filename):
|
def load_latent(filename):
|
||||||
filename = os.path.join(folder_paths.get_output_directory(), filename)
|
filename = os.path.join(folder_paths.get_output_directory(), filename)
|
||||||
|
|||||||
40
nodes.py
40
nodes.py
@ -866,7 +866,8 @@ class SetLatentNoiseMask:
|
|||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0,
|
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0,
|
||||||
disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, attention=None):
|
disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, attention=None,
|
||||||
|
attention_weight=0.0):
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
latent_image = latent["samples"]
|
latent_image = latent["samples"]
|
||||||
|
|
||||||
@ -880,9 +881,16 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||||||
if "noise_mask" in latent:
|
if "noise_mask" in latent:
|
||||||
noise_mask = latent["noise_mask"]
|
noise_mask = latent["noise_mask"]
|
||||||
|
|
||||||
samples, attention = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
samples, attention = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative,
|
||||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step,
|
||||||
force_full_denoise=force_full_denoise, noise_mask=noise_mask, attention=attention)
|
last_step=last_step, force_full_denoise=force_full_denoise, noise_mask=noise_mask,
|
||||||
|
attention=attention, attention_weight=attention_weight)
|
||||||
|
|
||||||
|
# attention[a][b][c][d]
|
||||||
|
# a: number of steps/sigma in this diffusion process
|
||||||
|
# b: number of SpatialTransformer or AttentionBlocks used in the middle blocks of the latent diffusion model
|
||||||
|
# c: number of transformer layers in the SpatialTransformer or AttentionBlocks
|
||||||
|
# d: attn1, attn2
|
||||||
out = latent.copy()
|
out = latent.copy()
|
||||||
out["samples"] = samples
|
out["samples"] = samples
|
||||||
return (out, attention)
|
return (out, attention)
|
||||||
@ -906,6 +914,7 @@ class KSampler:
|
|||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"attention": ("ATTENTION",),
|
"attention": ("ATTENTION",),
|
||||||
|
"attention_weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -915,10 +924,11 @@ class KSampler:
|
|||||||
CATEGORY = "sampling"
|
CATEGORY = "sampling"
|
||||||
|
|
||||||
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||||
denoise=1.0, attention=None):
|
denoise=1.0, attention=None, attention_weight=0.0):
|
||||||
model.model_options["transformer_options"]["return_attention"] = True
|
model.model_options["transformer_options"]["return_attention"] = True
|
||||||
|
|
||||||
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||||
denoise=denoise, attention=attention)
|
denoise=denoise, attention=attention, attention_weight=attention_weight)
|
||||||
|
|
||||||
class KSamplerAdvanced:
|
class KSamplerAdvanced:
|
||||||
def __init__(self, event_dispatcher):
|
def __init__(self, event_dispatcher):
|
||||||
@ -1338,8 +1348,19 @@ class PrintNode:
|
|||||||
print(f"Latent hash: {latent_hash}")
|
print(f"Latent hash: {latent_hash}")
|
||||||
print(np.array2string(latent["samples"].cpu().numpy(), separator=', '))
|
print(np.array2string(latent["samples"].cpu().numpy(), separator=', '))
|
||||||
|
|
||||||
|
# attention[a][b][c][d]
|
||||||
|
# a: number of steps/sigma in this diffusion process
|
||||||
|
# b: number of SpatialTransformer or AttentionBlocks used in the middle blocks of the latent diffusion model
|
||||||
|
# c: number of transformer layers in the SpatialTransformer or AttentionBlocks
|
||||||
|
# d: attn1, attn2
|
||||||
if attention is not None:
|
if attention is not None:
|
||||||
print(np.array2string(attention.cpu().numpy(), separator=', '))
|
print(f'attention has {len(attention)} steps')
|
||||||
|
print(f'each step has {len(attention[0])} transformer blocks')
|
||||||
|
print(f'each block has {len(attention[0][0])} transformer layers')
|
||||||
|
print(f'each transformer layer has {len(attention[0][0][0])} attention tensors (attn1, attn2)')
|
||||||
|
print(f'the shape of the attention tensors is {attention[0][0][0][0].shape}')
|
||||||
|
print(f'the first value of the first attention tensor is {attention[0][0][0][0][:1]}')
|
||||||
|
|
||||||
|
|
||||||
if text is not None:
|
if text is not None:
|
||||||
print(text)
|
print(text)
|
||||||
@ -1368,6 +1389,11 @@ class SaveAttention:
|
|||||||
CATEGORY = "operations"
|
CATEGORY = "operations"
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
# attention[a][b][c][d]
|
||||||
|
# a: number of steps/sigma in this diffusion process
|
||||||
|
# b: number of SpatialTransformer or AttentionBlocks used in the middle blocks of the latent diffusion model
|
||||||
|
# c: number of transformer layers in the SpatialTransformer or AttentionBlocks
|
||||||
|
# d: attn1, attn2
|
||||||
def save_attention(self, attention, filename):
|
def save_attention(self, attention, filename):
|
||||||
comfy.utils.save_attention(attention, filename)
|
comfy.utils.save_attention(attention, filename)
|
||||||
return {"ui": {"message": "Saved attention to " + filename}}
|
return {"ui": {"message": "Saved attention to " + filename}}
|
||||||
|
|||||||
@ -9,3 +9,4 @@ pytorch_lightning
|
|||||||
aiohttp
|
aiohttp
|
||||||
accelerate
|
accelerate
|
||||||
pyyaml
|
pyyaml
|
||||||
|
joblib
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user