Make flux support flipflop

This commit is contained in:
Jedrzej Kosinski 2025-10-02 17:53:22 -07:00
parent 8d7b22b720
commit d5001ed90e
2 changed files with 75 additions and 65 deletions

View File

@ -7,6 +7,7 @@ from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
import comfy.patcher_extension
from comfy.ldm.flipflop_transformer import FlipFlopModule
from .layers import (
DoubleStreamBlock,
@ -35,13 +36,13 @@ class FluxParams:
guidance_embed: bool
class Flux(nn.Module):
class Flux(FlipFlopModule):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
super().__init__(("double_blocks", "single_blocks"))
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
@ -89,6 +90,72 @@ class Flux(nn.Module):
if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
def indiv_double_block_fwd(self, i, block, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask,
transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img[:, :add.shape[1]] += add
return img, txt
def indiv_single_block_fwd(self, i, block, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
return img
def forward_orig(
self,
img: Tensor,
@ -136,74 +203,16 @@ class Flux(nn.Module):
pe = None
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask,
transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img[:, :add.shape[1]] += add
# execute double blocks
img, txt = self.execute_blocks("double_blocks", self.indiv_double_block_fwd, (img, txt), vec, pe, attn_mask, control, blocks_replace, transformer_options)
if img.dtype == torch.float16:
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
# execute single blocks
img = self.execute_blocks("single_blocks", self.indiv_single_block_fwd, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options)
img = img[:, txt.shape[1] :, ...]

View File

@ -639,7 +639,7 @@ class ModelPatcher:
block_buffer = 3
valid_block_types = []
# for each block type, check if have enough room to flipflop
for block_info in self.model.diffusion_model.get_all_block_module_sizes():
for block_info in self.model.diffusion_model.get_all_block_module_sizes(reverse_sort_by_size=False):
block_size: int = block_info[1]
if block_size * block_buffer < lowvram_model_memory:
valid_block_types.append(block_info)
@ -653,6 +653,7 @@ class ModelPatcher:
n_fit_in_memory = int(leftover_memory // block_size)
# if all (or more) of this block type would fit in memory, no need to flipflop with it
if n_fit_in_memory >= total_blocks:
leftover_memory -= total_blocks * block_size
continue
# if the amount of this block that would fit in memory is less than buffer, skip this block type
if n_fit_in_memory < block_buffer: