diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 1bf6034d1..4bef57580 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -582,7 +582,7 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N @torch.no_grad() -def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): +def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None, attention=None): """DPM-Solver++(2M).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -590,12 +590,14 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No t_fn = lambda sigma: sigma.log().neg() old_denoised = None - attention = None + attention_out = None for i in trange(len(sigmas) - 1, disable=disable): + if attention is not None: + extra_args['attention'] = attention[i] denoised, attn = model(x, sigmas[i] * s_in, **extra_args) - if attention is None: - attention = torch.empty((len(sigmas), *attn.shape), dtype=attn.dtype, device=attn.device) - attention[i] = attn + if attention_out is None: + attention_out = torch.empty((len(sigmas), *attn.shape), dtype=attn.dtype, device=attn.device) + attention_out[i] = attn if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) @@ -609,4 +611,4 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d old_denoised = denoised - return x, attention + return x, attention_out diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py index d3f0eb2b2..22b6f52aa 100644 --- a/comfy/ldm/models/diffusion/ddpm.py +++ b/comfy/ldm/models/diffusion/ddpm.py @@ -856,7 +856,7 @@ class LatentDiffusion(DDPM): key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' cond = {key: cond} - x_recon = self.model(x_noisy, t, **cond) + x_recon, attn = self.model(x_noisy, t, **cond) if isinstance(x_recon, tuple) and not return_ids: return x_recon[0] @@ -1318,6 +1318,7 @@ class DiffusionWrapper(torch.nn.Module): assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None, transformer_options={}): + return_attention = transformer_options.get("return_attention", False) if self.conditioning_key is None: out = self.diffusion_model(x, t, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'concat': @@ -1334,7 +1335,8 @@ class DiffusionWrapper(torch.nn.Module): # an error: RuntimeError: forward() is missing value for argument 'argument_3'. out = self.scripted_diffusion_model(x, t, cc, control=control, transformer_options=transformer_options) else: - out = self.diffusion_model(x, t, context=cc, control=control, transformer_options=transformer_options) + out, attn = self.diffusion_model(x, t, context=cc, control=control, + transformer_options=transformer_options) elif self.conditioning_key == 'hybrid': xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) @@ -1354,7 +1356,10 @@ class DiffusionWrapper(torch.nn.Module): else: raise NotImplementedError() - return out + if return_attention: + return out, attn + else: + return out class LatentUpscaleDiffusion(LatentDiffusion): diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index ce7180d91..d96a2dbec 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -163,9 +163,7 @@ class CrossAttentionBirchSan(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, value=None, mask=None): - h = self.heads - + def forward(self, x, context=None, value=None, mask=None, return_attention=False): query = self.to_q(x) context = default(context, x) key = self.to_k(context) @@ -220,7 +218,7 @@ class CrossAttentionBirchSan(nn.Module): kv_chunk_size = kv_chunk_size_x kv_chunk_size_min = kv_chunk_size_min_x - hidden_states = efficient_dot_product_attention( + output = efficient_dot_product_attention( query, key_t, value, @@ -229,7 +227,12 @@ class CrossAttentionBirchSan(nn.Module): kv_chunk_size_min=kv_chunk_size_min, use_checkpoint=self.training, upcast_attention=upcast_attention, + return_attention=return_attention, ) + if return_attention: + hidden_states, attention = output + else: + hidden_states = output hidden_states = hidden_states.to(dtype) @@ -239,7 +242,10 @@ class CrossAttentionBirchSan(nn.Module): hidden_states = out_proj(hidden_states) hidden_states = dropout(hidden_states) - return hidden_states + if return_attention: + return hidden_states, attention + else: + return hidden_states class CrossAttentionDoggettx(nn.Module): @@ -358,7 +364,7 @@ class CrossAttention(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, value=None, mask=None): + def forward(self, x, context=None, value=None, mask=None, return_attention=False): h = self.heads q = self.to_q(x) @@ -393,7 +399,13 @@ class CrossAttention(nn.Module): out = einsum('b i j, b j d -> b i d', sim, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) + out = self.to_out(out) + + if return_attention: + sim = rearrange(sim, '(b h) i j -> b h i j', h=h) + return out, sim + else: + return out class MemoryEfficientCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 @@ -523,6 +535,7 @@ class BasicTransformerBlock(nn.Module): return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) def _forward(self, x, context=None, transformer_options={}): + return_attention = transformer_options.get("return_attention", False) current_index = None if "current_index" in transformer_options: current_index = transformer_options["current_index"] @@ -550,7 +563,10 @@ class BasicTransformerBlock(nn.Module): m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) n = u(self.attn1(m(n), context=context_attn1, value=value_attn1)) else: - n = self.attn1(n, context=context_attn1, value=value_attn1) + if return_attention: + n, attn1_weights = self.attn1(n, context=context_attn1, value=value_attn1, return_attention=True) + else: + n = self.attn1(n, context=context_attn1, value=value_attn1) x += n if "middle_patch" in transformer_patches: @@ -568,14 +584,22 @@ class BasicTransformerBlock(nn.Module): for p in patch: n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2) - n = self.attn2(n, context=context_attn2, value=value_attn2) + + if return_attention: + n, attn2_weights = self.attn2(n, context=context_attn2, value=value_attn2, return_attention=True) + else: + n = self.attn2(n, context=context_attn2, value=value_attn2) x += n x = self.ff(self.norm3(x)) + x if current_index is not None: transformer_options["current_index"] += 1 - return x + + if return_attention: + return x, (attn1_weights, attn2_weights) + else: + return x class SpatialTransformer(nn.Module): @@ -633,12 +657,24 @@ class SpatialTransformer(nn.Module): x = rearrange(x, 'b c h w -> b (h w) c').contiguous() if self.use_linear: x = self.proj_in(x) + + attention_tensors = [] for i, block in enumerate(self.transformer_blocks): - x = block(x, context=context[i], transformer_options=transformer_options) + if transformer_options.get("return_attention", False): + x, attention = block(x, context=context[i], transformer_options=transformer_options) + attention_tensors.append(attention) + else: + x = block(x, context=context[i], transformer_options=transformer_options) + if self.use_linear: x = self.proj_out(x) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: x = self.proj_out(x) - return x + x_in + + if transformer_options.get("return_attention", False): + return x + x_in, attention_tensors + else: + return x + x_in diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4c69c8567..f204a4773 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -81,7 +81,10 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): if isinstance(layer, TimestepBlock): x = layer(x, emb) elif isinstance(layer, SpatialTransformer): - x = layer(x, context, transformer_options) + if transformer_options.get("attention", False): + x, attention = layer(x, context, transformer_options) + else: + x = layer(x, context, transformer_options) else: x = layer(x) return x @@ -310,17 +313,26 @@ class AttentionBlock(nn.Module): self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) - def forward(self, x): - return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + def forward(self, x, return_attention=False): + return checkpoint(self._forward, (x, return_attention), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! #return pt_checkpoint(self._forward, x) # pytorch - def _forward(self, x): + def _forward(self, x, return_attention=False): b, c, *spatial = x.shape x = x.reshape(b, c, -1) qkv = self.qkv(self.norm(x)) - h = self.attention(qkv) + + if return_attention: + h, attention_weights = self.attention(qkv, return_attention=True) + else: + h = self.attention(qkv) + h = self.proj_out(h) - return (x + h).reshape(b, c, *spatial) + h = (x + h).reshape(b, c, *spatial) + if return_attention: + return h, attention_weights + else: + return h def count_flops_attn(model, _x, y): @@ -352,7 +364,7 @@ class QKVAttentionLegacy(nn.Module): super().__init__() self.n_heads = n_heads - def forward(self, qkv): + def forward(self, qkv, return_attention=False): """ Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. @@ -368,7 +380,12 @@ class QKVAttentionLegacy(nn.Module): ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v) - return a.reshape(bs, -1, length) + a = a.reshape(bs, -1, length) + + if return_attention: + return a, weight + else: + return a @staticmethod def count_flops(model, _x, y): @@ -384,7 +401,7 @@ class QKVAttention(nn.Module): super().__init__() self.n_heads = n_heads - def forward(self, qkv): + def forward(self, qkv, return_attention=False): """ Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. @@ -402,7 +419,12 @@ class QKVAttention(nn.Module): ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) - return a.reshape(bs, -1, length) + a = a.reshape(bs, -1, length) + + if return_attention: + return a, weight + else: + return a @staticmethod def count_flops(model, _x, y): @@ -772,13 +794,18 @@ class UNetModel(nn.Module): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): + def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, + input_attention=None, attention_weight=None, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if class-conditional. + :param control: a dictionary of control parameters for the model. + :param transformer_options: a dictionary of options to pass to the transformer + :param input_attention: an optional Tensor of attentions to weight to the attention block + :param attention_weight: the weight with which to mux the input attention :return: an [N x C x ...] Tensor of outputs. """ transformer_options["original_shape"] = list(x.shape) @@ -796,14 +823,37 @@ class UNetModel(nn.Module): emb = emb + self.label_emb(y) h = x.type(self.dtype) + input_and_output_options = transformer_options.copy() + input_and_output_options["return_attention"] = False # No attention to be had in the input blocks + for id, module in enumerate(self.input_blocks): - h = module(h, emb, context, transformer_options) + h = module(h, emb, context, input_and_output_options) if control is not None and 'input' in control and len(control['input']) > 0: ctrl = control['input'].pop() if ctrl is not None: h += ctrl hs.append(h) - h = self.middle_block(h, emb, context, transformer_options) + + attention_tensors = [] + for i, module in enumerate(self.middle_block): + if isinstance(module, AttentionBlock): + if transformer_options.get("return_attention", False): + h, attention = module(h, emb, context, transformer_options) + attention_tensors.append(attention) + # if input_attention is not None and attention_weight is not None: + # combined_attention = attention_weight * input_attention + (1 - attention_weight) * attention + # h = h * combined_attention + else: + h = module(h, emb, context, transformer_options) + elif isinstance(module, SpatialTransformer): + if transformer_options.get("return_attention", False): + h, attention = module(h, context, transformer_options) + attention_tensors.append(attention) + elif isinstance(module, TimestepBlock): + h = module(h, emb) + else: + h = module(h) + if control is not None and 'middle' in control and len(control['middle']) > 0: h += control['middle'].pop() @@ -815,9 +865,9 @@ class UNetModel(nn.Module): hsp += ctrl h = th.cat([h, hsp], dim=1) del hsp - h = module(h, emb, context, transformer_options) + h = module(h, emb, context, input_and_output_options) h = h.type(x.dtype) if self.predict_codebook_ids: - return self.id_predictor(h) + return self.id_predictor(h), attention_tensors else: - return self.out(h) + return self.out(h), attention_tensors diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 573cce74f..f1e48d62a 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -95,6 +95,7 @@ def _query_chunk_attention( value: Tensor, summarize_chunk: SummarizeChunk, kv_chunk_size: int, + return_attention: bool, ) -> Tensor: batch_x_heads, k_channels_per_head, k_tokens = key_t.shape _, _, v_channels_per_head = value.shape @@ -125,7 +126,11 @@ def _query_chunk_attention( all_values = chunk_values.sum(dim=0) all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) - return all_values / all_weights + + if return_attention: + return all_values / all_weights, chunk_weights + else: + return all_values / all_weights # TODO: refactor CrossAttention#get_attention_scores to share code with this def _get_attention_scores_no_kv_chunking( @@ -134,6 +139,7 @@ def _get_attention_scores_no_kv_chunking( value: Tensor, scale: float, upcast_attention: bool, + return_attention: bool, ) -> Tensor: if upcast_attention: with torch.autocast(enabled=False, device_type = 'cuda'): @@ -167,7 +173,11 @@ def _get_attention_scores_no_kv_chunking( attn_probs = attn_scores hidden_states_slice = torch.bmm(attn_probs, value) - return hidden_states_slice + + if return_attention: + return hidden_states_slice, attn_probs + else: + return hidden_states_slice class ScannedChunk(NamedTuple): chunk_idx: int @@ -182,6 +192,7 @@ def efficient_dot_product_attention( kv_chunk_size_min: Optional[int] = None, use_checkpoint=True, upcast_attention=False, + return_attention=False, ): """Computes efficient dot-product attention given query, transposed key, and value. This is efficient version of attention presented in @@ -197,6 +208,8 @@ def efficient_dot_product_attention( kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens) kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) + upcast_attention: bool: whether to upcast attention to fp32 (?) # + return_attention: bool: whether to return attention weights Returns: Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. """ @@ -220,13 +233,15 @@ def efficient_dot_product_attention( compute_query_chunk_attn: ComputeQueryChunkAttn = partial( _get_attention_scores_no_kv_chunking, scale=scale, - upcast_attention=upcast_attention + upcast_attention=upcast_attention, + return_attention=return_attention, ) if k_tokens <= kv_chunk_size else ( # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) partial( _query_chunk_attention, kv_chunk_size=kv_chunk_size, summarize_chunk=summarize_chunk, + return_attention=return_attention, ) ) @@ -236,15 +251,25 @@ def efficient_dot_product_attention( query=query, key_t=key_t, value=value, + return_attention=return_attention, ) # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, # and pass slices to be mutated, instead of torch.cat()ing the returned slices - res = torch.cat([ + results = [ compute_query_chunk_attn( query=get_query_chunk(i * query_chunk_size), key_t=key_t, value=value, + return_attention=return_attention, ) for i in range(math.ceil(q_tokens / query_chunk_size)) - ], dim=1) - return res + ] + + res = torch.cat([result[0] if return_attention else result for result in results], dim=1) + + if return_attention: + attn_weights = [result[1] for result in results] + return res, attn_weights + else: + return res + diff --git a/comfy/sample.py b/comfy/sample.py index f28123106..0676329f0 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -56,7 +56,9 @@ def cleanup_additional_models(models): for m in models: m.cleanup() -def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None): +def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, + disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, + sigmas=None, callback=None, attention=None): device = comfy.model_management.get_torch_device() if noise_mask is not None: @@ -74,9 +76,13 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative models = load_additional_models(positive, negative) - sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) + sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, + denoise=denoise, model_options=model.model_options) - samples, attention = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback) + samples, attention = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, + start_step=start_step, last_step=last_step, + force_full_denoise=force_full_denoise, denoise_mask=noise_mask, + sigmas=sigmas, callback=callback, attention=attention) samples = samples.cpu() attention = attention.cpu() diff --git a/comfy/samplers.py b/comfy/samplers.py index b7a1cf207..b52993cd1 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -9,7 +9,8 @@ from .ldm.modules.diffusionmodules.util import make_ddim_timesteps #The main sampling function shared by all the samplers #Returns predicted noise -def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}): +def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}, + attention=None, attention_weight=0.5): def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 @@ -125,7 +126,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con out['c_adm'] = torch.cat(c_adm) return out - def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options): + def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, + model_options, attention=None, attention_weight=0.5): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in)/100000.0 @@ -192,6 +194,11 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if control is not None: c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond)) + # if attention is not None: + # generated_attention = c['c_crossattn'][0] + # mixed_attention = attention_weight * torch.cat(attention) + (1 - attention_weight) * generated_attention + # c['c_crossattn'] = [mixed_attention] + transformer_options = {} if 'transformer_options' in model_options: transformer_options = model_options['transformer_options'].copy() @@ -207,6 +214,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con else: transformer_options["patches"] = patches + # transformer_options['return_attention'] = True c['transformer_options'] = transformer_options output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) @@ -232,7 +240,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con max_total_area = model_management.maximum_batch_area() - cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) + cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, + model_options, attention=attention) if "sampler_cfg_function" in model_options: return model_options["sampler_cfg_function"](cond, uncond, cond_scale), cond[0] # cond[0] is attention else: @@ -252,8 +261,9 @@ class CFGNoisePredictor(torch.nn.Module): super().__init__() self.inner_model = model self.alphas_cumprod = model.alphas_cumprod - def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}): - out, attn = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options) + def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, attention=None): + out, attn = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, + model_options=model_options, attention=attention) return out, attn @@ -261,11 +271,13 @@ class KSamplerX0Inpaint(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}): + def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}, + attention=None): if denoise_mask is not None: latent_mask = 1. - denoise_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask - out, attn = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options) + out, attn = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, + model_options=model_options, attention=attention) if denoise_mask is not None: out *= denoise_mask @@ -462,7 +474,8 @@ class KSampler: self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None): + def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, + force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, attention=None): if sigmas is None: sigmas = self.sigmas sigma_min = self.sigma_min @@ -580,6 +593,7 @@ class KSampler: elif self.sampler == "dpm_adaptive": samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback) else: - samples, attention = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback) + samples, attention = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, + noise, sigmas, extra_args=extra_args, callback=k_callback, attention=attention) return samples.to(torch.float32), attention diff --git a/nodes.py b/nodes.py index 52d34641d..f19692623 100644 --- a/nodes.py +++ b/nodes.py @@ -865,7 +865,8 @@ class SetLatentNoiseMask: s["noise_mask"] = mask return (s,) -def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): +def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, + disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, attention=None): device = comfy.model_management.get_torch_device() latent_image = latent["samples"] @@ -881,7 +882,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, samples, attention = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, - force_full_denoise=force_full_denoise, noise_mask=noise_mask) + force_full_denoise=force_full_denoise, noise_mask=noise_mask, attention=attention) out = latent.copy() out["samples"] = samples return (out, attention) @@ -902,15 +903,22 @@ class KSampler: "negative": ("CONDITIONING", ), "latent_image": ("LATENT", ), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - }} + }, + "optional": { + "attention": ("ATTENTION",), + } + } RETURN_TYPES = ("LATENT","ATTENTION") FUNCTION = "sample" CATEGORY = "sampling" - def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): - return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) + def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + denoise=1.0, attention=None): + model.model_options["transformer_options"]["return_attention"] = True + return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + denoise=denoise, attention=attention) class KSamplerAdvanced: def __init__(self, event_dispatcher): @@ -931,14 +939,19 @@ class KSamplerAdvanced: "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), "return_with_leftover_noise": (["disable", "enable"], ), - }} + }, + "optional": { + "attention": ("ATTENTION",), + } + } RETURN_TYPES = ("LATENT",) FUNCTION = "sample" CATEGORY = "sampling" - def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0): + def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, + latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, attention=None): force_full_denoise = True if return_with_leftover_noise == "enable": force_full_denoise = False