mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
decrease peak memory in moe forward
This commit is contained in:
parent
ae8592ebf5
commit
a3ac798d4e
@ -2,7 +2,6 @@ import os
|
||||
import math
|
||||
import time
|
||||
import torch
|
||||
import psutil
|
||||
import asyncio
|
||||
import threading
|
||||
import torch.nn as nn
|
||||
@ -13,12 +12,13 @@ import torch.nn.functional as F
|
||||
from collections import OrderedDict
|
||||
from safetensors import safe_open
|
||||
from transformers.cache_utils import StaticCache
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional, Tuple, Any, List, Dict
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock
|
||||
|
||||
INIT_MOE = torch.cuda.device_count() != 1
|
||||
MOE_LAYER_SIZE = (1024**3) * 2.65 # approx
|
||||
MOE_LAYER_SIZE = (1024**3) * 5.15 # approx
|
||||
|
||||
class HunyuanStaticCache(StaticCache):
|
||||
|
||||
@ -576,7 +576,7 @@ def parse_layer_expert(key):
|
||||
return layer, expert
|
||||
|
||||
class LazyMoELoader(nn.Module):
|
||||
def __init__(self, cache, config):
|
||||
def __init__(self, cache, config, max_workers = 16, max_concurrent_loads = 32):
|
||||
super().__init__()
|
||||
self.cache = cache
|
||||
self.config = config
|
||||
@ -586,6 +586,9 @@ class LazyMoELoader(nn.Module):
|
||||
self._file = safe_open(self._checkpoint, framework="pt", device="cpu", mmap=True)
|
||||
self.expert_pool = self.build_meta_experts()
|
||||
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self._semaphore = threading.Semaphore(max_concurrent_loads)
|
||||
|
||||
def build_meta_experts(self):
|
||||
pool = {}
|
||||
for layer, experts in self.expert_key_index.items():
|
||||
@ -629,28 +632,17 @@ class LazyMoELoader(nn.Module):
|
||||
getattr(model, name).data = tensor
|
||||
return model
|
||||
|
||||
async def lazy_load_from_disk(self, layer_idx, expert_idx):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx)
|
||||
|
||||
def _schedule_disk_load(self, layer_idx, expert_idx):
|
||||
def _load_single_expert(self, layer_idx, expert_idx):
|
||||
with self._semaphore:
|
||||
return self.lazy_init(layer_idx, expert_idx)
|
||||
|
||||
coro = self.lazy_load_from_disk(layer_idx, expert_idx)
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
|
||||
def _on_disk_loaded(fut):
|
||||
moe_cpu = fut.result()
|
||||
def _add_cpu_in_main_thread():
|
||||
self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu),
|
||||
self.cache._loop
|
||||
)
|
||||
threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start()
|
||||
|
||||
future.add_done_callback(_on_disk_loaded)
|
||||
return future
|
||||
def schedule_layer_load(self, layer_idx, num_experts = 64):
|
||||
futures = []
|
||||
for i in range(num_experts):
|
||||
coro = asyncio.get_event_loop().run_in_executor(self._executor, self._load_single_expert, layer_idx, i)
|
||||
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
futures.append(fut)
|
||||
return futures
|
||||
|
||||
def enough_vram(required_bytes):
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
@ -700,16 +692,7 @@ class HunyuanMoE(nn.Module):
|
||||
pass
|
||||
else:
|
||||
tokens_padded = dispatched_input[used_indices]
|
||||
|
||||
l1_layers, l2_layers = [], []
|
||||
for i in used_indices:
|
||||
expert = self.experts[i]
|
||||
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
|
||||
expert = expert.result()
|
||||
expert = expert.to(device)
|
||||
l1_layers.append(expert.gate_and_up_proj)
|
||||
l2_layers.append(expert.down_proj)
|
||||
|
||||
|
||||
l1, l2 = [], []
|
||||
for i in used_indices:
|
||||
expert = self.experts[i]
|
||||
@ -720,24 +703,32 @@ class HunyuanMoE(nn.Module):
|
||||
l2.append(expert.down_proj)
|
||||
|
||||
compute_device = hidden_states.device
|
||||
l1 = [m.to(compute_device) for m in l1]
|
||||
l2 = [m.to(compute_device) for m in l2]
|
||||
|
||||
l1 = [m.to(compute_device) for m in l1]
|
||||
W1 = torch.stack([m.weight for m in l1], dim=0)
|
||||
W2 = torch.stack([m.weight for m in l2], dim=0)
|
||||
|
||||
del l1
|
||||
W1_T = W1.transpose(1, 2)
|
||||
|
||||
del W1
|
||||
x = torch.bmm(tokens_padded, W1_T)
|
||||
del W1_T, tokens_padded
|
||||
|
||||
x1, x2 = x.chunk(2, dim=2)
|
||||
gated = x1 * F.silu(x2)
|
||||
|
||||
l2 = [m.to(compute_device) for m in l2]
|
||||
W2 = torch.stack([m.weight for m in l2], dim=0)
|
||||
del l2
|
||||
W2_T = W2.transpose(1, 2)
|
||||
del W2
|
||||
out_padded = torch.bmm(gated, W2_T)
|
||||
del W2_T
|
||||
|
||||
while not enough_vram(3*(1024 ** 3)):
|
||||
event = self.moe_lru.last_offload_event
|
||||
if event is not None and not event.query():
|
||||
time.sleep(0.001)
|
||||
|
||||
x = torch.bmm(tokens_padded, W1_T)
|
||||
x1, x2 = x.chunk(2, dim=2)
|
||||
gated = x1 * F.silu(x2)
|
||||
out_padded = torch.bmm(gated, W2_T)
|
||||
|
||||
combine_weights_used = combine_weights[:, used_indices, :]
|
||||
|
||||
@ -746,7 +737,7 @@ class HunyuanMoE(nn.Module):
|
||||
out_padded
|
||||
)
|
||||
|
||||
del tokens_padded, W1, W2, W1_T, W2_T, x, x1, x2, gated, out_padded
|
||||
del x, x1, x2, gated, out_padded
|
||||
|
||||
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
||||
|
||||
@ -899,7 +890,8 @@ class HunyuanImage3Model(nn.Module):
|
||||
|
||||
self.shared_tensor = None
|
||||
self.moe_lru = moe_lru
|
||||
self.self.additional_layers_set = False
|
||||
self.additional_layers_set = False
|
||||
self.moe_loader = LazyMoELoader(self.moe_lru, self.config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -922,19 +914,19 @@ class HunyuanImage3Model(nn.Module):
|
||||
sparse_interval = max(1, len(self.layers) // additional_layers)
|
||||
|
||||
if len(self.layers[0].mlp.experts) == 0:
|
||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||
self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)]
|
||||
self.layers[0].mlp.experts = self.moe_loader.schedule_layer_load(0)
|
||||
|
||||
for layer_idx, decoder_layer in enumerate(self.layers):
|
||||
|
||||
if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded
|
||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||
self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)]
|
||||
self.layers[layer_idx+1].mlp.experts = self.moe_loader.schedule_layer_load(layer_idx + 1)
|
||||
|
||||
if layer_idx + 2 < len(self.layers) and len(self.layers[layer_idx + 2].mlp.experts) == 0: # load first and second layers
|
||||
self.layers[layer_idx+2].mlp.experts = self.moe_loader.schedule_layer_load(layer_idx + 2)
|
||||
|
||||
if not self.additional_layers_set:
|
||||
if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval:
|
||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||
self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
|
||||
self.layers[next_layers].mlp.experts = self.moe_loader.schedule_layer_load(next_layers)
|
||||
next_layers += 1
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user