From 7b4c1e80312e0da4ea47e85d8e4f1feee1e716c7 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 14 Nov 2025 09:15:16 +0200 Subject: [PATCH] async cache revamp Added an async loading and offloading of moe layers, having consistent memory with oom errors. Used to give oom error after the third layer with 24 giga bytes gpu, now goes to the end with consistent memory with minimal latency --- comfy/ldm/hunyuan_image_3/model.py | 372 ++++++++++++----------------- 1 file changed, 158 insertions(+), 214 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 3cbca46cd..9682a270f 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -1,15 +1,17 @@ import os -import gc import math +import time import torch import psutil +import asyncio +import threading import torch.nn as nn from pathlib import Path +import concurrent.futures from einops import rearrange import torch.nn.functional as F from collections import OrderedDict from safetensors import safe_open -from contextlib import contextmanager from transformers.cache_utils import StaticCache from typing import Optional, Tuple, Any, List, Dict from comfy.ldm.modules.attention import optimized_attention @@ -19,13 +21,13 @@ INIT_MOE = torch.cuda.device_count() != 1 if not INIT_MOE: MOE_LAYER_SIZE = (1024**3) * 2.65 # approx - CPU_MOE_RATIO = None torch.cuda.set_device(0) props = torch.cuda.get_device_properties(0) - INIT_CUDA_MEM = (props.total_memory - torch.cuda.memory_reserved()) * 0.9 - ADDITIONAL_LAYERS_IN_GPU = math.floor(INIT_CUDA_MEM / MOE_LAYER_SIZE) + LAYERS_IN_CPU = math.floor((int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')) + - psutil.Process(os.getpid()).memory_info().rss + - (2*1024**3)) * 0.50) / MOE_LAYER_SIZE) class HunyuanStaticCache(StaticCache): @@ -286,20 +288,15 @@ def topkgating( logits = logits.float() gates = F.softmax(logits, dim=1) - extra = ADDITIONAL_LAYERS_IN_GPU - - values_all, indices_all = torch.topk(gates, topk + extra, dim=1) + values_all, indices_all = torch.topk(gates, topk, dim=1) expert_weight = values_all[:, :topk] expert_index = indices_all[:, :topk] - _, cpu_expert_index = torch.topk(gates, int(CPU_MOE_RATIO * 64), dim = 1) - cpu_expert_index = cpu_expert_index[:, (8 + ADDITIONAL_LAYERS_IN_GPU):] - if norm_topk_prob and topk > 1: denom = expert_weight.sum(dim=1, keepdim=True).clamp_min(torch.finfo(gates.dtype).eps) expert_weight = expert_weight / denom - return expert_weight, expert_index, cpu_expert_index, indices_all + return expert_weight, expert_index class HunyuanRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -452,7 +449,7 @@ class HunyuanTopKGate(nn.Module): return gate_output class HunyuanMLP(nn.Module): - def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False): + def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False, device=None): super().__init__() self.config = config self.layer_idx = layer_idx @@ -462,8 +459,8 @@ class HunyuanMLP(nn.Module): self.act_fn = torch.nn.functional.silu self.intermediate_size *= 2 # SwiGLU - self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False) + self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device) + self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device) def forward(self, x): self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device) if x.ndim == 2: @@ -474,204 +471,143 @@ class HunyuanMLP(nn.Module): return down_proj class MoELRUCache(nn.Module): - def __init__(self, cpu_mem: int = 50, safety_buffer_bytes = 3*(1024**3), max_gpu_eviction_attempts = 8): + def __init__(self): super().__init__() - global CPU_MOE_RATIO - - _, total = torch.cuda.mem_get_info() - max_gpu_mem_gb = max((total - 2 * safety_buffer_bytes) / (1024**3), 1) - - self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3) - self.MAX_CPU_MEM = int(cpu_mem * 1024**3) self.gpu_cache = OrderedDict() self.cpu_cache = OrderedDict() + self.offload_stream = torch.cuda.Stream() + self.load_stream = torch.cuda.Stream() - self.gpu_mem_usage = 0 - self.cpu_mem_usage = 0 - # 50% for system and headroom - try: - self.MAX_CPU_MEM = int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')) - - psutil.Process(os.getpid()).memory_info().rss - - safety_buffer_bytes) * 0.55 - except: - self.MAX_CPU_MEM = int(cpu_mem * (1024**3) * 0.5) # TODO + self.last_offload_event = None + self._loop = asyncio.new_event_loop() + threading.Thread(target=self._loop.run_forever, daemon=True).start() - ADDITIONAL_LAYERS_IN_CPU = math.floor((50 * (1024**3)) / MOE_LAYER_SIZE) - CPU_MOE_RATIO = (min(64 - ADDITIONAL_LAYERS_IN_GPU, ADDITIONAL_LAYERS_IN_CPU)) / 64 + async def _async_offload_to_cpu(self, layer_idx): + # async offload from gpu (removed) - self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3) - self.SAFETY_BUFFER = int(safety_buffer_bytes) - self.MAX_GPU_EVICT_ATTEMPTS = max_gpu_eviction_attempts + num_experts = 64 + moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i]) + for i in range(num_experts) + if (layer_idx * num_experts + i) in self.gpu_cache] + event = torch.cuda.Event() - def _gpu_free_bytes(self): - free, total = torch.cuda.mem_get_info() - return int(free) - - def _estimate_size(self, moe): - # include parameters + buffers - size = 0 - for p in moe.parameters(): - size += p.numel() * p.element_size() - for b in moe.buffers(): - size += b.numel() * b.element_size() - return int(size) + with torch.cuda.stream(self.offload_stream): + for index, moe in moe_group: + moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True) + for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()): + if p_gpu.device.type == "meta": + continue + with torch.no_grad(): + p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True) + p_cpu.copy_(p_gpu, non_blocking=True) - def _evict_until_free(self, required_bytes, max_attempts=16): - attempts = 0 - while self._gpu_free_bytes() < required_bytes and attempts < max_attempts: - evicted = self._evict_from_gpu() - if not evicted: - break - attempts += 1 - return self._gpu_free_bytes() >= required_bytes + self.cpu_cache[index] = moe_cpu - @contextmanager - def ensure_headroom(self, required_bytes): + self.offload_stream.record_event(event) - safety = getattr(self, "SAFETY_BUFFER", 0) - target_free = int(required_bytes + safety) + self.last_offload_event = event - if getattr(self, "_headroom", None) is not None: - try: - del self._headroom - except Exception: - pass - self._headroom = None + def finalize_offload_layer(): + event.synchronize() + for index, moe in moe_group: + moe.to("meta") + self.gpu_cache.pop(index, None) + del moe + torch.cuda.empty_cache() - ok = self._evict_until_free(target_free) - if not ok and self._gpu_free_bytes() < target_free: - # last ditch - try: - torch.cuda.empty_cache() - except Exception: - pass + threading.Thread(target=finalize_offload_layer, daemon=True).start() - try: - yield - finally: - if getattr(self, "_headroom", None) is None: - try: - self._headroom = torch.empty((self._headroom_bytes,), dtype=torch.uint8, device="cuda:0") - except Exception: - self._headroom = None + async def _async_load_to_gpu(self, index, moe): - def add_gpu(self, moe, index, allowed_retries=3): - size = self._estimate_size(moe) - - while self.gpu_mem_usage + size > self.MAX_GPU_MEM: - if not self._evict_from_gpu(): + # if enough memory load, otherwise wait for offload + while True: + free_bytes, _ = torch.cuda.mem_get_info() + if free_bytes > 2 * MOE_LAYER_SIZE: break - attempts = 0 - while self._gpu_free_bytes() < size + self.SAFETY_BUFFER and attempts < self.MAX_GPU_EVICT_ATTEMPTS: - if not self._evict_from_gpu(): - break - attempts += 1 + self.last_offload_event.synchronize() + torch.cuda.empty_cache() + await asyncio.sleep(0.01) - for _ in range(allowed_retries): - try: - moe_cuda = moe.to("cuda:0") - break - except RuntimeError as e: - if "out of memory" not in str(e).lower(): - raise - evicted = self._evict_from_gpu() - if not evicted: # can't evict - raise - else: - raise RuntimeError("Failed to move expert to GPU after evictions") + # async loading from cpu -> gpu + with torch.cuda.stream(self.load_stream): + moe_gpu = HunyuanMLP(moe.config).to("cuda", non_blocking=True) + for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()): + with torch.no_grad(): + p_gpu.data = torch.empty_like(p_cpu, device="cuda") + p_gpu.copy_(p_cpu, non_blocking=True) - self.gpu_cache[index] = moe_cuda - self.gpu_cache.move_to_end(index) - self.gpu_mem_usage += size + def finalize_load(): + self.gpu_cache[index] = moe_gpu + self.cpu_cache.pop(index, None) - return + threading.Thread(target=finalize_load, daemon=True).start() def add_cpu(self, moe, index): - size = self._estimate_size(moe) - while self.cpu_mem_usage + size > self.MAX_CPU_MEM: - if not self._evict_from_cpu(): - break moe_cpu = moe.to("cpu") + + for _, p in moe_cpu.named_parameters(): + if not p.is_pinned(): + p.data = p.data.pin_memory() + self.cpu_cache[index] = moe_cpu self.cpu_cache.move_to_end(index) - self.cpu_mem_usage += size - - def get_from_device(self, index): - if index in self.gpu_cache: - moe = self.gpu_cache[index] - self.gpu_cache.move_to_end(index) - return moe - if index in self.cpu_cache: - moe = self.cpu_cache.pop(index) - self.cpu_mem_usage = max(0, self.cpu_mem_usage - self._estimate_size(moe)) - try: - self.add_gpu(moe, index) - return self.gpu_cache[index] - except RuntimeError: - self.cpu_cache[index] = moe - self.cpu_cache.move_to_end(index) - self.cpu_mem_usage += self._estimate_size(moe) - raise - - return None # load from disk - - def _evict_from_gpu(self): - if not self.gpu_cache: - return False - - idx, moe = self.gpu_cache.popitem(last=False) - size = self._estimate_size(moe) - self.gpu_mem_usage = max(0, self.gpu_mem_usage - size) - - if self.cpu_mem_usage + size <= self.MAX_CPU_MEM: - try: - moe_cpu = moe.to("cpu") - except Exception: - # drop the model if cpu is full - del moe - return True - self.cpu_cache[idx] = moe_cpu - self.cpu_cache.move_to_end(idx) - self.cpu_mem_usage += size - return True - else: - del moe - return True - - def _evict_from_cpu(self): - if not self.cpu_cache: - return False - _, moe = self.cpu_cache.popitem(last=False) - size = self._estimate_size(moe) - self.cpu_mem_usage = max(0, self.cpu_mem_usage - size) - del moe - gc.collect() - return True class LazyMoELoader(nn.Module): - def __init__(self, device): + def __init__(self, cache, config): super().__init__() - self.device = device + self.cache = cache + self.config = config + self._loop = cache._loop - def lazy_init(self, config, layer_idx, expert_idx): + def get_checkpoint(self): comfyui_dir = Path.home() / "ComfyUI" checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors" checkpoint = checkpoint.resolve() if not os.path.exists(checkpoint): raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}") + return checkpoint + def lazy_init(self, layer_idx, expert_idx): + checkpoint = self.get_checkpoint() prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}." additional_prefix = f"model.layers.{layer_idx}.mlp.gate_and_up_proj.weight" sd = {} - with safe_open(checkpoint, framework="pt", device=self.device) as f: + with safe_open(checkpoint, framework="pt", device="cpu") as f: for k in f.keys(): if k.startswith(prefix) or k.startswith(additional_prefix): new_k = k.split(f"experts.{expert_idx}.", 1)[1] sd[new_k] = f.get_tensor(k) - return HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd).to(self.deivce) + return HunyuanMLP(self.config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd) + + async def lazy_load_from_disk(self, layer_idx, expert_idx): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx) + + def _schedule_disk_load(self, layer_idx, expert_idx): + + coro = self.lazy_load_from_disk(layer_idx, expert_idx) + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + + def _on_disk_loaded(fut): + moe_cpu = fut.result() + def _add_cpu_in_main_thread(): + self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx) + + asyncio.run_coroutine_threadsafe( + self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu), + self.cache._loop + ) + threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start() + + future.add_done_callback(_on_disk_loaded) + return future + +def enough_vram(required_bytes): + free, total = torch.cuda.mem_get_info() + return free > required_bytes class HunyuanMoE(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None): @@ -687,7 +623,7 @@ class HunyuanMoE(nn.Module): [HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)] ) else: - self.experts = None + self.experts = [] self.moe_lru = moe_lru def forward(self, hidden_states): @@ -702,38 +638,10 @@ class HunyuanMoE(nn.Module): reshaped_input = hidden_states.reshape(-1, hidden_size) with torch.cuda.nvtx.range("MoE"): - expert_weight, expert_index, cpu_expert_index, indices_all = self.gate(hidden_states) - if not INIT_MOE: - if ADDITIONAL_LAYERS_IN_GPU > 0: - additional_expert_index = indices_all[:, expert_index.size(1): expert_index.size(1) + ADDITIONAL_LAYERS_IN_GPU] - - flat = additional_expert_index.reshape(-1).to("cpu") - counts = torch.bincount(flat, minlength=self.num_experts) - top_extra = torch.topk(counts, k=min(ADDITIONAL_LAYERS_IN_GPU, (counts>0).sum().item())).indices.tolist() - - for expert_id in top_extra: - if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None: - expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id) - self.moe_lru.add_gpu(expert_cpu, expert_id + self.layer_idx) - - if cpu_expert_index is not None and cpu_expert_index.numel() > 0: - for expert_id in torch.unique(cpu_expert_index).cpu().tolist(): - if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None: - expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id) - self.moe_lru.add_cpu(expert_cpu, expert_id + self.layer_idx) + expert_weight, expert_index = self.gate(hidden_states) combined_output = torch.zeros_like(reshaped_input) - experts_list = [] - for e in range(self.num_experts): - token_mask = (expert_index == e) - if not token_mask.any(): - continue - expert = self.moe_lru.get_from_device(e + self.layer_idx) - if expert is None: - expert = LazyMoELoader() - expert = expert.lazy_init(self.config, self.layer_idx, e) - self.moe_lru.add_gpu(expert, e + self.layer_idx) - experts_list.append((e, expert)) + experts_list = [(i, expert) for i, expert in enumerate(self.experts)] per_pos, per_tokens, per_weights = [], [], [] for e, _ in experts_list: @@ -761,6 +669,8 @@ class HunyuanMoE(nn.Module): l1, l2 = [], [] for _, expert in experts_list: + if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): + expert = expert.result() l1.append(expert.gate_and_up_proj) l2.append(expert.down_proj) @@ -769,6 +679,12 @@ class HunyuanMoE(nn.Module): W1_T = W1.transpose(1, 2) W2_T = W2.transpose(1, 2) + + # wait for enough vram for the computations + while not enough_vram(5*(1024 ** 3)): + event = self.moe_lru.last_offload_event + if event is not None and not event.query(): + time.sleep(0.001) x = torch.bmm(tokens_padded, W1_T) x = F.silu(x) @@ -781,7 +697,7 @@ class HunyuanMoE(nn.Module): for i, token_positions in enumerate(per_pos): Ni = lengths[i] out_i = out_padded[i, :Ni] - combined_output.index_add_(0, token_positions.to(hidden_states.device), out_i) + combined_output.to(hidden_states.dtype).index_add_(0, token_positions.to(hidden_states.device), out_i.to(hidden_states.dtype)) #dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input) #chunks = dispatched_input.chunk(self.num_experts, dim=0) @@ -933,6 +849,7 @@ class HunyuanImage3Model(nn.Module): super().__init__() self.padding_idx = 128009 self.vocab_size = 133120 + self.config = config self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx) self.layers = nn.ModuleList( [HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])] @@ -941,6 +858,7 @@ class HunyuanImage3Model(nn.Module): self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) self.shared_tensor = None + self.moe_lru = moe_lru def forward( self, @@ -962,19 +880,45 @@ class HunyuanImage3Model(nn.Module): hidden_states = inputs_embeds next_decoder_cache = None - for layer_idx, decoder_layer in enumerate(self.layers): + next_layers = 0 + sparse_interval = max(1, len(self.layers) // 3) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - use_cache=use_cache, - custom_pos_emb=custom_pos_emb, - mode=mode, - first_step=first_step, - gen_timestep_scatter_index=gen_timestep_scatter_index, - ) + if len(self.layers[0].mlp.experts) == 0: + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)] + + for layer_idx, decoder_layer in enumerate(self.layers): + + if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)] + + if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval: + if len(self.layers[next_layers].mlp.experts) > 0: # for testing + raise ValueError("Problem with offloading") + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)] + next_layers += 1 + + with torch.no_grad(): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + custom_pos_emb=custom_pos_emb, + mode=mode, + first_step=first_step, + gen_timestep_scatter_index=gen_timestep_scatter_index, + ) + + if layer_idx >= 0: + asyncio.run_coroutine_threadsafe( + self.moe_lru._async_offload_to_cpu(layer_idx), + self.moe_lru._loop + ) + self.layers[layer_idx].mlp.experts = [] hidden_states = layer_outputs[0]