Attentions really now should be saved and returned.

This commit is contained in:
InconsolableCellist 2023-04-30 00:38:16 -06:00
parent 5e062a88de
commit f969ec5108
8 changed files with 213 additions and 62 deletions

View File

@ -582,7 +582,7 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
@torch.no_grad()
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None, attention=None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
@ -590,12 +590,14 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
t_fn = lambda sigma: sigma.log().neg()
old_denoised = None
attention = None
attention_out = None
for i in trange(len(sigmas) - 1, disable=disable):
if attention is not None:
extra_args['attention'] = attention[i]
denoised, attn = model(x, sigmas[i] * s_in, **extra_args)
if attention is None:
attention = torch.empty((len(sigmas), *attn.shape), dtype=attn.dtype, device=attn.device)
attention[i] = attn
if attention_out is None:
attention_out = torch.empty((len(sigmas), *attn.shape), dtype=attn.dtype, device=attn.device)
attention_out[i] = attn
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
@ -609,4 +611,4 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
old_denoised = denoised
return x, attention
return x, attention_out

View File

@ -856,7 +856,7 @@ class LatentDiffusion(DDPM):
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
cond = {key: cond}
x_recon = self.model(x_noisy, t, **cond)
x_recon, attn = self.model(x_noisy, t, **cond)
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
@ -1318,6 +1318,7 @@ class DiffusionWrapper(torch.nn.Module):
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None, transformer_options={}):
return_attention = transformer_options.get("return_attention", False)
if self.conditioning_key is None:
out = self.diffusion_model(x, t, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'concat':
@ -1334,7 +1335,8 @@ class DiffusionWrapper(torch.nn.Module):
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
out = self.scripted_diffusion_model(x, t, cc, control=control, transformer_options=transformer_options)
else:
out = self.diffusion_model(x, t, context=cc, control=control, transformer_options=transformer_options)
out, attn = self.diffusion_model(x, t, context=cc, control=control,
transformer_options=transformer_options)
elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
@ -1354,7 +1356,10 @@ class DiffusionWrapper(torch.nn.Module):
else:
raise NotImplementedError()
return out
if return_attention:
return out, attn
else:
return out
class LatentUpscaleDiffusion(LatentDiffusion):

View File

@ -163,9 +163,7 @@ class CrossAttentionBirchSan(nn.Module):
nn.Dropout(dropout)
)
def forward(self, x, context=None, value=None, mask=None):
h = self.heads
def forward(self, x, context=None, value=None, mask=None, return_attention=False):
query = self.to_q(x)
context = default(context, x)
key = self.to_k(context)
@ -220,7 +218,7 @@ class CrossAttentionBirchSan(nn.Module):
kv_chunk_size = kv_chunk_size_x
kv_chunk_size_min = kv_chunk_size_min_x
hidden_states = efficient_dot_product_attention(
output = efficient_dot_product_attention(
query,
key_t,
value,
@ -229,7 +227,12 @@ class CrossAttentionBirchSan(nn.Module):
kv_chunk_size_min=kv_chunk_size_min,
use_checkpoint=self.training,
upcast_attention=upcast_attention,
return_attention=return_attention,
)
if return_attention:
hidden_states, attention = output
else:
hidden_states = output
hidden_states = hidden_states.to(dtype)
@ -239,7 +242,10 @@ class CrossAttentionBirchSan(nn.Module):
hidden_states = out_proj(hidden_states)
hidden_states = dropout(hidden_states)
return hidden_states
if return_attention:
return hidden_states, attention
else:
return hidden_states
class CrossAttentionDoggettx(nn.Module):
@ -358,7 +364,7 @@ class CrossAttention(nn.Module):
nn.Dropout(dropout)
)
def forward(self, x, context=None, value=None, mask=None):
def forward(self, x, context=None, value=None, mask=None, return_attention=False):
h = self.heads
q = self.to_q(x)
@ -393,7 +399,13 @@ class CrossAttention(nn.Module):
out = einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
out = self.to_out(out)
if return_attention:
sim = rearrange(sim, '(b h) i j -> b h i j', h=h)
return out, sim
else:
return out
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
@ -523,6 +535,7 @@ class BasicTransformerBlock(nn.Module):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}):
return_attention = transformer_options.get("return_attention", False)
current_index = None
if "current_index" in transformer_options:
current_index = transformer_options["current_index"]
@ -550,7 +563,10 @@ class BasicTransformerBlock(nn.Module):
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
else:
n = self.attn1(n, context=context_attn1, value=value_attn1)
if return_attention:
n, attn1_weights = self.attn1(n, context=context_attn1, value=value_attn1, return_attention=True)
else:
n = self.attn1(n, context=context_attn1, value=value_attn1)
x += n
if "middle_patch" in transformer_patches:
@ -568,14 +584,22 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2)
n = self.attn2(n, context=context_attn2, value=value_attn2)
if return_attention:
n, attn2_weights = self.attn2(n, context=context_attn2, value=value_attn2, return_attention=True)
else:
n = self.attn2(n, context=context_attn2, value=value_attn2)
x += n
x = self.ff(self.norm3(x)) + x
if current_index is not None:
transformer_options["current_index"] += 1
return x
if return_attention:
return x, (attn1_weights, attn2_weights)
else:
return x
class SpatialTransformer(nn.Module):
@ -633,12 +657,24 @@ class SpatialTransformer(nn.Module):
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
if self.use_linear:
x = self.proj_in(x)
attention_tensors = []
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i], transformer_options=transformer_options)
if transformer_options.get("return_attention", False):
x, attention = block(x, context=context[i], transformer_options=transformer_options)
attention_tensors.append(attention)
else:
x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
if transformer_options.get("return_attention", False):
return x + x_in, attention_tensors
else:
return x + x_in

