mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-20 19:30:20 +08:00
decrease peak memory in moe forward
This commit is contained in:
parent
ae8592ebf5
commit
a3ac798d4e
@ -2,7 +2,6 @@ import os
|
|||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
import psutil
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -13,12 +12,13 @@ import torch.nn.functional as F
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from transformers.cache_utils import StaticCache
|
from transformers.cache_utils import StaticCache
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Optional, Tuple, Any, List, Dict
|
from typing import Optional, Tuple, Any, List, Dict
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock
|
from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock
|
||||||
|
|
||||||
INIT_MOE = torch.cuda.device_count() != 1
|
INIT_MOE = torch.cuda.device_count() != 1
|
||||||
MOE_LAYER_SIZE = (1024**3) * 2.65 # approx
|
MOE_LAYER_SIZE = (1024**3) * 5.15 # approx
|
||||||
|
|
||||||
class HunyuanStaticCache(StaticCache):
|
class HunyuanStaticCache(StaticCache):
|
||||||
|
|
||||||
@ -576,7 +576,7 @@ def parse_layer_expert(key):
|
|||||||
return layer, expert
|
return layer, expert
|
||||||
|
|
||||||
class LazyMoELoader(nn.Module):
|
class LazyMoELoader(nn.Module):
|
||||||
def __init__(self, cache, config):
|
def __init__(self, cache, config, max_workers = 16, max_concurrent_loads = 32):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -586,6 +586,9 @@ class LazyMoELoader(nn.Module):
|
|||||||
self._file = safe_open(self._checkpoint, framework="pt", device="cpu", mmap=True)
|
self._file = safe_open(self._checkpoint, framework="pt", device="cpu", mmap=True)
|
||||||
self.expert_pool = self.build_meta_experts()
|
self.expert_pool = self.build_meta_experts()
|
||||||
|
|
||||||
|
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
|
self._semaphore = threading.Semaphore(max_concurrent_loads)
|
||||||
|
|
||||||
def build_meta_experts(self):
|
def build_meta_experts(self):
|
||||||
pool = {}
|
pool = {}
|
||||||
for layer, experts in self.expert_key_index.items():
|
for layer, experts in self.expert_key_index.items():
|
||||||
@ -629,28 +632,17 @@ class LazyMoELoader(nn.Module):
|
|||||||
getattr(model, name).data = tensor
|
getattr(model, name).data = tensor
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def lazy_load_from_disk(self, layer_idx, expert_idx):
|
def _load_single_expert(self, layer_idx, expert_idx):
|
||||||
loop = asyncio.get_event_loop()
|
with self._semaphore:
|
||||||
return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx)
|
return self.lazy_init(layer_idx, expert_idx)
|
||||||
|
|
||||||
def _schedule_disk_load(self, layer_idx, expert_idx):
|
def schedule_layer_load(self, layer_idx, num_experts = 64):
|
||||||
|
futures = []
|
||||||
coro = self.lazy_load_from_disk(layer_idx, expert_idx)
|
for i in range(num_experts):
|
||||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
coro = asyncio.get_event_loop().run_in_executor(self._executor, self._load_single_expert, layer_idx, i)
|
||||||
|
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||||
def _on_disk_loaded(fut):
|
futures.append(fut)
|
||||||
moe_cpu = fut.result()
|
return futures
|
||||||
def _add_cpu_in_main_thread():
|
|
||||||
self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx)
|
|
||||||
|
|
||||||
asyncio.run_coroutine_threadsafe(
|
|
||||||
self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu),
|
|
||||||
self.cache._loop
|
|
||||||
)
|
|
||||||
threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start()
|
|
||||||
|
|
||||||
future.add_done_callback(_on_disk_loaded)
|
|
||||||
return future
|
|
||||||
|
|
||||||
def enough_vram(required_bytes):
|
def enough_vram(required_bytes):
|
||||||
free, total = torch.cuda.mem_get_info()
|
free, total = torch.cuda.mem_get_info()
|
||||||
@ -701,15 +693,6 @@ class HunyuanMoE(nn.Module):
|
|||||||
else:
|
else:
|
||||||
tokens_padded = dispatched_input[used_indices]
|
tokens_padded = dispatched_input[used_indices]
|
||||||
|
|
||||||
l1_layers, l2_layers = [], []
|
|
||||||
for i in used_indices:
|
|
||||||
expert = self.experts[i]
|
|
||||||
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
|
|
||||||
expert = expert.result()
|
|
||||||
expert = expert.to(device)
|
|
||||||
l1_layers.append(expert.gate_and_up_proj)
|
|
||||||
l2_layers.append(expert.down_proj)
|
|
||||||
|
|
||||||
l1, l2 = [], []
|
l1, l2 = [], []
|
||||||
for i in used_indices:
|
for i in used_indices:
|
||||||
expert = self.experts[i]
|
expert = self.experts[i]
|
||||||
@ -720,24 +703,32 @@ class HunyuanMoE(nn.Module):
|
|||||||
l2.append(expert.down_proj)
|
l2.append(expert.down_proj)
|
||||||
|
|
||||||
compute_device = hidden_states.device
|
compute_device = hidden_states.device
|
||||||
|
|
||||||
l1 = [m.to(compute_device) for m in l1]
|
l1 = [m.to(compute_device) for m in l1]
|
||||||
l2 = [m.to(compute_device) for m in l2]
|
|
||||||
|
|
||||||
W1 = torch.stack([m.weight for m in l1], dim=0)
|
W1 = torch.stack([m.weight for m in l1], dim=0)
|
||||||
W2 = torch.stack([m.weight for m in l2], dim=0)
|
del l1
|
||||||
|
|
||||||
W1_T = W1.transpose(1, 2)
|
W1_T = W1.transpose(1, 2)
|
||||||
|
|
||||||
|
del W1
|
||||||
|
x = torch.bmm(tokens_padded, W1_T)
|
||||||
|
del W1_T, tokens_padded
|
||||||
|
|
||||||
|
x1, x2 = x.chunk(2, dim=2)
|
||||||
|
gated = x1 * F.silu(x2)
|
||||||
|
|
||||||
|
l2 = [m.to(compute_device) for m in l2]
|
||||||
|
W2 = torch.stack([m.weight for m in l2], dim=0)
|
||||||
|
del l2
|
||||||
W2_T = W2.transpose(1, 2)
|
W2_T = W2.transpose(1, 2)
|
||||||
|
del W2
|
||||||
|
out_padded = torch.bmm(gated, W2_T)
|
||||||
|
del W2_T
|
||||||
|
|
||||||
while not enough_vram(3*(1024 ** 3)):
|
while not enough_vram(3*(1024 ** 3)):
|
||||||
event = self.moe_lru.last_offload_event
|
event = self.moe_lru.last_offload_event
|
||||||
if event is not None and not event.query():
|
if event is not None and not event.query():
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
|
|
||||||
x = torch.bmm(tokens_padded, W1_T)
|
|
||||||
x1, x2 = x.chunk(2, dim=2)
|
|
||||||
gated = x1 * F.silu(x2)
|
|
||||||
out_padded = torch.bmm(gated, W2_T)
|
|
||||||
|
|
||||||
combine_weights_used = combine_weights[:, used_indices, :]
|
combine_weights_used = combine_weights[:, used_indices, :]
|
||||||
|
|
||||||
@ -746,7 +737,7 @@ class HunyuanMoE(nn.Module):
|
|||||||
out_padded
|
out_padded
|
||||||
)
|
)
|
||||||
|
|
||||||
del tokens_padded, W1, W2, W1_T, W2_T, x, x1, x2, gated, out_padded
|
del x, x1, x2, gated, out_padded
|
||||||
|
|
||||||
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
||||||
|
|
||||||
@ -899,7 +890,8 @@ class HunyuanImage3Model(nn.Module):
|
|||||||
|
|
||||||
self.shared_tensor = None
|
self.shared_tensor = None
|
||||||
self.moe_lru = moe_lru
|
self.moe_lru = moe_lru
|
||||||
self.self.additional_layers_set = False
|
self.additional_layers_set = False
|
||||||
|
self.moe_loader = LazyMoELoader(self.moe_lru, self.config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -922,19 +914,19 @@ class HunyuanImage3Model(nn.Module):
|
|||||||
sparse_interval = max(1, len(self.layers) // additional_layers)
|
sparse_interval = max(1, len(self.layers) // additional_layers)
|
||||||
|
|
||||||
if len(self.layers[0].mlp.experts) == 0:
|
if len(self.layers[0].mlp.experts) == 0:
|
||||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
self.layers[0].mlp.experts = self.moe_loader.schedule_layer_load(0)
|
||||||
self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)]
|
|
||||||
|
|
||||||
for layer_idx, decoder_layer in enumerate(self.layers):
|
for layer_idx, decoder_layer in enumerate(self.layers):
|
||||||
|
|
||||||
if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded
|
if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded
|
||||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
self.layers[layer_idx+1].mlp.experts = self.moe_loader.schedule_layer_load(layer_idx + 1)
|
||||||
self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)]
|
|
||||||
|
if layer_idx + 2 < len(self.layers) and len(self.layers[layer_idx + 2].mlp.experts) == 0: # load first and second layers
|
||||||
|
self.layers[layer_idx+2].mlp.experts = self.moe_loader.schedule_layer_load(layer_idx + 2)
|
||||||
|
|
||||||
if not self.additional_layers_set:
|
if not self.additional_layers_set:
|
||||||
if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval:
|
if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval:
|
||||||
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
self.layers[next_layers].mlp.experts = self.moe_loader.schedule_layer_load(next_layers)
|
||||||
self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
|
|
||||||
next_layers += 1
|
next_layers += 1
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user