mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-18 02:23:06 +08:00
Make flux support flipflop
This commit is contained in:
parent
8d7b22b720
commit
d5001ed90e
@ -7,6 +7,7 @@ from torch import Tensor, nn
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
|
from comfy.ldm.flipflop_transformer import FlipFlopModule
|
||||||
|
|
||||||
from .layers import (
|
from .layers import (
|
||||||
DoubleStreamBlock,
|
DoubleStreamBlock,
|
||||||
@ -35,13 +36,13 @@ class FluxParams:
|
|||||||
guidance_embed: bool
|
guidance_embed: bool
|
||||||
|
|
||||||
|
|
||||||
class Flux(nn.Module):
|
class Flux(FlipFlopModule):
|
||||||
"""
|
"""
|
||||||
Transformer model for flow matching on sequences.
|
Transformer model for flow matching on sequences.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__(("double_blocks", "single_blocks"))
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
params = FluxParams(**kwargs)
|
params = FluxParams(**kwargs)
|
||||||
self.params = params
|
self.params = params
|
||||||
@ -89,6 +90,72 @@ class Flux(nn.Module):
|
|||||||
if final_layer:
|
if final_layer:
|
||||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def indiv_double_block_fwd(self, i, block, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"], out["txt"] = block(img=args["img"],
|
||||||
|
txt=args["txt"],
|
||||||
|
vec=args["vec"],
|
||||||
|
pe=args["pe"],
|
||||||
|
attn_mask=args.get("attn_mask"),
|
||||||
|
transformer_options=args.get("transformer_options"))
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": img,
|
||||||
|
"txt": txt,
|
||||||
|
"vec": vec,
|
||||||
|
"pe": pe,
|
||||||
|
"attn_mask": attn_mask,
|
||||||
|
"transformer_options": transformer_options},
|
||||||
|
{"original_block": block_wrap})
|
||||||
|
txt = out["txt"]
|
||||||
|
img = out["img"]
|
||||||
|
else:
|
||||||
|
img, txt = block(img=img,
|
||||||
|
txt=txt,
|
||||||
|
vec=vec,
|
||||||
|
pe=pe,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
transformer_options=transformer_options)
|
||||||
|
|
||||||
|
if control is not None: # Controlnet
|
||||||
|
control_i = control.get("input")
|
||||||
|
if i < len(control_i):
|
||||||
|
add = control_i[i]
|
||||||
|
if add is not None:
|
||||||
|
img[:, :add.shape[1]] += add
|
||||||
|
return img, txt
|
||||||
|
|
||||||
|
def indiv_single_block_fwd(self, i, block, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options):
|
||||||
|
if ("single_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"],
|
||||||
|
vec=args["vec"],
|
||||||
|
pe=args["pe"],
|
||||||
|
attn_mask=args.get("attn_mask"),
|
||||||
|
transformer_options=args.get("transformer_options"))
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("single_block", i)]({"img": img,
|
||||||
|
"vec": vec,
|
||||||
|
"pe": pe,
|
||||||
|
"attn_mask": attn_mask,
|
||||||
|
"transformer_options": transformer_options},
|
||||||
|
{"original_block": block_wrap})
|
||||||
|
img = out["img"]
|
||||||
|
else:
|
||||||
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
if control is not None: # Controlnet
|
||||||
|
control_o = control.get("output")
|
||||||
|
if i < len(control_o):
|
||||||
|
add = control_o[i]
|
||||||
|
if add is not None:
|
||||||
|
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
|
||||||
|
return img
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
img: Tensor,
|
img: Tensor,
|
||||||
@ -136,74 +203,16 @@ class Flux(nn.Module):
|
|||||||
pe = None
|
pe = None
|
||||||
|
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
for i, block in enumerate(self.double_blocks):
|
# execute double blocks
|
||||||
if ("double_block", i) in blocks_replace:
|
img, txt = self.execute_blocks("double_blocks", self.indiv_double_block_fwd, (img, txt), vec, pe, attn_mask, control, blocks_replace, transformer_options)
|
||||||
def block_wrap(args):
|
|
||||||
out = {}
|
|
||||||
out["img"], out["txt"] = block(img=args["img"],
|
|
||||||
txt=args["txt"],
|
|
||||||
vec=args["vec"],
|
|
||||||
pe=args["pe"],
|
|
||||||
attn_mask=args.get("attn_mask"),
|
|
||||||
transformer_options=args.get("transformer_options"))
|
|
||||||
return out
|
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": img,
|
|
||||||
"txt": txt,
|
|
||||||
"vec": vec,
|
|
||||||
"pe": pe,
|
|
||||||
"attn_mask": attn_mask,
|
|
||||||
"transformer_options": transformer_options},
|
|
||||||
{"original_block": block_wrap})
|
|
||||||
txt = out["txt"]
|
|
||||||
img = out["img"]
|
|
||||||
else:
|
|
||||||
img, txt = block(img=img,
|
|
||||||
txt=txt,
|
|
||||||
vec=vec,
|
|
||||||
pe=pe,
|
|
||||||
attn_mask=attn_mask,
|
|
||||||
transformer_options=transformer_options)
|
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
|
||||||
control_i = control.get("input")
|
|
||||||
if i < len(control_i):
|
|
||||||
add = control_i[i]
|
|
||||||
if add is not None:
|
|
||||||
img[:, :add.shape[1]] += add
|
|
||||||
|
|
||||||
if img.dtype == torch.float16:
|
if img.dtype == torch.float16:
|
||||||
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
|
||||||
img = torch.cat((txt, img), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
for i, block in enumerate(self.single_blocks):
|
# execute single blocks
|
||||||
if ("single_block", i) in blocks_replace:
|
img = self.execute_blocks("single_blocks", self.indiv_single_block_fwd, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options)
|
||||||
def block_wrap(args):
|
|
||||||
out = {}
|
|
||||||
out["img"] = block(args["img"],
|
|
||||||
vec=args["vec"],
|
|
||||||
pe=args["pe"],
|
|
||||||
attn_mask=args.get("attn_mask"),
|
|
||||||
transformer_options=args.get("transformer_options"))
|
|
||||||
return out
|
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": img,
|
|
||||||
"vec": vec,
|
|
||||||
"pe": pe,
|
|
||||||
"attn_mask": attn_mask,
|
|
||||||
"transformer_options": transformer_options},
|
|
||||||
{"original_block": block_wrap})
|
|
||||||
img = out["img"]
|
|
||||||
else:
|
|
||||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
|
||||||
control_o = control.get("output")
|
|
||||||
if i < len(control_o):
|
|
||||||
add = control_o[i]
|
|
||||||
if add is not None:
|
|
||||||
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
|
|
||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
|
|||||||
@ -639,7 +639,7 @@ class ModelPatcher:
|
|||||||
block_buffer = 3
|
block_buffer = 3
|
||||||
valid_block_types = []
|
valid_block_types = []
|
||||||
# for each block type, check if have enough room to flipflop
|
# for each block type, check if have enough room to flipflop
|
||||||
for block_info in self.model.diffusion_model.get_all_block_module_sizes():
|
for block_info in self.model.diffusion_model.get_all_block_module_sizes(reverse_sort_by_size=False):
|
||||||
block_size: int = block_info[1]
|
block_size: int = block_info[1]
|
||||||
if block_size * block_buffer < lowvram_model_memory:
|
if block_size * block_buffer < lowvram_model_memory:
|
||||||
valid_block_types.append(block_info)
|
valid_block_types.append(block_info)
|
||||||
@ -653,6 +653,7 @@ class ModelPatcher:
|
|||||||
n_fit_in_memory = int(leftover_memory // block_size)
|
n_fit_in_memory = int(leftover_memory // block_size)
|
||||||
# if all (or more) of this block type would fit in memory, no need to flipflop with it
|
# if all (or more) of this block type would fit in memory, no need to flipflop with it
|
||||||
if n_fit_in_memory >= total_blocks:
|
if n_fit_in_memory >= total_blocks:
|
||||||
|
leftover_memory -= total_blocks * block_size
|
||||||
continue
|
continue
|
||||||
# if the amount of this block that would fit in memory is less than buffer, skip this block type
|
# if the amount of this block that would fit in memory is less than buffer, skip this block type
|
||||||
if n_fit_in_memory < block_buffer:
|
if n_fit_in_memory < block_buffer:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user