decrease peak memory in moe forward

This commit is contained in:
Yousef Rafat 2025-11-22 23:02:42 +02:00
parent ae8592ebf5
commit a3ac798d4e

View File

@ -2,7 +2,6 @@ import os
import math
import time
import torch
import psutil
import asyncio
import threading
import torch.nn as nn
@ -13,12 +12,13 @@ import torch.nn.functional as F
from collections import OrderedDict
from safetensors import safe_open
from transformers.cache_utils import StaticCache
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Tuple, Any, List, Dict
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock
INIT_MOE = torch.cuda.device_count() != 1
MOE_LAYER_SIZE = (1024**3) * 2.65 # approx
MOE_LAYER_SIZE = (1024**3) * 5.15 # approx
class HunyuanStaticCache(StaticCache):
@ -576,7 +576,7 @@ def parse_layer_expert(key):
return layer, expert
class LazyMoELoader(nn.Module):
def __init__(self, cache, config):
def __init__(self, cache, config, max_workers = 16, max_concurrent_loads = 32):
super().__init__()
self.cache = cache
self.config = config
@ -586,6 +586,9 @@ class LazyMoELoader(nn.Module):
self._file = safe_open(self._checkpoint, framework="pt", device="cpu", mmap=True)
self.expert_pool = self.build_meta_experts()
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._semaphore = threading.Semaphore(max_concurrent_loads)
def build_meta_experts(self):
pool = {}
for layer, experts in self.expert_key_index.items():
@ -629,28 +632,17 @@ class LazyMoELoader(nn.Module):
getattr(model, name).data = tensor
return model
async def lazy_load_from_disk(self, layer_idx, expert_idx):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx)
def _schedule_disk_load(self, layer_idx, expert_idx):
def _load_single_expert(self, layer_idx, expert_idx):
with self._semaphore:
return self.lazy_init(layer_idx, expert_idx)
coro = self.lazy_load_from_disk(layer_idx, expert_idx)
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
def _on_disk_loaded(fut):
moe_cpu = fut.result()
def _add_cpu_in_main_thread():
self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx)
asyncio.run_coroutine_threadsafe(
self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu),
self.cache._loop
)
threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start()
future.add_done_callback(_on_disk_loaded)
return future
def schedule_layer_load(self, layer_idx, num_experts = 64):
futures = []
for i in range(num_experts):
coro = asyncio.get_event_loop().run_in_executor(self._executor, self._load_single_expert, layer_idx, i)
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
futures.append(fut)
return futures
def enough_vram(required_bytes):
free, total = torch.cuda.mem_get_info()
@ -700,16 +692,7 @@ class HunyuanMoE(nn.Module):
pass
else:
tokens_padded = dispatched_input[used_indices]
l1_layers, l2_layers = [], []
for i in used_indices:
expert = self.experts[i]
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
expert = expert.result()
expert = expert.to(device)
l1_layers.append(expert.gate_and_up_proj)
l2_layers.append(expert.down_proj)
l1, l2 = [], []
for i in used_indices:
expert = self.experts[i]
@ -720,24 +703,32 @@ class HunyuanMoE(nn.Module):
l2.append(expert.down_proj)
compute_device = hidden_states.device
l1 = [m.to(compute_device) for m in l1]
l2 = [m.to(compute_device) for m in l2]
l1 = [m.to(compute_device) for m in l1]
W1 = torch.stack([m.weight for m in l1], dim=0)
W2 = torch.stack([m.weight for m in l2], dim=0)
del l1
W1_T = W1.transpose(1, 2)
del W1
x = torch.bmm(tokens_padded, W1_T)
del W1_T, tokens_padded
x1, x2 = x.chunk(2, dim=2)
gated = x1 * F.silu(x2)
l2 = [m.to(compute_device) for m in l2]
W2 = torch.stack([m.weight for m in l2], dim=0)
del l2
W2_T = W2.transpose(1, 2)
del W2
out_padded = torch.bmm(gated, W2_T)
del W2_T
while not enough_vram(3*(1024 ** 3)):
event = self.moe_lru.last_offload_event
if event is not None and not event.query():
time.sleep(0.001)
x = torch.bmm(tokens_padded, W1_T)
x1, x2 = x.chunk(2, dim=2)
gated = x1 * F.silu(x2)
out_padded = torch.bmm(gated, W2_T)
combine_weights_used = combine_weights[:, used_indices, :]
@ -746,7 +737,7 @@ class HunyuanMoE(nn.Module):
out_padded
)
del tokens_padded, W1, W2, W1_T, W2_T, x, x1, x2, gated, out_padded
del x, x1, x2, gated, out_padded
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
@ -899,7 +890,8 @@ class HunyuanImage3Model(nn.Module):
self.shared_tensor = None
self.moe_lru = moe_lru
self.self.additional_layers_set = False
self.additional_layers_set = False
self.moe_loader = LazyMoELoader(self.moe_lru, self.config)
def forward(
self,
@ -922,19 +914,19 @@ class HunyuanImage3Model(nn.Module):
sparse_interval = max(1, len(self.layers) // additional_layers)
if len(self.layers[0].mlp.experts) == 0:
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)]
self.layers[0].mlp.experts = self.moe_loader.schedule_layer_load(0)
for layer_idx, decoder_layer in enumerate(self.layers):
if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)]
self.layers[layer_idx+1].mlp.experts = self.moe_loader.schedule_layer_load(layer_idx + 1)
if layer_idx + 2 < len(self.layers) and len(self.layers[layer_idx + 2].mlp.experts) == 0: # load first and second layers
self.layers[layer_idx+2].mlp.experts = self.moe_loader.schedule_layer_load(layer_idx + 2)
if not self.additional_layers_set:
if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval:
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
self.layers[next_layers].mlp.experts = self.moe_loader.schedule_layer_load(next_layers)
next_layers += 1
with torch.no_grad():