async cache revamp

Added an async loading and offloading of moe layers, having consistent memory with oom errors.
Used to give oom error after the third layer with 24 giga bytes gpu, now goes to the end with consistent memory with minimal latency
This commit is contained in:
Yousef Rafat 2025-11-14 09:15:16 +02:00
parent 44346c4251
commit 7b4c1e8031

View File

@ -1,15 +1,17 @@
import os import os
import gc
import math import math
import time
import torch import torch
import psutil import psutil
import asyncio
import threading
import torch.nn as nn import torch.nn as nn
from pathlib import Path from pathlib import Path
import concurrent.futures
from einops import rearrange from einops import rearrange
import torch.nn.functional as F import torch.nn.functional as F
from collections import OrderedDict from collections import OrderedDict
from safetensors import safe_open from safetensors import safe_open
from contextlib import contextmanager
from transformers.cache_utils import StaticCache from transformers.cache_utils import StaticCache
from typing import Optional, Tuple, Any, List, Dict from typing import Optional, Tuple, Any, List, Dict
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
@ -19,13 +21,13 @@ INIT_MOE = torch.cuda.device_count() != 1
if not INIT_MOE: if not INIT_MOE:
MOE_LAYER_SIZE = (1024**3) * 2.65 # approx MOE_LAYER_SIZE = (1024**3) * 2.65 # approx
CPU_MOE_RATIO = None
torch.cuda.set_device(0) torch.cuda.set_device(0)
props = torch.cuda.get_device_properties(0) props = torch.cuda.get_device_properties(0)
INIT_CUDA_MEM = (props.total_memory - torch.cuda.memory_reserved()) * 0.9 LAYERS_IN_CPU = math.floor((int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES'))
ADDITIONAL_LAYERS_IN_GPU = math.floor(INIT_CUDA_MEM / MOE_LAYER_SIZE) - psutil.Process(os.getpid()).memory_info().rss
- (2*1024**3)) * 0.50) / MOE_LAYER_SIZE)
class HunyuanStaticCache(StaticCache): class HunyuanStaticCache(StaticCache):
@ -286,20 +288,15 @@ def topkgating(
logits = logits.float() logits = logits.float()
gates = F.softmax(logits, dim=1) gates = F.softmax(logits, dim=1)
extra = ADDITIONAL_LAYERS_IN_GPU values_all, indices_all = torch.topk(gates, topk, dim=1)
values_all, indices_all = torch.topk(gates, topk + extra, dim=1)
expert_weight = values_all[:, :topk] expert_weight = values_all[:, :topk]
expert_index = indices_all[:, :topk] expert_index = indices_all[:, :topk]
_, cpu_expert_index = torch.topk(gates, int(CPU_MOE_RATIO * 64), dim = 1)
cpu_expert_index = cpu_expert_index[:, (8 + ADDITIONAL_LAYERS_IN_GPU):]
if norm_topk_prob and topk > 1: if norm_topk_prob and topk > 1:
denom = expert_weight.sum(dim=1, keepdim=True).clamp_min(torch.finfo(gates.dtype).eps) denom = expert_weight.sum(dim=1, keepdim=True).clamp_min(torch.finfo(gates.dtype).eps)
expert_weight = expert_weight / denom expert_weight = expert_weight / denom
return expert_weight, expert_index, cpu_expert_index, indices_all return expert_weight, expert_index
class HunyuanRMSNorm(nn.Module): class HunyuanRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
@ -452,7 +449,7 @@ class HunyuanTopKGate(nn.Module):
return gate_output return gate_output
class HunyuanMLP(nn.Module): class HunyuanMLP(nn.Module):
def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False): def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False, device=None):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer_idx = layer_idx self.layer_idx = layer_idx
@ -462,8 +459,8 @@ class HunyuanMLP(nn.Module):
self.act_fn = torch.nn.functional.silu self.act_fn = torch.nn.functional.silu
self.intermediate_size *= 2 # SwiGLU self.intermediate_size *= 2 # SwiGLU
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device)
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device)
def forward(self, x): def forward(self, x):
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device) self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
if x.ndim == 2: if x.ndim == 2:
@ -474,204 +471,143 @@ class HunyuanMLP(nn.Module):
return down_proj return down_proj
class MoELRUCache(nn.Module): class MoELRUCache(nn.Module):
def __init__(self, cpu_mem: int = 50, safety_buffer_bytes = 3*(1024**3), max_gpu_eviction_attempts = 8): def __init__(self):
super().__init__() super().__init__()
global CPU_MOE_RATIO
_, total = torch.cuda.mem_get_info()
max_gpu_mem_gb = max((total - 2 * safety_buffer_bytes) / (1024**3), 1)
self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3)
self.MAX_CPU_MEM = int(cpu_mem * 1024**3)
self.gpu_cache = OrderedDict() self.gpu_cache = OrderedDict()
self.cpu_cache = OrderedDict() self.cpu_cache = OrderedDict()
self.offload_stream = torch.cuda.Stream()
self.load_stream = torch.cuda.Stream()
self.gpu_mem_usage = 0 self.last_offload_event = None
self.cpu_mem_usage = 0 self._loop = asyncio.new_event_loop()
# 50% for system and headroom threading.Thread(target=self._loop.run_forever, daemon=True).start()
try:
self.MAX_CPU_MEM = int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES'))
- psutil.Process(os.getpid()).memory_info().rss
- safety_buffer_bytes) * 0.55
except:
self.MAX_CPU_MEM = int(cpu_mem * (1024**3) * 0.5) # TODO
ADDITIONAL_LAYERS_IN_CPU = math.floor((50 * (1024**3)) / MOE_LAYER_SIZE) async def _async_offload_to_cpu(self, layer_idx):
CPU_MOE_RATIO = (min(64 - ADDITIONAL_LAYERS_IN_GPU, ADDITIONAL_LAYERS_IN_CPU)) / 64 # async offload from gpu (removed)
self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3) num_experts = 64
self.SAFETY_BUFFER = int(safety_buffer_bytes) moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i])
self.MAX_GPU_EVICT_ATTEMPTS = max_gpu_eviction_attempts for i in range(num_experts)
if (layer_idx * num_experts + i) in self.gpu_cache]
event = torch.cuda.Event()
def _gpu_free_bytes(self): with torch.cuda.stream(self.offload_stream):
free, total = torch.cuda.mem_get_info() for index, moe in moe_group:
return int(free) moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True)
for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()):
if p_gpu.device.type == "meta":
continue
with torch.no_grad():
p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True)
p_cpu.copy_(p_gpu, non_blocking=True)
def _estimate_size(self, moe): self.cpu_cache[index] = moe_cpu
# include parameters + buffers
size = 0
for p in moe.parameters():
size += p.numel() * p.element_size()
for b in moe.buffers():
size += b.numel() * b.element_size()
return int(size)
def _evict_until_free(self, required_bytes, max_attempts=16): self.offload_stream.record_event(event)
attempts = 0
while self._gpu_free_bytes() < required_bytes and attempts < max_attempts:
evicted = self._evict_from_gpu()
if not evicted:
break
attempts += 1
return self._gpu_free_bytes() >= required_bytes
@contextmanager self.last_offload_event = event
def ensure_headroom(self, required_bytes):
safety = getattr(self, "SAFETY_BUFFER", 0) def finalize_offload_layer():
target_free = int(required_bytes + safety) event.synchronize()
for index, moe in moe_group:
moe.to("meta")
self.gpu_cache.pop(index, None)
del moe
torch.cuda.empty_cache()
if getattr(self, "_headroom", None) is not None: threading.Thread(target=finalize_offload_layer, daemon=True).start()
try:
del self._headroom
except Exception:
pass
self._headroom = None
ok = self._evict_until_free(target_free) async def _async_load_to_gpu(self, index, moe):
if not ok and self._gpu_free_bytes() < target_free:
# last ditch
try:
torch.cuda.empty_cache()
except Exception:
pass
try: # if enough memory load, otherwise wait for offload
yield while True:
finally: free_bytes, _ = torch.cuda.mem_get_info()
if getattr(self, "_headroom", None) is None: if free_bytes > 2 * MOE_LAYER_SIZE:
try:
self._headroom = torch.empty((self._headroom_bytes,), dtype=torch.uint8, device="cuda:0")
except Exception:
self._headroom = None
def add_gpu(self, moe, index, allowed_retries=3):
size = self._estimate_size(moe)
while self.gpu_mem_usage + size > self.MAX_GPU_MEM:
if not self._evict_from_gpu():
break break
attempts = 0 self.last_offload_event.synchronize()
while self._gpu_free_bytes() < size + self.SAFETY_BUFFER and attempts < self.MAX_GPU_EVICT_ATTEMPTS: torch.cuda.empty_cache()
if not self._evict_from_gpu(): await asyncio.sleep(0.01)
break
attempts += 1
for _ in range(allowed_retries): # async loading from cpu -> gpu
try: with torch.cuda.stream(self.load_stream):
moe_cuda = moe.to("cuda:0") moe_gpu = HunyuanMLP(moe.config).to("cuda", non_blocking=True)
break for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()):
except RuntimeError as e: with torch.no_grad():
if "out of memory" not in str(e).lower(): p_gpu.data = torch.empty_like(p_cpu, device="cuda")
raise p_gpu.copy_(p_cpu, non_blocking=True)
evicted = self._evict_from_gpu()
if not evicted: # can't evict
raise
else:
raise RuntimeError("Failed to move expert to GPU after evictions")
self.gpu_cache[index] = moe_cuda def finalize_load():
self.gpu_cache.move_to_end(index) self.gpu_cache[index] = moe_gpu
self.gpu_mem_usage += size self.cpu_cache.pop(index, None)
return threading.Thread(target=finalize_load, daemon=True).start()
def add_cpu(self, moe, index): def add_cpu(self, moe, index):
size = self._estimate_size(moe)
while self.cpu_mem_usage + size > self.MAX_CPU_MEM:
if not self._evict_from_cpu():
break
moe_cpu = moe.to("cpu") moe_cpu = moe.to("cpu")
for _, p in moe_cpu.named_parameters():
if not p.is_pinned():
p.data = p.data.pin_memory()
self.cpu_cache[index] = moe_cpu self.cpu_cache[index] = moe_cpu
self.cpu_cache.move_to_end(index) self.cpu_cache.move_to_end(index)
self.cpu_mem_usage += size
def get_from_device(self, index):
if index in self.gpu_cache:
moe = self.gpu_cache[index]
self.gpu_cache.move_to_end(index)
return moe
if index in self.cpu_cache:
moe = self.cpu_cache.pop(index)
self.cpu_mem_usage = max(0, self.cpu_mem_usage - self._estimate_size(moe))
try:
self.add_gpu(moe, index)
return self.gpu_cache[index]
except RuntimeError:
self.cpu_cache[index] = moe
self.cpu_cache.move_to_end(index)
self.cpu_mem_usage += self._estimate_size(moe)
raise
return None # load from disk
def _evict_from_gpu(self):
if not self.gpu_cache:
return False
idx, moe = self.gpu_cache.popitem(last=False)
size = self._estimate_size(moe)
self.gpu_mem_usage = max(0, self.gpu_mem_usage - size)
if self.cpu_mem_usage + size <= self.MAX_CPU_MEM:
try:
moe_cpu = moe.to("cpu")
except Exception:
# drop the model if cpu is full
del moe
return True
self.cpu_cache[idx] = moe_cpu
self.cpu_cache.move_to_end(idx)
self.cpu_mem_usage += size
return True
else:
del moe
return True
def _evict_from_cpu(self):
if not self.cpu_cache:
return False
_, moe = self.cpu_cache.popitem(last=False)
size = self._estimate_size(moe)
self.cpu_mem_usage = max(0, self.cpu_mem_usage - size)
del moe
gc.collect()
return True
class LazyMoELoader(nn.Module): class LazyMoELoader(nn.Module):
def __init__(self, device): def __init__(self, cache, config):
super().__init__() super().__init__()
self.device = device self.cache = cache
self.config = config
self._loop = cache._loop
def lazy_init(self, config, layer_idx, expert_idx): def get_checkpoint(self):
comfyui_dir = Path.home() / "ComfyUI" comfyui_dir = Path.home() / "ComfyUI"
checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors" checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors"
checkpoint = checkpoint.resolve() checkpoint = checkpoint.resolve()
if not os.path.exists(checkpoint): if not os.path.exists(checkpoint):
raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}") raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}")
return checkpoint
def lazy_init(self, layer_idx, expert_idx):
checkpoint = self.get_checkpoint()
prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}." prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}."
additional_prefix = f"model.layers.{layer_idx}.mlp.gate_and_up_proj.weight" additional_prefix = f"model.layers.{layer_idx}.mlp.gate_and_up_proj.weight"
sd = {} sd = {}
with safe_open(checkpoint, framework="pt", device=self.device) as f: with safe_open(checkpoint, framework="pt", device="cpu") as f:
for k in f.keys(): for k in f.keys():
if k.startswith(prefix) or k.startswith(additional_prefix): if k.startswith(prefix) or k.startswith(additional_prefix):
new_k = k.split(f"experts.{expert_idx}.", 1)[1] new_k = k.split(f"experts.{expert_idx}.", 1)[1]
sd[new_k] = f.get_tensor(k) sd[new_k] = f.get_tensor(k)
return HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd).to(self.deivce) return HunyuanMLP(self.config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd)
async def lazy_load_from_disk(self, layer_idx, expert_idx):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx)
def _schedule_disk_load(self, layer_idx, expert_idx):
coro = self.lazy_load_from_disk(layer_idx, expert_idx)
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
def _on_disk_loaded(fut):
moe_cpu = fut.result()
def _add_cpu_in_main_thread():
self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx)
asyncio.run_coroutine_threadsafe(
self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu),
self.cache._loop
)
threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start()
future.add_done_callback(_on_disk_loaded)
return future
def enough_vram(required_bytes):
free, total = torch.cuda.mem_get_info()
return free > required_bytes
class HunyuanMoE(nn.Module): class HunyuanMoE(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None): def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None):
@ -687,7 +623,7 @@ class HunyuanMoE(nn.Module):
[HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)] [HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)]
) )
else: else:
self.experts = None self.experts = []
self.moe_lru = moe_lru self.moe_lru = moe_lru
def forward(self, hidden_states): def forward(self, hidden_states):
@ -702,38 +638,10 @@ class HunyuanMoE(nn.Module):
reshaped_input = hidden_states.reshape(-1, hidden_size) reshaped_input = hidden_states.reshape(-1, hidden_size)
with torch.cuda.nvtx.range("MoE"): with torch.cuda.nvtx.range("MoE"):
expert_weight, expert_index, cpu_expert_index, indices_all = self.gate(hidden_states) expert_weight, expert_index = self.gate(hidden_states)
if not INIT_MOE:
if ADDITIONAL_LAYERS_IN_GPU > 0:
additional_expert_index = indices_all[:, expert_index.size(1): expert_index.size(1) + ADDITIONAL_LAYERS_IN_GPU]
flat = additional_expert_index.reshape(-1).to("cpu")
counts = torch.bincount(flat, minlength=self.num_experts)
top_extra = torch.topk(counts, k=min(ADDITIONAL_LAYERS_IN_GPU, (counts>0).sum().item())).indices.tolist()
for expert_id in top_extra:
if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None:
expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id)
self.moe_lru.add_gpu(expert_cpu, expert_id + self.layer_idx)
if cpu_expert_index is not None and cpu_expert_index.numel() > 0:
for expert_id in torch.unique(cpu_expert_index).cpu().tolist():
if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None:
expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id)
self.moe_lru.add_cpu(expert_cpu, expert_id + self.layer_idx)
combined_output = torch.zeros_like(reshaped_input) combined_output = torch.zeros_like(reshaped_input)
experts_list = [] experts_list = [(i, expert) for i, expert in enumerate(self.experts)]
for e in range(self.num_experts):
token_mask = (expert_index == e)
if not token_mask.any():
continue
expert = self.moe_lru.get_from_device(e + self.layer_idx)
if expert is None:
expert = LazyMoELoader()
expert = expert.lazy_init(self.config, self.layer_idx, e)
self.moe_lru.add_gpu(expert, e + self.layer_idx)
experts_list.append((e, expert))
per_pos, per_tokens, per_weights = [], [], [] per_pos, per_tokens, per_weights = [], [], []
for e, _ in experts_list: for e, _ in experts_list:
@ -761,6 +669,8 @@ class HunyuanMoE(nn.Module):
l1, l2 = [], [] l1, l2 = [], []
for _, expert in experts_list: for _, expert in experts_list:
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
expert = expert.result()
l1.append(expert.gate_and_up_proj) l1.append(expert.gate_and_up_proj)
l2.append(expert.down_proj) l2.append(expert.down_proj)
@ -770,6 +680,12 @@ class HunyuanMoE(nn.Module):
W1_T = W1.transpose(1, 2) W1_T = W1.transpose(1, 2)
W2_T = W2.transpose(1, 2) W2_T = W2.transpose(1, 2)
# wait for enough vram for the computations
while not enough_vram(5*(1024 ** 3)):
event = self.moe_lru.last_offload_event
if event is not None and not event.query():
time.sleep(0.001)
x = torch.bmm(tokens_padded, W1_T) x = torch.bmm(tokens_padded, W1_T)
x = F.silu(x) x = F.silu(x)
@ -781,7 +697,7 @@ class HunyuanMoE(nn.Module):
for i, token_positions in enumerate(per_pos): for i, token_positions in enumerate(per_pos):
Ni = lengths[i] Ni = lengths[i]
out_i = out_padded[i, :Ni] out_i = out_padded[i, :Ni]
combined_output.index_add_(0, token_positions.to(hidden_states.device), out_i) combined_output.to(hidden_states.dtype).index_add_(0, token_positions.to(hidden_states.device), out_i.to(hidden_states.dtype))
#dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input) #dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
#chunks = dispatched_input.chunk(self.num_experts, dim=0) #chunks = dispatched_input.chunk(self.num_experts, dim=0)
@ -933,6 +849,7 @@ class HunyuanImage3Model(nn.Module):
super().__init__() super().__init__()
self.padding_idx = 128009 self.padding_idx = 128009
self.vocab_size = 133120 self.vocab_size = 133120
self.config = config
self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx) self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])] [HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])]
@ -941,6 +858,7 @@ class HunyuanImage3Model(nn.Module):
self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
self.shared_tensor = None self.shared_tensor = None
self.moe_lru = moe_lru
def forward( def forward(
self, self,
@ -962,19 +880,45 @@ class HunyuanImage3Model(nn.Module):
hidden_states = inputs_embeds hidden_states = inputs_embeds
next_decoder_cache = None next_decoder_cache = None
next_layers = 0
sparse_interval = max(1, len(self.layers) // 3)
if len(self.layers[0].mlp.experts) == 0:
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)]
for layer_idx, decoder_layer in enumerate(self.layers): for layer_idx, decoder_layer in enumerate(self.layers):
layer_outputs = decoder_layer( if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded
hidden_states, experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
attention_mask=attention_mask, self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)]
position_ids=position_ids,
past_key_value=past_key_values, if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval:
use_cache=use_cache, if len(self.layers[next_layers].mlp.experts) > 0: # for testing
custom_pos_emb=custom_pos_emb, raise ValueError("Problem with offloading")
mode=mode, experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
first_step=first_step, self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
gen_timestep_scatter_index=gen_timestep_scatter_index, next_layers += 1
)
with torch.no_grad():
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
custom_pos_emb=custom_pos_emb,
mode=mode,
first_step=first_step,
gen_timestep_scatter_index=gen_timestep_scatter_index,
)
if layer_idx >= 0:
asyncio.run_coroutine_threadsafe(
self.moe_lru._async_offload_to_cpu(layer_idx),
self.moe_lru._loop
)
self.layers[layer_idx].mlp.experts = []
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]