View File

@ -81,7 +81,10 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
if transformer_options.get("attention", False):
x, attention = layer(x, context, transformer_options)
else:
x = layer(x, context, transformer_options)
else:
x = layer(x)
return x
@ -310,17 +313,26 @@ class AttentionBlock(nn.Module):
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
def forward(self, x, return_attention=False):
return checkpoint(self._forward, (x, return_attention), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
#return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x):
def _forward(self, x, return_attention=False):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
if return_attention:
h, attention_weights = self.attention(qkv, return_attention=True)
else:
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
h = (x + h).reshape(b, c, *spatial)
if return_attention:
return h, attention_weights
else:
return h
def count_flops_attn(model, _x, y):
@ -352,7 +364,7 @@ class QKVAttentionLegacy(nn.Module):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
def forward(self, qkv, return_attention=False):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
@ -368,7 +380,12 @@ class QKVAttentionLegacy(nn.Module):
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
a = a.reshape(bs, -1, length)
if return_attention:
return a, weight
else:
return a
@staticmethod
def count_flops(model, _x, y):
@ -384,7 +401,7 @@ class QKVAttention(nn.Module):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
def forward(self, qkv, return_attention=False):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
@ -402,7 +419,12 @@ class QKVAttention(nn.Module):
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
a = a.reshape(bs, -1, length)
if return_attention:
return a, weight
else:
return a
@staticmethod
def count_flops(model, _x, y):
@ -772,13 +794,18 @@ class UNetModel(nn.Module):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={},
input_attention=None, attention_weight=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:param control: a dictionary of control parameters for the model.
:param transformer_options: a dictionary of options to pass to the transformer
:param input_attention: an optional Tensor of attentions to weight to the attention block
:param attention_weight: the weight with which to mux the input attention
:return: an [N x C x ...] Tensor of outputs.
"""
transformer_options["original_shape"] = list(x.shape)
@ -796,14 +823,37 @@ class UNetModel(nn.Module):
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
input_and_output_options = transformer_options.copy()
input_and_output_options["return_attention"] = False # No attention to be had in the input blocks
for id, module in enumerate(self.input_blocks):
h = module(h, emb, context, transformer_options)
h = module(h, emb, context, input_and_output_options)
if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
hs.append(h)
h = self.middle_block(h, emb, context, transformer_options)
attention_tensors = []
for i, module in enumerate(self.middle_block):
if isinstance(module, AttentionBlock):
if transformer_options.get("return_attention", False):
h, attention = module(h, emb, context, transformer_options)
attention_tensors.append(attention)
# if input_attention is not None and attention_weight is not None:
# combined_attention = attention_weight * input_attention + (1 - attention_weight) * attention
# h = h * combined_attention
else:
h = module(h, emb, context, transformer_options)
elif isinstance(module, SpatialTransformer):
if transformer_options.get("return_attention", False):
h, attention = module(h, context, transformer_options)
attention_tensors.append(attention)
elif isinstance(module, TimestepBlock):
h = module(h, emb)
else:
h = module(h)
if control is not None and 'middle' in control and len(control['middle']) > 0:
h += control['middle'].pop()
@ -815,9 +865,9 @@ class UNetModel(nn.Module):
hsp += ctrl
h = th.cat([h, hsp], dim=1)
del hsp
h = module(h, emb, context, transformer_options)
h = module(h, emb, context, input_and_output_options)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
return self.id_predictor(h), attention_tensors
else:
return self.out(h)
return self.out(h), attention_tensors

View File

@ -95,6 +95,7 @@ def _query_chunk_attention(
value: Tensor,
summarize_chunk: SummarizeChunk,
kv_chunk_size: int,
return_attention: bool,
) -> Tensor:
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
_, _, v_channels_per_head = value.shape
@ -125,7 +126,11 @@ def _query_chunk_attention(
all_values = chunk_values.sum(dim=0)
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
return all_values / all_weights
if return_attention:
return all_values / all_weights, chunk_weights
else:
return all_values / all_weights
# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
@ -134,6 +139,7 @@ def _get_attention_scores_no_kv_chunking(
value: Tensor,
scale: float,
upcast_attention: bool,
return_attention: bool,
) -> Tensor:
if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'):
@ -167,7 +173,11 @@ def _get_attention_scores_no_kv_chunking(
attn_probs = attn_scores
hidden_states_slice = torch.bmm(attn_probs, value)
return hidden_states_slice
if return_attention:
return hidden_states_slice, attn_probs
else:
return hidden_states_slice
class ScannedChunk(NamedTuple):
chunk_idx: int
@ -182,6 +192,7 @@ def efficient_dot_product_attention(
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
upcast_attention=False,
return_attention=False,
):
"""Computes efficient dot-product attention given query, transposed key, and value.
This is efficient version of attention presented in
@ -197,6 +208,8 @@ def efficient_dot_product_attention(
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
upcast_attention: bool: whether to upcast attention to fp32 (?) #
return_attention: bool: whether to return attention weights
Returns:
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
"""
@ -220,13 +233,15 @@ def efficient_dot_product_attention(
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
_get_attention_scores_no_kv_chunking,
scale=scale,
upcast_attention=upcast_attention
upcast_attention=upcast_attention,
return_attention=return_attention,
) if k_tokens <= kv_chunk_size else (
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
partial(
_query_chunk_attention,
kv_chunk_size=kv_chunk_size,
summarize_chunk=summarize_chunk,
return_attention=return_attention,
)
)
@ -236,15 +251,25 @@ def efficient_dot_product_attention(
query=query,
key_t=key_t,
value=value,
return_attention=return_attention,
)
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
res = torch.cat([
results = [
compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size),
key_t=key_t,
value=value,
return_attention=return_attention,
) for i in range(math.ceil(q_tokens / query_chunk_size))
], dim=1)
return res
]
res = torch.cat([result[0] if return_attention else result for result in results], dim=1)
if return_attention:
attn_weights = [result[1] for result in results]
return res, attn_weights
else:
return res

View File

@ -56,7 +56,9 @@ def cleanup_additional_models(models):
for m in models:
m.cleanup()
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None):
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0,
disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None,
sigmas=None, callback=None, attention=None):
device = comfy.model_management.get_torch_device()
if noise_mask is not None:
@ -74,9 +76,13 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
models = load_additional_models(positive, negative)
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler,
denoise=denoise, model_options=model.model_options)
samples, attention = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback)
samples, attention = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image,
start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, denoise_mask=noise_mask,
sigmas=sigmas, callback=callback, attention=attention)
samples = samples.cpu()
attention = attention.cpu()

View File

@ -9,7 +9,8 @@ from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
#The main sampling function shared by all the samplers
#Returns predicted noise
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}):
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={},
attention=None, attention_weight=0.5):
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
@ -125,7 +126,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
out['c_adm'] = torch.cat(c_adm)
return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in,
model_options, attention=None, attention_weight=0.5):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0
@ -192,6 +194,11 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
# if attention is not None:
# generated_attention = c['c_crossattn'][0]
# mixed_attention = attention_weight * torch.cat(attention) + (1 - attention_weight) * generated_attention
# c['c_crossattn'] = [mixed_attention]
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
@ -207,6 +214,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
else:
transformer_options["patches"] = patches
# transformer_options['return_attention'] = True
c['transformer_options'] = transformer_options
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
@ -232,7 +240,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat,
model_options, attention=attention)
if "sampler_cfg_function" in model_options:
return model_options["sampler_cfg_function"](cond, uncond, cond_scale), cond[0] # cond[0] is attention
else:
@ -252,8 +261,9 @@ class CFGNoisePredictor(torch.nn.Module):
super().__init__()
self.inner_model = model
self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}):
out, attn = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options)
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, attention=None):
out, attn = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat,
model_options=model_options, attention=attention)
return out, attn
@ -261,11 +271,13 @@ class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}):
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={},
attention=None):
if denoise_mask is not None:
latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
out, attn = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options)
out, attn = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat,
model_options=model_options, attention=attention)
if denoise_mask is not None:
out *= denoise_mask
@ -462,7 +474,8 @@ class KSampler:
self.sigmas = sigmas[-(steps + 1):]
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None):
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None,
force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, attention=None):
if sigmas is None:
sigmas = self.sigmas
sigma_min = self.sigma_min
@ -580,6 +593,7 @@ class KSampler:
elif self.sampler == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback)
else:
samples, attention = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback)
samples, attention = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k,
noise, sigmas, extra_args=extra_args, callback=k_callback, attention=attention)
return samples.to(torch.float32), attention

View File

@ -865,7 +865,8 @@ class SetLatentNoiseMask:
s["noise_mask"] = mask
return (s,)
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0,
disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, attention=None):
device = comfy.model_management.get_torch_device()
latent_image = latent["samples"]
@ -881,7 +882,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
samples, attention = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask)
force_full_denoise=force_full_denoise, noise_mask=noise_mask, attention=attention)
out = latent.copy()
out["samples"] = samples
return (out, attention)
@ -902,15 +903,22 @@ class KSampler:
"negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
},
"optional": {
"attention": ("ATTENTION",),
}
}
RETURN_TYPES = ("LATENT","ATTENTION")
FUNCTION = "sample"
CATEGORY = "sampling"
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=1.0, attention=None):
model.model_options["transformer_options"]["return_attention"] = True
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, attention=attention)
class KSamplerAdvanced:
def __init__(self, event_dispatcher):
@ -931,14 +939,19 @@ class KSamplerAdvanced:
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
"return_with_leftover_noise": (["disable", "enable"], ),
}}
},
"optional": {
"attention": ("ATTENTION",),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "sample"
CATEGORY = "sampling"
def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0):
def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative,
latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, attention=None):
force_full_denoise = True
if return_with_leftover_noise == "enable":
force_full_denoise = False