async cache revamp

Added an async loading and offloading of moe layers, having consistent memory with oom errors.
Used to give oom error after the third layer with 24 giga bytes gpu, now goes to the end with consistent memory with minimal latency
This commit is contained in:
Yousef Rafat 2025-11-14 09:15:16 +02:00
parent 44346c4251
commit 7b4c1e8031

View File

@ -1,15 +1,17 @@
import os
import gc
import math
import time
import torch
import psutil
import asyncio
import threading
import torch.nn as nn
from pathlib import Path
import concurrent.futures
from einops import rearrange
import torch.nn.functional as F
from collections import OrderedDict
from safetensors import safe_open
from contextlib import contextmanager
from transformers.cache_utils import StaticCache
from typing import Optional, Tuple, Any, List, Dict
from comfy.ldm.modules.attention import optimized_attention
@ -19,13 +21,13 @@ INIT_MOE = torch.cuda.device_count() != 1
if not INIT_MOE:
MOE_LAYER_SIZE = (1024**3) * 2.65 # approx
CPU_MOE_RATIO = None
torch.cuda.set_device(0)
props = torch.cuda.get_device_properties(0)
INIT_CUDA_MEM = (props.total_memory - torch.cuda.memory_reserved()) * 0.9
ADDITIONAL_LAYERS_IN_GPU = math.floor(INIT_CUDA_MEM / MOE_LAYER_SIZE)
LAYERS_IN_CPU = math.floor((int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES'))
- psutil.Process(os.getpid()).memory_info().rss
- (2*1024**3)) * 0.50) / MOE_LAYER_SIZE)
class HunyuanStaticCache(StaticCache):
@ -286,20 +288,15 @@ def topkgating(
logits = logits.float()
gates = F.softmax(logits, dim=1)
extra = ADDITIONAL_LAYERS_IN_GPU
values_all, indices_all = torch.topk(gates, topk + extra, dim=1)
values_all, indices_all = torch.topk(gates, topk, dim=1)
expert_weight = values_all[:, :topk]
expert_index = indices_all[:, :topk]
_, cpu_expert_index = torch.topk(gates, int(CPU_MOE_RATIO * 64), dim = 1)
cpu_expert_index = cpu_expert_index[:, (8 + ADDITIONAL_LAYERS_IN_GPU):]
if norm_topk_prob and topk > 1:
denom = expert_weight.sum(dim=1, keepdim=True).clamp_min(torch.finfo(gates.dtype).eps)
expert_weight = expert_weight / denom
return expert_weight, expert_index, cpu_expert_index, indices_all
return expert_weight, expert_index
class HunyuanRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
@ -452,7 +449,7 @@ class HunyuanTopKGate(nn.Module):
return gate_output
class HunyuanMLP(nn.Module):
def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False):
def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False, device=None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
@ -462,8 +459,8 @@ class HunyuanMLP(nn.Module):
self.act_fn = torch.nn.functional.silu
self.intermediate_size *= 2 # SwiGLU
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False)
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device)
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device)
def forward(self, x):
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
if x.ndim == 2:
@ -474,204 +471,143 @@ class HunyuanMLP(nn.Module):
return down_proj
class MoELRUCache(nn.Module):
def __init__(self, cpu_mem: int = 50, safety_buffer_bytes = 3*(1024**3), max_gpu_eviction_attempts = 8):
def __init__(self):
super().__init__()
global CPU_MOE_RATIO
_, total = torch.cuda.mem_get_info()
max_gpu_mem_gb = max((total - 2 * safety_buffer_bytes) / (1024**3), 1)
self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3)
self.MAX_CPU_MEM = int(cpu_mem * 1024**3)
self.gpu_cache = OrderedDict()
self.cpu_cache = OrderedDict()
self.offload_stream = torch.cuda.Stream()
self.load_stream = torch.cuda.Stream()
self.gpu_mem_usage = 0
self.cpu_mem_usage = 0
# 50% for system and headroom
try:
self.MAX_CPU_MEM = int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES'))
- psutil.Process(os.getpid()).memory_info().rss
- safety_buffer_bytes) * 0.55
except:
self.MAX_CPU_MEM = int(cpu_mem * (1024**3) * 0.5) # TODO
self.last_offload_event = None
self._loop = asyncio.new_event_loop()
threading.Thread(target=self._loop.run_forever, daemon=True).start()
ADDITIONAL_LAYERS_IN_CPU = math.floor((50 * (1024**3)) / MOE_LAYER_SIZE)
CPU_MOE_RATIO = (min(64 - ADDITIONAL_LAYERS_IN_GPU, ADDITIONAL_LAYERS_IN_CPU)) / 64
async def _async_offload_to_cpu(self, layer_idx):
# async offload from gpu (removed)
self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3)
self.SAFETY_BUFFER = int(safety_buffer_bytes)
self.MAX_GPU_EVICT_ATTEMPTS = max_gpu_eviction_attempts
num_experts = 64
moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i])
for i in range(num_experts)
if (layer_idx * num_experts + i) in self.gpu_cache]
event = torch.cuda.Event()
def _gpu_free_bytes(self):
free, total = torch.cuda.mem_get_info()
return int(free)
def _estimate_size(self, moe):
# include parameters + buffers
size = 0
for p in moe.parameters():
size += p.numel() * p.element_size()
for b in moe.buffers():
size += b.numel() * b.element_size()
return int(size)
with torch.cuda.stream(self.offload_stream):
for index, moe in moe_group:
moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True)
for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()):
if p_gpu.device.type == "meta":
continue
with torch.no_grad():
p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True)
p_cpu.copy_(p_gpu, non_blocking=True)
def _evict_until_free(self, required_bytes, max_attempts=16):
attempts = 0
while self._gpu_free_bytes() < required_bytes and attempts < max_attempts:
evicted = self._evict_from_gpu()
if not evicted:
break
attempts += 1
return self._gpu_free_bytes() >= required_bytes
self.cpu_cache[index] = moe_cpu
@contextmanager
def ensure_headroom(self, required_bytes):
self.offload_stream.record_event(event)
safety = getattr(self, "SAFETY_BUFFER", 0)
target_free = int(required_bytes + safety)
self.last_offload_event = event
if getattr(self, "_headroom", None) is not None:
try:
del self._headroom
except Exception:
pass
self._headroom = None
def finalize_offload_layer():
event.synchronize()
for index, moe in moe_group:
moe.to("meta")
self.gpu_cache.pop(index, None)
del moe
torch.cuda.empty_cache()
ok = self._evict_until_free(target_free)
if not ok and self._gpu_free_bytes() < target_free:
# last ditch
try:
torch.cuda.empty_cache()
except Exception:
pass
threading.Thread(target=finalize_offload_layer, daemon=True).start()
try:
yield
finally:
if getattr(self, "_headroom", None) is None:
try:
self._headroom = torch.empty((self._headroom_bytes,), dtype=torch.uint8, device="cuda:0")
except Exception:
self._headroom = None
async def _async_load_to_gpu(self, index, moe):
def add_gpu(self, moe, index, allowed_retries=3):
size = self._estimate_size(moe)
while self.gpu_mem_usage + size > self.MAX_GPU_MEM:
if not self._evict_from_gpu():
# if enough memory load, otherwise wait for offload
while True:
free_bytes, _ = torch.cuda.mem_get_info()
if free_bytes > 2 * MOE_LAYER_SIZE:
break
attempts = 0
while self._gpu_free_bytes() < size + self.SAFETY_BUFFER and attempts < self.MAX_GPU_EVICT_ATTEMPTS:
if not self._evict_from_gpu():
break
attempts += 1
self.last_offload_event.synchronize()
torch.cuda.empty_cache()
await asyncio.sleep(0.01)
for _ in range(allowed_retries):
try:
moe_cuda = moe.to("cuda:0")
break
except RuntimeError as e:
if "out of memory" not in str(e).lower():
raise
evicted = self._evict_from_gpu()
if not evicted: # can't evict
raise
else:
raise RuntimeError("Failed to move expert to GPU after evictions")
# async loading from cpu -> gpu
with torch.cuda.stream(self.load_stream):
moe_gpu = HunyuanMLP(moe.config).to("cuda", non_blocking=True)
for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()):
with torch.no_grad():
p_gpu.data = torch.empty_like(p_cpu, device="cuda")
p_gpu.copy_(p_cpu, non_blocking=True)
self.gpu_cache[index] = moe_cuda
self.gpu_cache.move_to_end(index)
self.gpu_mem_usage += size
def finalize_load():
self.gpu_cache[index] = moe_gpu
self.cpu_cache.pop(index, None)
return
threading.Thread(target=finalize_load, daemon=True).start()
def add_cpu(self, moe, index):
size = self._estimate_size(moe)
while self.cpu_mem_usage + size > self.MAX_CPU_MEM:
if not self._evict_from_cpu():
break
moe_cpu = moe.to("cpu")
for _, p in moe_cpu.named_parameters():
if not p.is_pinned():
p.data = p.data.pin_memory()
self.cpu_cache[index] = moe_cpu
self.cpu_cache.move_to_end(index)
self.cpu_mem_usage += size
def get_from_device(self, index):
if index in self.gpu_cache:
moe = self.gpu_cache[index]
self.gpu_cache.move_to_end(index)
return moe
if index in self.cpu_cache:
moe = self.cpu_cache.pop(index)
self.cpu_mem_usage = max(0, self.cpu_mem_usage - self._estimate_size(moe))
try:
self.add_gpu(moe, index)
return self.gpu_cache[index]
except RuntimeError:
self.cpu_cache[index] = moe
self.cpu_cache.move_to_end(index)
self.cpu_mem_usage += self._estimate_size(moe)
raise
return None # load from disk
def _evict_from_gpu(self):
if not self.gpu_cache:
return False
idx, moe = self.gpu_cache.popitem(last=False)
size = self._estimate_size(moe)
self.gpu_mem_usage = max(0, self.gpu_mem_usage - size)
if self.cpu_mem_usage + size <= self.MAX_CPU_MEM:
try:
moe_cpu = moe.to("cpu")
except Exception:
# drop the model if cpu is full
del moe
return True
self.cpu_cache[idx] = moe_cpu
self.cpu_cache.move_to_end(idx)
self.cpu_mem_usage += size
return True
else:
del moe
return True
def _evict_from_cpu(self):
if not self.cpu_cache:
return False
_, moe = self.cpu_cache.popitem(last=False)
size = self._estimate_size(moe)
self.cpu_mem_usage = max(0, self.cpu_mem_usage - size)
del moe
gc.collect()
return True
class LazyMoELoader(nn.Module):
def __init__(self, device):
def __init__(self, cache, config):
super().__init__()
self.device = device
self.cache = cache
self.config = config
self._loop = cache._loop
def lazy_init(self, config, layer_idx, expert_idx):
def get_checkpoint(self):
comfyui_dir = Path.home() / "ComfyUI"
checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors"
checkpoint = checkpoint.resolve()
if not os.path.exists(checkpoint):
raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}")
return checkpoint
def lazy_init(self, layer_idx, expert_idx):
checkpoint = self.get_checkpoint()
prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}."
additional_prefix = f"model.layers.{layer_idx}.mlp.gate_and_up_proj.weight"
sd = {}
with safe_open(checkpoint, framework="pt", device=self.device) as f:
with safe_open(checkpoint, framework="pt", device="cpu") as f:
for k in f.keys():
if k.startswith(prefix) or k.startswith(additional_prefix):
new_k = k.split(f"experts.{expert_idx}.", 1)[1]
sd[new_k] = f.get_tensor(k)
return HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd).to(self.deivce)
return HunyuanMLP(self.config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd)
async def lazy_load_from_disk(self, layer_idx, expert_idx):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx)
def _schedule_disk_load(self, layer_idx, expert_idx):
coro = self.lazy_load_from_disk(layer_idx, expert_idx)
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
def _on_disk_loaded(fut):
moe_cpu = fut.result()
def _add_cpu_in_main_thread():
self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx)
asyncio.run_coroutine_threadsafe(
self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu),
self.cache._loop
)
threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start()
future.add_done_callback(_on_disk_loaded)
return future
def enough_vram(required_bytes):
free, total = torch.cuda.mem_get_info()
return free > required_bytes
class HunyuanMoE(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None):
@ -687,7 +623,7 @@ class HunyuanMoE(nn.Module):
[HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)]
)
else:
self.experts = None
self.experts = []
self.moe_lru = moe_lru
def forward(self, hidden_states):
@ -702,38 +638,10 @@ class HunyuanMoE(nn.Module):
reshaped_input = hidden_states.reshape(-1, hidden_size)
with torch.cuda.nvtx.range("MoE"):
expert_weight, expert_index, cpu_expert_index, indices_all = self.gate(hidden_states)
if not INIT_MOE:
if ADDITIONAL_LAYERS_IN_GPU > 0:
additional_expert_index = indices_all[:, expert_index.size(1): expert_index.size(1) + ADDITIONAL_LAYERS_IN_GPU]
flat = additional_expert_index.reshape(-1).to("cpu")
counts = torch.bincount(flat, minlength=self.num_experts)
top_extra = torch.topk(counts, k=min(ADDITIONAL_LAYERS_IN_GPU, (counts>0).sum().item())).indices.tolist()
for expert_id in top_extra:
if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None:
expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id)
self.moe_lru.add_gpu(expert_cpu, expert_id + self.layer_idx)
if cpu_expert_index is not None and cpu_expert_index.numel() > 0:
for expert_id in torch.unique(cpu_expert_index).cpu().tolist():
if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None:
expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id)
self.moe_lru.add_cpu(expert_cpu, expert_id + self.layer_idx)
expert_weight, expert_index = self.gate(hidden_states)
combined_output = torch.zeros_like(reshaped_input)
experts_list = []
for e in range(self.num_experts):
token_mask = (expert_index == e)
if not token_mask.any():
continue
expert = self.moe_lru.get_from_device(e + self.layer_idx)
if expert is None:
expert = LazyMoELoader()
expert = expert.lazy_init(self.config, self.layer_idx, e)
self.moe_lru.add_gpu(expert, e + self.layer_idx)
experts_list.append((e, expert))
experts_list = [(i, expert) for i, expert in enumerate(self.experts)]
per_pos, per_tokens, per_weights = [], [], []
for e, _ in experts_list:
@ -761,6 +669,8 @@ class HunyuanMoE(nn.Module):
l1, l2 = [], []
for _, expert in experts_list:
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
expert = expert.result()
l1.append(expert.gate_and_up_proj)
l2.append(expert.down_proj)
@ -769,6 +679,12 @@ class HunyuanMoE(nn.Module):
W1_T = W1.transpose(1, 2)
W2_T = W2.transpose(1, 2)
# wait for enough vram for the computations
while not enough_vram(5*(1024 ** 3)):
event = self.moe_lru.last_offload_event
if event is not None and not event.query():
time.sleep(0.001)
x = torch.bmm(tokens_padded, W1_T)
x = F.silu(x)
@ -781,7 +697,7 @@ class HunyuanMoE(nn.Module):
for i, token_positions in enumerate(per_pos):
Ni = lengths[i]
out_i = out_padded[i, :Ni]
combined_output.index_add_(0, token_positions.to(hidden_states.device), out_i)
combined_output.to(hidden_states.dtype).index_add_(0, token_positions.to(hidden_states.device), out_i.to(hidden_states.dtype))
#dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
#chunks = dispatched_input.chunk(self.num_experts, dim=0)
@ -933,6 +849,7 @@ class HunyuanImage3Model(nn.Module):
super().__init__()
self.padding_idx = 128009
self.vocab_size = 133120
self.config = config
self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx)
self.layers = nn.ModuleList(
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])]
@ -941,6 +858,7 @@ class HunyuanImage3Model(nn.Module):
self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
self.shared_tensor = None
self.moe_lru = moe_lru
def forward(
self,
@ -962,19 +880,45 @@ class HunyuanImage3Model(nn.Module):
hidden_states = inputs_embeds
next_decoder_cache = None
for layer_idx, decoder_layer in enumerate(self.layers):
next_layers = 0
sparse_interval = max(1, len(self.layers) // 3)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
custom_pos_emb=custom_pos_emb,
mode=mode,
first_step=first_step,
gen_timestep_scatter_index=gen_timestep_scatter_index,
)
if len(self.layers[0].mlp.experts) == 0:
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)]
for layer_idx, decoder_layer in enumerate(self.layers):
if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)]
if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval:
if len(self.layers[next_layers].mlp.experts) > 0: # for testing
raise ValueError("Problem with offloading")
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
next_layers += 1
with torch.no_grad():
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
custom_pos_emb=custom_pos_emb,
mode=mode,
first_step=first_step,
gen_timestep_scatter_index=gen_timestep_scatter_index,
)
if layer_idx >= 0:
asyncio.run_coroutine_threadsafe(
self.moe_lru._async_offload_to_cpu(layer_idx),
self.moe_lru._loop
)
self.layers[layer_idx].mlp.experts = []
hidden_states = layer_outputs[0]