Support flux 2 klein kv cache model: Use the FluxKVCache node. (#12905)

This commit is contained in:
comfyanonymous 2026-03-12 08:30:50 -07:00 committed by GitHub
parent 8f9ea49571
commit 44f1246c89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 129 additions and 11 deletions

View File

@ -44,6 +44,22 @@ class FluxParams:
txt_norm: bool = False txt_norm: bool = False
def invert_slices(slices, length):
sorted_slices = sorted(slices)
result = []
current = 0
for start, end in sorted_slices:
if current < start:
result.append((current, start))
current = max(current, end)
if current < length:
result.append((current, length))
return result
class Flux(nn.Module): class Flux(nn.Module):
""" """
Transformer model for flow matching on sequences. Transformer model for flow matching on sequences.
@ -138,6 +154,7 @@ class Flux(nn.Module):
y: Tensor, y: Tensor,
guidance: Tensor = None, guidance: Tensor = None,
control = None, control = None,
timestep_zero_index=None,
transformer_options={}, transformer_options={},
attn_mask: Tensor = None, attn_mask: Tensor = None,
) -> Tensor: ) -> Tensor:
@ -164,10 +181,6 @@ class Flux(nn.Module):
txt = self.txt_norm(txt) txt = self.txt_norm(txt)
txt = self.txt_in(txt) txt = self.txt_in(txt)
vec_orig = vec
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
if "post_input" in patches: if "post_input" in patches:
for p in patches["post_input"]: for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
@ -182,6 +195,24 @@ class Flux(nn.Module):
else: else:
pe = None pe = None
vec_orig = vec
txt_vec = vec
extra_kwargs = {}
if timestep_zero_index is not None:
modulation_dims = []
batch = vec.shape[0] // 2
vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
invert = invert_slices(timestep_zero_index, img.shape[1])
for s in invert:
modulation_dims.append((s[0], s[1], 0))
for s in timestep_zero_index:
modulation_dims.append((s[0], s[1], 1))
extra_kwargs["modulation_dims_img"] = modulation_dims
txt_vec = vec[:batch]
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks) transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double" transformer_options["block_type"] = "double"
@ -195,7 +226,8 @@ class Flux(nn.Module):
vec=args["vec"], vec=args["vec"],
pe=args["pe"], pe=args["pe"],
attn_mask=args.get("attn_mask"), attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options")) transformer_options=args.get("transformer_options"),
**extra_kwargs)
return out return out
out = blocks_replace[("double_block", i)]({"img": img, out = blocks_replace[("double_block", i)]({"img": img,
@ -213,7 +245,8 @@ class Flux(nn.Module):
vec=vec, vec=vec,
pe=pe, pe=pe,
attn_mask=attn_mask, attn_mask=attn_mask,
transformer_options=transformer_options) transformer_options=transformer_options,
**extra_kwargs)
if control is not None: # Controlnet if control is not None: # Controlnet
control_i = control.get("input") control_i = control.get("input")
@ -230,6 +263,12 @@ class Flux(nn.Module):
if self.params.global_modulation: if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig) vec, _ = self.single_stream_modulation(vec_orig)
extra_kwargs = {}
if timestep_zero_index is not None:
lambda a: 0 if a == 0 else a + txt.shape[1]
modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
extra_kwargs["modulation_dims"] = modulation_dims_combined
transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single" transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]] transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
@ -242,7 +281,8 @@ class Flux(nn.Module):
vec=args["vec"], vec=args["vec"],
pe=args["pe"], pe=args["pe"],
attn_mask=args.get("attn_mask"), attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options")) transformer_options=args.get("transformer_options"),
**extra_kwargs)
return out return out
out = blocks_replace[("single_block", i)]({"img": img, out = blocks_replace[("single_block", i)]({"img": img,
@ -253,7 +293,7 @@ class Flux(nn.Module):
{"original_block": block_wrap}) {"original_block": block_wrap})
img = out["img"] img = out["img"]
else: else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)
if control is not None: # Controlnet if control is not None: # Controlnet
control_o = control.get("output") control_o = control.get("output")
@ -264,7 +304,11 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...] img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels) extra_kwargs = {}
if timestep_zero_index is not None:
extra_kwargs["modulation_dims"] = modulation_dims
img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
return img return img
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
@ -312,13 +356,16 @@ class Flux(nn.Module):
w_len = ((w_orig + (patch_size // 2)) // patch_size) w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x, transformer_options=transformer_options) img, img_ids = self.process_img(x, transformer_options=transformer_options)
img_tokens = img.shape[1] img_tokens = img.shape[1]
timestep_zero_index = None
if ref_latents is not None: if ref_latents is not None:
ref_num_tokens = []
h = 0 h = 0
w = 0 w = 0
index = 0 index = 0
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method) ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
timestep_zero = ref_latents_method == "index_timestep_zero"
for ref in ref_latents: for ref in ref_latents:
if ref_latents_method == "index": if ref_latents_method in ("index", "index_timestep_zero"):
index += self.params.ref_index_scale index += self.params.ref_index_scale
h_offset = 0 h_offset = 0
w_offset = 0 w_offset = 0
@ -342,6 +389,13 @@ class Flux(nn.Module):
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1) img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
@ -349,6 +403,6 @@ class Flux(nn.Module):
for i in self.params.txt_ids_dims: for i in self.params.txt_ids_dims:
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens] out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]

