mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-13 13:17:45 +08:00
Support flux 2 klein kv cache model: Use the FluxKVCache node. (#12905)
This commit is contained in:
parent
8f9ea49571
commit
44f1246c89
@ -44,6 +44,22 @@ class FluxParams:
|
||||
txt_norm: bool = False
|
||||
|
||||
|
||||
def invert_slices(slices, length):
|
||||
sorted_slices = sorted(slices)
|
||||
result = []
|
||||
current = 0
|
||||
|
||||
for start, end in sorted_slices:
|
||||
if current < start:
|
||||
result.append((current, start))
|
||||
current = max(current, end)
|
||||
|
||||
if current < length:
|
||||
result.append((current, length))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
@ -138,6 +154,7 @@ class Flux(nn.Module):
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control = None,
|
||||
timestep_zero_index=None,
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
@ -164,10 +181,6 @@ class Flux(nn.Module):
|
||||
txt = self.txt_norm(txt)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
vec_orig = vec
|
||||
if self.params.global_modulation:
|
||||
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
|
||||
@ -182,6 +195,24 @@ class Flux(nn.Module):
|
||||
else:
|
||||
pe = None
|
||||
|
||||
vec_orig = vec
|
||||
txt_vec = vec
|
||||
extra_kwargs = {}
|
||||
if timestep_zero_index is not None:
|
||||
modulation_dims = []
|
||||
batch = vec.shape[0] // 2
|
||||
vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
|
||||
invert = invert_slices(timestep_zero_index, img.shape[1])
|
||||
for s in invert:
|
||||
modulation_dims.append((s[0], s[1], 0))
|
||||
for s in timestep_zero_index:
|
||||
modulation_dims.append((s[0], s[1], 1))
|
||||
extra_kwargs["modulation_dims_img"] = modulation_dims
|
||||
txt_vec = vec[:batch]
|
||||
|
||||
if self.params.global_modulation:
|
||||
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
@ -195,7 +226,8 @@ class Flux(nn.Module):
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
transformer_options=args.get("transformer_options"),
|
||||
**extra_kwargs)
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
@ -213,7 +245,8 @@ class Flux(nn.Module):
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
transformer_options=transformer_options,
|
||||
**extra_kwargs)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@ -230,6 +263,12 @@ class Flux(nn.Module):
|
||||
if self.params.global_modulation:
|
||||
vec, _ = self.single_stream_modulation(vec_orig)
|
||||
|
||||
extra_kwargs = {}
|
||||
if timestep_zero_index is not None:
|
||||
lambda a: 0 if a == 0 else a + txt.shape[1]
|
||||
modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
|
||||
extra_kwargs["modulation_dims"] = modulation_dims_combined
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
@ -242,7 +281,8 @@ class Flux(nn.Module):
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
transformer_options=args.get("transformer_options"),
|
||||
**extra_kwargs)
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
@ -253,7 +293,7 @@ class Flux(nn.Module):
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@ -264,7 +304,11 @@ class Flux(nn.Module):
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
|
||||
extra_kwargs = {}
|
||||
if timestep_zero_index is not None:
|
||||
extra_kwargs["modulation_dims"] = modulation_dims
|
||||
|
||||
img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||
@ -312,13 +356,16 @@ class Flux(nn.Module):
|
||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||
img, img_ids = self.process_img(x, transformer_options=transformer_options)
|
||||
img_tokens = img.shape[1]
|
||||
timestep_zero_index = None
|
||||
if ref_latents is not None:
|
||||
ref_num_tokens = []
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
|
||||
timestep_zero = ref_latents_method == "index_timestep_zero"
|
||||
for ref in ref_latents:
|
||||
if ref_latents_method == "index":
|
||||
if ref_latents_method in ("index", "index_timestep_zero"):
|
||||
index += self.params.ref_index_scale
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
@ -342,6 +389,13 @@ class Flux(nn.Module):
|
||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
ref_num_tokens.append(kontext.shape[1])
|
||||
if timestep_zero:
|
||||
if index > 0:
|
||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||
timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
|
||||
transformer_options = transformer_options.copy()
|
||||
transformer_options["reference_image_num_tokens"] = ref_num_tokens
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
|
||||
@ -349,6 +403,6 @@ class Flux(nn.Module):
|
||||
for i in self.params.txt_ids_dims:
|
||||
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = out[:, :img_tokens]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
|
||||
|
||||
@ -6,6 +6,7 @@ import comfy.model_management
|
||||
import torch
|
||||
import math
|
||||
import nodes
|
||||
import comfy.ldm.flux.math
|
||||
|
||||
class CLIPTextEncodeFlux(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -231,6 +232,68 @@ class Flux2Scheduler(io.ComfyNode):
|
||||
sigmas = get_schedule(steps, round(seq_len))
|
||||
return io.NodeOutput(sigmas)
|
||||
|
||||
class KV_Attn_Input:
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def __call__(self, q, k, v, extra_options, **kwargs):
|
||||
reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
|
||||
if len(reference_image_num_tokens) == 0:
|
||||
return {}
|
||||
|
||||
ref_toks = sum(reference_image_num_tokens)
|
||||
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
|
||||
if cache_key in self.cache:
|
||||
kk, vv = self.cache[cache_key]
|
||||
self.set_cache = False
|
||||
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
|
||||
|
||||
self.cache[cache_key] = (k[:, :, -ref_toks:], v[:, :, -ref_toks:])
|
||||
self.set_cache = True
|
||||
return {"q": q, "k": k, "v": v}
|
||||
|
||||
def cleanup(self):
|
||||
self.cache = {}
|
||||
|
||||
|
||||
class FluxKVCache(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="FluxKVCache",
|
||||
display_name="Flux KV Cache",
|
||||
description="Enables KV Cache optimization for reference images on Flux family models.",
|
||||
category="",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to use KV Cache on."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The patched model with KV Cache enabled."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
input_patch_obj = KV_Attn_Input()
|
||||
|
||||
def model_input_patch(inputs):
|
||||
if len(input_patch_obj.cache) > 0:
|
||||
ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
|
||||
if ref_image_tokens > 0:
|
||||
img = inputs["img"]
|
||||
inputs["img"] = img[:, :-ref_image_tokens]
|
||||
return inputs
|
||||
|
||||
m.set_model_attn1_patch(input_patch_obj)
|
||||
m.set_model_post_input_patch(model_input_patch)
|
||||
if hasattr(model.model.diffusion_model, "params"):
|
||||
m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
|
||||
else:
|
||||
m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class FluxExtension(ComfyExtension):
|
||||
@override
|
||||
@ -243,6 +306,7 @@ class FluxExtension(ComfyExtension):
|
||||
FluxKontextMultiReferenceLatentMethod,
|
||||
EmptyFlux2LatentImage,
|
||||
Flux2Scheduler,
|
||||
FluxKVCache,
|
||||
]
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user