diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 14f90cea5..5b014fc1b 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -7,6 +7,7 @@ from torch import Tensor, nn from einops import rearrange, repeat import comfy.ldm.common_dit import comfy.patcher_extension +from comfy.ldm.flipflop_transformer import FlipFlopModule from .layers import ( DoubleStreamBlock, @@ -35,13 +36,13 @@ class FluxParams: guidance_embed: bool -class Flux(nn.Module): +class Flux(FlipFlopModule): """ Transformer model for flow matching on sequences. """ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): - super().__init__() + super().__init__(("double_blocks", "single_blocks")) self.dtype = dtype params = FluxParams(**kwargs) self.params = params @@ -89,6 +90,72 @@ class Flux(nn.Module): if final_layer: self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) + def indiv_double_block_fwd(self, i, block, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"], out["txt"] = block(img=args["img"], + txt=args["txt"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) + return out + + out = blocks_replace[("double_block", i)]({"img": img, + "txt": txt, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, + {"original_block": block_wrap}) + txt = out["txt"] + img = out["img"] + else: + img, txt = block(img=img, + txt=txt, + vec=vec, + pe=pe, + attn_mask=attn_mask, + transformer_options=transformer_options) + + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + img[:, :add.shape[1]] += add + return img, txt + + def indiv_single_block_fwd(self, i, block, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options): + if ("single_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) + return out + + out = blocks_replace[("single_block", i)]({"img": img, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, + {"original_block": block_wrap}) + img = out["img"] + else: + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) + + if control is not None: # Controlnet + control_o = control.get("output") + if i < len(control_o): + add = control_o[i] + if add is not None: + img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add + return img + def forward_orig( self, img: Tensor, @@ -136,74 +203,16 @@ class Flux(nn.Module): pe = None blocks_replace = patches_replace.get("dit", {}) - for i, block in enumerate(self.double_blocks): - if ("double_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"], out["txt"] = block(img=args["img"], - txt=args["txt"], - vec=args["vec"], - pe=args["pe"], - attn_mask=args.get("attn_mask"), - transformer_options=args.get("transformer_options")) - return out - - out = blocks_replace[("double_block", i)]({"img": img, - "txt": txt, - "vec": vec, - "pe": pe, - "attn_mask": attn_mask, - "transformer_options": transformer_options}, - {"original_block": block_wrap}) - txt = out["txt"] - img = out["img"] - else: - img, txt = block(img=img, - txt=txt, - vec=vec, - pe=pe, - attn_mask=attn_mask, - transformer_options=transformer_options) - - if control is not None: # Controlnet - control_i = control.get("input") - if i < len(control_i): - add = control_i[i] - if add is not None: - img[:, :add.shape[1]] += add + # execute double blocks + img, txt = self.execute_blocks("double_blocks", self.indiv_double_block_fwd, (img, txt), vec, pe, attn_mask, control, blocks_replace, transformer_options) if img.dtype == torch.float16: img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504) img = torch.cat((txt, img), 1) - for i, block in enumerate(self.single_blocks): - if ("single_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"] = block(args["img"], - vec=args["vec"], - pe=args["pe"], - attn_mask=args.get("attn_mask"), - transformer_options=args.get("transformer_options")) - return out - - out = blocks_replace[("single_block", i)]({"img": img, - "vec": vec, - "pe": pe, - "attn_mask": attn_mask, - "transformer_options": transformer_options}, - {"original_block": block_wrap}) - img = out["img"] - else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) - - if control is not None: # Controlnet - control_o = control.get("output") - if i < len(control_o): - add = control_o[i] - if add is not None: - img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add + # execute single blocks + img = self.execute_blocks("single_blocks", self.indiv_single_block_fwd, img, txt, vec, pe, attn_mask, control, blocks_replace, transformer_options) img = img[:, txt.shape[1] :, ...] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index aa7fd6cd7..8ffed3efd 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -639,7 +639,7 @@ class ModelPatcher: block_buffer = 3 valid_block_types = [] # for each block type, check if have enough room to flipflop - for block_info in self.model.diffusion_model.get_all_block_module_sizes(): + for block_info in self.model.diffusion_model.get_all_block_module_sizes(reverse_sort_by_size=False): block_size: int = block_info[1] if block_size * block_buffer < lowvram_model_memory: valid_block_types.append(block_info) @@ -653,6 +653,7 @@ class ModelPatcher: n_fit_in_memory = int(leftover_memory // block_size) # if all (or more) of this block type would fit in memory, no need to flipflop with it if n_fit_in_memory >= total_blocks: + leftover_memory -= total_blocks * block_size continue # if the amount of this block that would fit in memory is less than buffer, skip this block type if n_fit_in_memory < block_buffer: