decrease peak memory in moe forward

This commit is contained in:
Yousef Rafat 2025-11-22 23:02:42 +02:00
parent ae8592ebf5
commit a3ac798d4e

View File

@ -2,7 +2,6 @@ import os
import math import math
import time import time
import torch import torch
import psutil
import asyncio import asyncio
import threading import threading
import torch.nn as nn import torch.nn as nn
@ -13,12 +12,13 @@ import torch.nn.functional as F
from collections import OrderedDict from collections import OrderedDict
from safetensors import safe_open from safetensors import safe_open
from transformers.cache_utils import StaticCache from transformers.cache_utils import StaticCache
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Tuple, Any, List, Dict from typing import Optional, Tuple, Any, List, Dict
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock
INIT_MOE = torch.cuda.device_count() != 1 INIT_MOE = torch.cuda.device_count() != 1
MOE_LAYER_SIZE = (1024**3) * 2.65 # approx MOE_LAYER_SIZE = (1024**3) * 5.15 # approx
class HunyuanStaticCache(StaticCache): class HunyuanStaticCache(StaticCache):
@ -576,7 +576,7 @@ def parse_layer_expert(key):
return layer, expert return layer, expert
class LazyMoELoader(nn.Module): class LazyMoELoader(nn.Module):
def __init__(self, cache, config): def __init__(self, cache, config, max_workers = 16, max_concurrent_loads = 32):
super().__init__() super().__init__()
self.cache = cache self.cache = cache
self.config = config self.config = config
@ -586,6 +586,9 @@ class LazyMoELoader(nn.Module):
self._file = safe_open(self._checkpoint, framework="pt", device="cpu", mmap=True) self._file = safe_open(self._checkpoint, framework="pt", device="cpu", mmap=True)
self.expert_pool = self.build_meta_experts() self.expert_pool = self.build_meta_experts()
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._semaphore = threading.Semaphore(max_concurrent_loads)
def build_meta_experts(self): def build_meta_experts(self):
pool = {} pool = {}
for layer, experts in self.expert_key_index.items(): for layer, experts in self.expert_key_index.items():
@ -629,28 +632,17 @@ class LazyMoELoader(nn.Module):
getattr(model, name).data = tensor getattr(model, name).data = tensor
return model return model
async def lazy_load_from_disk(self, layer_idx, expert_idx): def _load_single_expert(self, layer_idx, expert_idx):
loop = asyncio.get_event_loop() with self._semaphore:
return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx) return self.lazy_init(layer_idx, expert_idx)
def _schedule_disk_load(self, layer_idx, expert_idx):
coro = self.lazy_load_from_disk(layer_idx, expert_idx) def schedule_layer_load(self, layer_idx, num_experts = 64):
future = asyncio.run_coroutine_threadsafe(coro, self._loop) futures = []
for i in range(num_experts):
def _on_disk_loaded(fut): coro = asyncio.get_event_loop().run_in_executor(self._executor, self._load_single_expert, layer_idx, i)
moe_cpu = fut.result() fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
def _add_cpu_in_main_thread(): futures.append(fut)
self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx) return futures
asyncio.run_coroutine_threadsafe(
self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu),
self.cache._loop
)
threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start()
future.add_done_callback(_on_disk_loaded)
return future
def enough_vram(required_bytes): def enough_vram(required_bytes):
free, total = torch.cuda.mem_get_info() free, total = torch.cuda.mem_get_info()
@ -700,16 +692,7 @@ class HunyuanMoE(nn.Module):
pass pass
else: else:
tokens_padded = dispatched_input[used_indices] tokens_padded = dispatched_input[used_indices]
l1_layers, l2_layers = [], []
for i in used_indices:
expert = self.experts[i]
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
expert = expert.result()
expert = expert.to(device)
l1_layers.append(expert.gate_and_up_proj)
l2_layers.append(expert.down_proj)
l1, l2 = [], [] l1, l2 = [], []
for i in used_indices: for i in used_indices:
expert = self.experts[i] expert = self.experts[i]
@ -720,24 +703,32 @@ class HunyuanMoE(nn.Module):
l2.append(expert.down_proj) l2.append(expert.down_proj)
compute_device = hidden_states.device compute_device = hidden_states.device
l1 = [m.to(compute_device) for m in l1]
l2 = [m.to(compute_device) for m in l2]
l1 = [m.to(compute_device) for m in l1]
W1 = torch.stack([m.weight for m in l1], dim=0) W1 = torch.stack([m.weight for m in l1], dim=0)
W2 = torch.stack([m.weight for m in l2], dim=0) del l1
W1_T = W1.transpose(1, 2) W1_T = W1.transpose(1, 2)
del W1
x = torch.bmm(tokens_padded, W1_T)
del W1_T, tokens_padded
x1, x2 = x.chunk(2, dim=2)
gated = x1 * F.silu(x2)
l2 = [m.to(compute_device) for m in l2]
W2 = torch.stack([m.weight for m in l2], dim=0)
del l2
W2_T = W2.transpose(1, 2) W2_T = W2.transpose(1, 2)
del W2
out_padded = torch.bmm(gated, W2_T)
del W2_T
while not enough_vram(3*(1024 ** 3)): while not enough_vram(3*(1024 ** 3)):
event = self.moe_lru.last_offload_event event = self.moe_lru.last_offload_event
if event is not None and not event.query(): if event is not None and not event.query():
time.sleep(0.001) time.sleep(0.001)
x = torch.bmm(tokens_padded, W1_T)
x1, x2 = x.chunk(2, dim=2)
gated = x1 * F.silu(x2)
out_padded = torch.bmm(gated, W2_T)
combine_weights_used = combine_weights[:, used_indices, :] combine_weights_used = combine_weights[:, used_indices, :]
@ -746,7 +737,7 @@ class HunyuanMoE(nn.Module):
out_padded out_padded
) )
del tokens_padded, W1, W2, W1_T, W2_T, x, x1, x2, gated, out_padded del x, x1, x2, gated, out_padded
combined_output = combined_output.reshape(bsz, seq_len, hidden_size) combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
@ -899,7 +890,8 @@ class HunyuanImage3Model(nn.Module):
self.shared_tensor = None self.shared_tensor = None
self.moe_lru = moe_lru self.moe_lru = moe_lru
self.self.additional_layers_set = False self.additional_layers_set = False
self.moe_loader = LazyMoELoader(self.moe_lru, self.config)
def forward( def forward(
self, self,
@ -922,19 +914,19 @@ class HunyuanImage3Model(nn.Module):
sparse_interval = max(1, len(self.layers) // additional_layers) sparse_interval = max(1, len(self.layers) // additional_layers)
if len(self.layers[0].mlp.experts) == 0: if len(self.layers[0].mlp.experts) == 0:
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] self.layers[0].mlp.experts = self.moe_loader.schedule_layer_load(0)
self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)]
for layer_idx, decoder_layer in enumerate(self.layers): for layer_idx, decoder_layer in enumerate(self.layers):
if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] self.layers[layer_idx+1].mlp.experts = self.moe_loader.schedule_layer_load(layer_idx + 1)
self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)]
if layer_idx + 2 < len(self.layers) and len(self.layers[layer_idx + 2].mlp.experts) == 0: # load first and second layers
self.layers[layer_idx+2].mlp.experts = self.moe_loader.schedule_layer_load(layer_idx + 2)
if not self.additional_layers_set: if not self.additional_layers_set:
if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval: if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval:
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] self.layers[next_layers].mlp.experts = self.moe_loader.schedule_layer_load(next_layers)
self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
next_layers += 1 next_layers += 1
with torch.no_grad(): with torch.no_grad():