View File

@ -6,6 +6,7 @@ import comfy.model_management
import torch import torch
import math import math
import nodes import nodes
import comfy.ldm.flux.math
class CLIPTextEncodeFlux(io.ComfyNode): class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod @classmethod
@ -231,6 +232,68 @@ class Flux2Scheduler(io.ComfyNode):
sigmas = get_schedule(steps, round(seq_len)) sigmas = get_schedule(steps, round(seq_len))
return io.NodeOutput(sigmas) return io.NodeOutput(sigmas)
class KV_Attn_Input:
def __init__(self):
self.cache = {}
def __call__(self, q, k, v, extra_options, **kwargs):
reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
if len(reference_image_num_tokens) == 0:
return {}
ref_toks = sum(reference_image_num_tokens)
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
if cache_key in self.cache:
kk, vv = self.cache[cache_key]
self.set_cache = False
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
self.cache[cache_key] = (k[:, :, -ref_toks:], v[:, :, -ref_toks:])
self.set_cache = True
return {"q": q, "k": k, "v": v}
def cleanup(self):
self.cache = {}
class FluxKVCache(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="FluxKVCache",
display_name="Flux KV Cache",
description="Enables KV Cache optimization for reference images on Flux family models.",
category="",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to use KV Cache on."),
],
outputs=[
io.Model.Output(tooltip="The patched model with KV Cache enabled."),
],
)
@classmethod
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
m = model.clone()
input_patch_obj = KV_Attn_Input()
def model_input_patch(inputs):
if len(input_patch_obj.cache) > 0:
ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
if ref_image_tokens > 0:
img = inputs["img"]
inputs["img"] = img[:, :-ref_image_tokens]
return inputs
m.set_model_attn1_patch(input_patch_obj)
m.set_model_post_input_patch(model_input_patch)
if hasattr(model.model.diffusion_model, "params"):
m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
else:
m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
return io.NodeOutput(m)
class FluxExtension(ComfyExtension): class FluxExtension(ComfyExtension):
@override @override
@ -243,6 +306,7 @@ class FluxExtension(ComfyExtension):
FluxKontextMultiReferenceLatentMethod, FluxKontextMultiReferenceLatentMethod,
EmptyFlux2LatentImage, EmptyFlux2LatentImage,
Flux2Scheduler, Flux2Scheduler,
FluxKVCache,
] ]