mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
In-progress commit on making flipflop async weight streaming native, made loaded partially/loaded completely log messages have labels because having to memorize their meaning for dev work is annoying
This commit is contained in:
parent
d0bd221495
commit
01f4512bf8
@ -3,42 +3,9 @@ import torch
|
|||||||
import torch.cuda as cuda
|
import torch.cuda as cuda
|
||||||
import copy
|
import copy
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
FLIPFLOP_REGISTRY = {}
|
|
||||||
|
|
||||||
def register(name):
|
|
||||||
def decorator(cls):
|
|
||||||
FLIPFLOP_REGISTRY[name] = cls
|
|
||||||
return cls
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FlipFlopConfig:
|
|
||||||
block_name: str
|
|
||||||
block_wrap_fn: callable
|
|
||||||
out_names: Tuple[str]
|
|
||||||
overwrite_forward: str
|
|
||||||
pinned_staging: bool = False
|
|
||||||
inference_device: str = "cuda"
|
|
||||||
offloading_device: str = "cpu"
|
|
||||||
|
|
||||||
|
|
||||||
def patch_model_from_config(model, config: FlipFlopConfig):
|
|
||||||
block_list = getattr(model, config.block_name)
|
|
||||||
flip_flop_transformer = FlipFlopTransformer(block_list,
|
|
||||||
block_wrap_fn=config.block_wrap_fn,
|
|
||||||
out_names=config.out_names,
|
|
||||||
offloading_device=config.offloading_device,
|
|
||||||
inference_device=config.inference_device,
|
|
||||||
pinned_staging=config.pinned_staging)
|
|
||||||
delattr(model, config.block_name)
|
|
||||||
setattr(model, config.block_name, flip_flop_transformer)
|
|
||||||
setattr(model, config.overwrite_forward, flip_flop_transformer.__call__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlipFlopContext:
|
class FlipFlopContext:
|
||||||
def __init__(self, holder: FlipFlopHolder):
|
def __init__(self, holder: FlipFlopHolder):
|
||||||
@ -46,11 +13,12 @@ class FlipFlopContext:
|
|||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.num_blocks = len(self.holder.transformer_blocks)
|
self.num_blocks = len(self.holder.blocks)
|
||||||
self.first_flip = True
|
self.first_flip = True
|
||||||
self.first_flop = True
|
self.first_flop = True
|
||||||
self.last_flip = False
|
self.last_flip = False
|
||||||
self.last_flop = False
|
self.last_flop = False
|
||||||
|
# TODO: the 'i' that's passed into func needs to be properly offset to do patches correctly
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
@ -71,9 +39,9 @@ class FlipFlopContext:
|
|||||||
next_flop_i = next_flop_i - self.num_blocks
|
next_flop_i = next_flop_i - self.num_blocks
|
||||||
self.last_flip = True
|
self.last_flip = True
|
||||||
if not self.first_flip:
|
if not self.first_flip:
|
||||||
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event)
|
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[next_flop_i].state_dict(), self.holder.event_flop, self.holder.cpy_end_event)
|
||||||
if self.last_flip:
|
if self.last_flip:
|
||||||
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[0].state_dict(), cpy_start_event=self.holder.event_flip)
|
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[0].state_dict(), cpy_start_event=self.holder.event_flip)
|
||||||
self.first_flip = False
|
self.first_flip = False
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -89,9 +57,9 @@ class FlipFlopContext:
|
|||||||
if next_flip_i >= self.num_blocks:
|
if next_flip_i >= self.num_blocks:
|
||||||
next_flip_i = next_flip_i - self.num_blocks
|
next_flip_i = next_flip_i - self.num_blocks
|
||||||
self.last_flop = True
|
self.last_flop = True
|
||||||
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.transformer_blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event)
|
self.holder._copy_state_dict(self.holder.flip.state_dict(), self.holder.blocks[next_flip_i].state_dict(), self.holder.event_flip, self.holder.cpy_end_event)
|
||||||
if self.last_flop:
|
if self.last_flop:
|
||||||
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.transformer_blocks[1].state_dict(), cpy_start_event=self.holder.event_flop)
|
self.holder._copy_state_dict(self.holder.flop.state_dict(), self.holder.blocks[1].state_dict(), cpy_start_event=self.holder.event_flop)
|
||||||
self.first_flop = False
|
self.first_flop = False
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -106,19 +74,20 @@ class FlipFlopContext:
|
|||||||
|
|
||||||
|
|
||||||
class FlipFlopHolder:
|
class FlipFlopHolder:
|
||||||
def __init__(self, transformer_blocks: List[torch.nn.Module], inference_device="cuda", offloading_device="cpu"):
|
def __init__(self, blocks: List[torch.nn.Module], flip_amount: int, load_device="cuda", offload_device="cpu"):
|
||||||
self.load_device = torch.device(inference_device)
|
self.load_device = torch.device(load_device)
|
||||||
self.offload_device = torch.device(offloading_device)
|
self.offload_device = torch.device(offload_device)
|
||||||
self.transformer_blocks = transformer_blocks
|
self.blocks = blocks
|
||||||
|
self.flip_amount = flip_amount
|
||||||
|
|
||||||
self.block_module_size = 0
|
self.block_module_size = 0
|
||||||
if len(self.transformer_blocks) > 0:
|
if len(self.blocks) > 0:
|
||||||
self.block_module_size = comfy.model_management.module_size(self.transformer_blocks[0])
|
self.block_module_size = comfy.model_management.module_size(self.blocks[0])
|
||||||
|
|
||||||
self.flip: torch.nn.Module = None
|
self.flip: torch.nn.Module = None
|
||||||
self.flop: torch.nn.Module = None
|
self.flop: torch.nn.Module = None
|
||||||
# TODO: make initialization happen in model management code/model patcher, not here
|
# TODO: make initialization happen in model management code/model patcher, not here
|
||||||
self.initialize_flipflop_blocks(self.load_device)
|
self.init_flipflop_blocks(self.load_device)
|
||||||
|
|
||||||
self.compute_stream = cuda.default_stream(self.load_device)
|
self.compute_stream = cuda.default_stream(self.load_device)
|
||||||
self.cpy_stream = cuda.Stream(self.load_device)
|
self.cpy_stream = cuda.Stream(self.load_device)
|
||||||
@ -142,10 +111,57 @@ class FlipFlopHolder:
|
|||||||
def context(self):
|
def context(self):
|
||||||
return FlipFlopContext(self)
|
return FlipFlopContext(self)
|
||||||
|
|
||||||
def initialize_flipflop_blocks(self, load_device: torch.device):
|
def init_flipflop_blocks(self, load_device: torch.device):
|
||||||
self.flip = copy.deepcopy(self.transformer_blocks[0]).to(device=load_device)
|
self.flip = copy.deepcopy(self.blocks[0]).to(device=load_device)
|
||||||
self.flop = copy.deepcopy(self.transformer_blocks[1]).to(device=load_device)
|
self.flop = copy.deepcopy(self.blocks[1]).to(device=load_device)
|
||||||
|
|
||||||
|
def clean_flipflop_blocks(self):
|
||||||
|
del self.flip
|
||||||
|
del self.flop
|
||||||
|
self.flip = None
|
||||||
|
self.flop = None
|
||||||
|
|
||||||
|
|
||||||
|
class FlopFlopModule(torch.nn.Module):
|
||||||
|
def __init__(self, block_types: tuple[str, ...]):
|
||||||
|
super().__init__()
|
||||||
|
self.block_types = block_types
|
||||||
|
self.flipflop: dict[str, FlipFlopHolder] = {}
|
||||||
|
|
||||||
|
def setup_flipflop_holders(self, block_percentage: float):
|
||||||
|
for block_type in self.block_types:
|
||||||
|
if block_type in self.flipflop:
|
||||||
|
continue
|
||||||
|
num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage))
|
||||||
|
self.flipflop["transformer_blocks"] = FlipFlopHolder(self.transformer_blocks[num_blocks:], num_blocks)
|
||||||
|
|
||||||
|
def clean_flipflop_holders(self):
|
||||||
|
for block_type in self.flipflop.keys():
|
||||||
|
self.flipflop[block_type].clean_flipflop_blocks()
|
||||||
|
del self.flipflop[block_type]
|
||||||
|
|
||||||
|
def get_blocks(self, block_type: str) -> torch.nn.ModuleList:
|
||||||
|
if block_type not in self.block_types:
|
||||||
|
raise ValueError(f"Block type {block_type} not found in {self.block_types}")
|
||||||
|
if block_type in self.flipflop:
|
||||||
|
return getattr(self, block_type)[:self.flipflop[block_type].flip_amount]
|
||||||
|
return getattr(self, block_type)
|
||||||
|
|
||||||
|
def get_all_block_module_sizes(self, sort_by_size: bool = False) -> list[tuple[str, int]]:
|
||||||
|
'''
|
||||||
|
Returns a list of (block_type, size).
|
||||||
|
If sort_by_size is True, the list is sorted by size.
|
||||||
|
'''
|
||||||
|
sizes = [(block_type, self.get_block_module_size(block_type)) for block_type in self.block_types]
|
||||||
|
if sort_by_size:
|
||||||
|
sizes.sort(key=lambda x: x[1])
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
def get_block_module_size(self, block_type: str) -> int:
|
||||||
|
return comfy.model_management.module_size(getattr(self, block_type)[0])
|
||||||
|
|
||||||
|
|
||||||
|
# Below is the implementation from contentis' prototype flip flop
|
||||||
class FlipFlopTransformer:
|
class FlipFlopTransformer:
|
||||||
def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"):
|
def __init__(self, transformer_blocks: List[torch.nn.Module], block_wrap_fn, out_names: Tuple[str], pinned_staging: bool = False, inference_device="cuda", offloading_device="cpu"):
|
||||||
self.transformer_blocks = transformer_blocks
|
self.transformer_blocks = transformer_blocks
|
||||||
@ -379,28 +395,26 @@ class FlipFlopTransformer:
|
|||||||
# patch_model_from_config(model, Wan.blocks_config)
|
# patch_model_from_config(model, Wan.blocks_config)
|
||||||
# return model
|
# return model
|
||||||
|
|
||||||
|
# @register("QwenImageTransformer2DModel")
|
||||||
|
# class QwenImage:
|
||||||
|
# @staticmethod
|
||||||
|
# def qwen_blocks_wrap(block, **kwargs):
|
||||||
|
# kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"],
|
||||||
|
# encoder_hidden_states=kwargs["encoder_hidden_states"],
|
||||||
|
# encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"],
|
||||||
|
# temb=kwargs["temb"],
|
||||||
|
# image_rotary_emb=kwargs["image_rotary_emb"],
|
||||||
|
# transformer_options=kwargs["transformer_options"])
|
||||||
|
# return kwargs
|
||||||
|
|
||||||
@register("QwenImageTransformer2DModel")
|
# blocks_config = FlipFlopConfig(block_name="transformer_blocks",
|
||||||
class QwenImage:
|
# block_wrap_fn=qwen_blocks_wrap,
|
||||||
@staticmethod
|
# out_names=("encoder_hidden_states", "hidden_states"),
|
||||||
def qwen_blocks_wrap(block, **kwargs):
|
# overwrite_forward="blocks_fwd",
|
||||||
kwargs["encoder_hidden_states"], kwargs["hidden_states"] = block(hidden_states=kwargs["hidden_states"],
|
# pinned_staging=False)
|
||||||
encoder_hidden_states=kwargs["encoder_hidden_states"],
|
|
||||||
encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"],
|
|
||||||
temb=kwargs["temb"],
|
|
||||||
image_rotary_emb=kwargs["image_rotary_emb"],
|
|
||||||
transformer_options=kwargs["transformer_options"])
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
blocks_config = FlipFlopConfig(block_name="transformer_blocks",
|
|
||||||
block_wrap_fn=qwen_blocks_wrap,
|
|
||||||
out_names=("encoder_hidden_states", "hidden_states"),
|
|
||||||
overwrite_forward="blocks_fwd",
|
|
||||||
pinned_staging=False)
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
# @staticmethod
|
||||||
def patch(model):
|
# def patch(model):
|
||||||
patch_model_from_config(model, QwenImage.blocks_config)
|
# patch_model_from_config(model, QwenImage.blocks_config)
|
||||||
return model
|
# return model
|
||||||
|
|
||||||
|
|||||||
@ -343,11 +343,36 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def setup_flipflop_holders(self, block_percentage: float):
|
def setup_flipflop_holders(self, block_percentage: float):
|
||||||
|
if "transformer_blocks" in self.flipflop:
|
||||||
|
return
|
||||||
|
import comfy.model_management
|
||||||
# We hackily move any flipflopped blocks into holder so that our model management system does not see them.
|
# We hackily move any flipflopped blocks into holder so that our model management system does not see them.
|
||||||
num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage))
|
num_blocks = int(len(self.transformer_blocks) * (1.0-block_percentage))
|
||||||
self.flipflop["blocks_fwd"] = FlipFlopHolder(self.transformer_blocks[num_blocks:])
|
loading = []
|
||||||
|
for n, m in self.named_modules():
|
||||||
|
params = []
|
||||||
|
skip = False
|
||||||
|
for name, param in m.named_parameters(recurse=False):
|
||||||
|
params.append(name)
|
||||||
|
for name, param in m.named_parameters(recurse=True):
|
||||||
|
if name not in params:
|
||||||
|
skip = True # skip random weights in non leaf modules
|
||||||
|
break
|
||||||
|
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||||
|
loading.append((comfy.model_management.module_size(m), n, m, params))
|
||||||
|
self.flipflop["transformer_blocks"] = FlipFlopHolder(self.transformer_blocks[num_blocks:], num_blocks)
|
||||||
self.transformer_blocks = nn.ModuleList(self.transformer_blocks[:num_blocks])
|
self.transformer_blocks = nn.ModuleList(self.transformer_blocks[:num_blocks])
|
||||||
|
|
||||||
|
def clean_flipflop_holders(self):
|
||||||
|
if "transformer_blocks" in self.flipflop:
|
||||||
|
self.flipflop["transformer_blocks"].clean_flipflop_blocks()
|
||||||
|
del self.flipflop["transformer_blocks"]
|
||||||
|
|
||||||
|
def get_transformer_blocks(self):
|
||||||
|
if "transformer_blocks" in self.flipflop:
|
||||||
|
return self.transformer_blocks[:self.flipflop["transformer_blocks"].flip_amount]
|
||||||
|
return self.transformer_blocks
|
||||||
|
|
||||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
@ -409,17 +434,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
|
|
||||||
return encoder_hidden_states, hidden_states
|
return encoder_hidden_states, hidden_states
|
||||||
|
|
||||||
def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options):
|
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
|
||||||
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
|
||||||
if "blocks_fwd" in self.flipflop:
|
|
||||||
holder = self.flipflop["blocks_fwd"]
|
|
||||||
with holder.context() as ctx:
|
|
||||||
for i, block in enumerate(holder.transformer_blocks):
|
|
||||||
encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
|
||||||
|
|
||||||
return encoder_hidden_states, hidden_states
|
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@ -487,12 +501,14 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
encoder_hidden_states, hidden_states = self.blocks_fwd(hidden_states=hidden_states,
|
for i, block in enumerate(self.get_transformer_blocks()):
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
||||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
if "transformer_blocks" in self.flipflop:
|
||||||
temb=temb, image_rotary_emb=image_rotary_emb,
|
holder = self.flipflop["transformer_blocks"]
|
||||||
patches=patches, control=control, blocks_replace=blocks_replace, x=x,
|
with holder.context() as ctx:
|
||||||
transformer_options=transformer_options)
|
for i, block in enumerate(holder.blocks):
|
||||||
|
encoder_hidden_states, hidden_states = ctx(self.indiv_block_fwd, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
||||||
|
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|||||||
@ -605,7 +605,27 @@ class ModelPatcher:
|
|||||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||||
|
|
||||||
def supports_flipflop(self):
|
def supports_flipflop(self):
|
||||||
return hasattr(self.model.diffusion_model, "flipflop")
|
# flipflop requires diffusion_model, explicit flipflop support, NVIDIA CUDA streams, and loading/offloading VRAM
|
||||||
|
if not hasattr(self.model, "diffusion_model"):
|
||||||
|
return False
|
||||||
|
if not hasattr(self.model.diffusion_model, "flipflop"):
|
||||||
|
return False
|
||||||
|
if not comfy.model_management.is_nvidia():
|
||||||
|
return False
|
||||||
|
if comfy.model_management.vram_state in (comfy.model_management.VRAMState.HIGH_VRAM, comfy.model_management.VRAMState.SHARED):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def init_flipflop(self):
|
||||||
|
if not self.supports_flipflop():
|
||||||
|
return
|
||||||
|
# figure out how many b
|
||||||
|
self.model.diffusion_model.setup_flipflop_holders(self.model_options["flipflop_block_percentage"])
|
||||||
|
|
||||||
|
def clean_flipflop(self):
|
||||||
|
if not self.supports_flipflop():
|
||||||
|
return
|
||||||
|
self.model.diffusion_model.clean_flipflop_holders()
|
||||||
|
|
||||||
def _load_list(self):
|
def _load_list(self):
|
||||||
loading = []
|
loading = []
|
||||||
@ -628,6 +648,9 @@ class ModelPatcher:
|
|||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
|
lowvram_mem_counter = 0
|
||||||
|
if self.supports_flipflop():
|
||||||
|
...
|
||||||
loading = self._load_list()
|
loading = self._load_list()
|
||||||
|
|
||||||
load_completely = []
|
load_completely = []
|
||||||
@ -647,6 +670,7 @@ class ModelPatcher:
|
|||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
lowvram_counter += 1
|
lowvram_counter += 1
|
||||||
|
lowvram_mem_counter += module_mem
|
||||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -709,10 +733,10 @@ class ModelPatcher:
|
|||||||
x[2].to(device_to)
|
x[2].to(device_to)
|
||||||
|
|
||||||
if lowvram_counter > 0:
|
if lowvram_counter > 0:
|
||||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
logging.info(f"loaded partially; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, {lowvram_mem_counter / (1024 * 1024):.2f} MB offloaded, lowvram patches: {patch_counter}")
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
else:
|
else:
|
||||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
logging.info(f"loaded completely; {lowvram_model_memory / (1024 * 1024):.2f} MB usable memory, {mem_counter / (1024 * 1024):.2f} MB loaded, full load: {full_load}")
|
||||||
self.model.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
if full_load:
|
if full_load:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
|
|||||||
@ -3,33 +3,6 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
from comfy.ldm.flipflop_transformer import FLIPFLOP_REGISTRY
|
|
||||||
|
|
||||||
class FlipFlopOld(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls) -> io.Schema:
|
|
||||||
return io.Schema(
|
|
||||||
node_id="FlipFlop",
|
|
||||||
display_name="FlipFlop (Old)",
|
|
||||||
category="_for_testing",
|
|
||||||
inputs=[
|
|
||||||
io.Model.Input(id="model")
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
io.Model.Output()
|
|
||||||
],
|
|
||||||
description="Apply FlipFlop transformation to model using registry-based patching"
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, model) -> io.NodeOutput:
|
|
||||||
patch_cls = FLIPFLOP_REGISTRY.get(model.model.diffusion_model.__class__.__name__, None)
|
|
||||||
if patch_cls is None:
|
|
||||||
raise ValueError(f"Model {model.model.diffusion_model.__class__.__name__} not supported")
|
|
||||||
|
|
||||||
model.model.diffusion_model = patch_cls.patch(model.model.diffusion_model)
|
|
||||||
|
|
||||||
return io.NodeOutput(model)
|
|
||||||
|
|
||||||
class FlipFlop(io.ComfyNode):
|
class FlipFlop(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -62,7 +35,6 @@ class FlipFlopExtension(ComfyExtension):
|
|||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
FlipFlopOld,
|
|
||||||
FlipFlop,
|
FlipFlop,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user