mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 21:00:16 +08:00
async cache revamp
Added an async loading and offloading of moe layers, having consistent memory with oom errors. Used to give oom error after the third layer with 24 giga bytes gpu, now goes to the end with consistent memory with minimal latency
This commit is contained in:
parent
44346c4251
commit
7b4c1e8031
@ -1,15 +1,17 @@
|
|||||||
import os
|
import os
|
||||||
import gc
|
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
import torch
|
import torch
|
||||||
import psutil
|
import psutil
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import concurrent.futures
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from contextlib import contextmanager
|
|
||||||
from transformers.cache_utils import StaticCache
|
from transformers.cache_utils import StaticCache
|
||||||
from typing import Optional, Tuple, Any, List, Dict
|
from typing import Optional, Tuple, Any, List, Dict
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
@ -19,13 +21,13 @@ INIT_MOE = torch.cuda.device_count() != 1
|
|||||||
|
|
||||||
if not INIT_MOE:
|
if not INIT_MOE:
|
||||||
MOE_LAYER_SIZE = (1024**3) * 2.65 # approx
|
MOE_LAYER_SIZE = (1024**3) * 2.65 # approx
|
||||||
CPU_MOE_RATIO = None
|
|
||||||
|
|
||||||
torch.cuda.set_device(0)
|
torch.cuda.set_device(0)
|
||||||
props = torch.cuda.get_device_properties(0)
|
props = torch.cuda.get_device_properties(0)
|
||||||
|
|
||||||
INIT_CUDA_MEM = (props.total_memory - torch.cuda.memory_reserved()) * 0.9
|
LAYERS_IN_CPU = math.floor((int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES'))
|
||||||
ADDITIONAL_LAYERS_IN_GPU = math.floor(INIT_CUDA_MEM / MOE_LAYER_SIZE)
|
- psutil.Process(os.getpid()).memory_info().rss
|
||||||
|
- (2*1024**3)) * 0.50) / MOE_LAYER_SIZE)
|
||||||
|
|
||||||
class HunyuanStaticCache(StaticCache):
|
class HunyuanStaticCache(StaticCache):
|
||||||
|
|
||||||
@ -286,20 +288,15 @@ def topkgating(
|
|||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
gates = F.softmax(logits, dim=1)
|
gates = F.softmax(logits, dim=1)
|
||||||
|
|
||||||
extra = ADDITIONAL_LAYERS_IN_GPU
|
values_all, indices_all = torch.topk(gates, topk, dim=1)
|
||||||
|
|
||||||
values_all, indices_all = torch.topk(gates, topk + extra, dim=1)
|
|
||||||
expert_weight = values_all[:, :topk]
|
expert_weight = values_all[:, :topk]
|
||||||
expert_index = indices_all[:, :topk]
|
expert_index = indices_all[:, :topk]
|
||||||
|
|
||||||
_, cpu_expert_index = torch.topk(gates, int(CPU_MOE_RATIO * 64), dim = 1)
|
|
||||||
cpu_expert_index = cpu_expert_index[:, (8 + ADDITIONAL_LAYERS_IN_GPU):]
|
|
||||||
|
|
||||||
if norm_topk_prob and topk > 1:
|
if norm_topk_prob and topk > 1:
|
||||||
denom = expert_weight.sum(dim=1, keepdim=True).clamp_min(torch.finfo(gates.dtype).eps)
|
denom = expert_weight.sum(dim=1, keepdim=True).clamp_min(torch.finfo(gates.dtype).eps)
|
||||||
expert_weight = expert_weight / denom
|
expert_weight = expert_weight / denom
|
||||||
|
|
||||||
return expert_weight, expert_index, cpu_expert_index, indices_all
|
return expert_weight, expert_index
|
||||||
|
|
||||||
class HunyuanRMSNorm(nn.Module):
|
class HunyuanRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
@ -452,7 +449,7 @@ class HunyuanTopKGate(nn.Module):
|
|||||||
return gate_output
|
return gate_output
|
||||||
|
|
||||||
class HunyuanMLP(nn.Module):
|
class HunyuanMLP(nn.Module):
|
||||||
def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False):
|
def __init__(self, config, layer_idx=None, is_shared_mlp=False, is_moe=False, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
@ -462,8 +459,8 @@ class HunyuanMLP(nn.Module):
|
|||||||
|
|
||||||
self.act_fn = torch.nn.functional.silu
|
self.act_fn = torch.nn.functional.silu
|
||||||
self.intermediate_size *= 2 # SwiGLU
|
self.intermediate_size *= 2 # SwiGLU
|
||||||
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False, device=device)
|
||||||
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False)
|
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False, device=device)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
|
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
|
||||||
if x.ndim == 2:
|
if x.ndim == 2:
|
||||||
@ -474,204 +471,143 @@ class HunyuanMLP(nn.Module):
|
|||||||
return down_proj
|
return down_proj
|
||||||
|
|
||||||
class MoELRUCache(nn.Module):
|
class MoELRUCache(nn.Module):
|
||||||
def __init__(self, cpu_mem: int = 50, safety_buffer_bytes = 3*(1024**3), max_gpu_eviction_attempts = 8):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
global CPU_MOE_RATIO
|
|
||||||
|
|
||||||
_, total = torch.cuda.mem_get_info()
|
|
||||||
max_gpu_mem_gb = max((total - 2 * safety_buffer_bytes) / (1024**3), 1)
|
|
||||||
|
|
||||||
self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3)
|
|
||||||
self.MAX_CPU_MEM = int(cpu_mem * 1024**3)
|
|
||||||
self.gpu_cache = OrderedDict()
|
self.gpu_cache = OrderedDict()
|
||||||
self.cpu_cache = OrderedDict()
|
self.cpu_cache = OrderedDict()
|
||||||
|
self.offload_stream = torch.cuda.Stream()
|
||||||
|
self.load_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
self.gpu_mem_usage = 0
|
self.last_offload_event = None
|
||||||
self.cpu_mem_usage = 0
|
self._loop = asyncio.new_event_loop()
|
||||||
# 50% for system and headroom
|
threading.Thread(target=self._loop.run_forever, daemon=True).start()
|
||||||
try:
|
|
||||||
self.MAX_CPU_MEM = int((os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES'))
|
|
||||||
- psutil.Process(os.getpid()).memory_info().rss
|
|
||||||
- safety_buffer_bytes) * 0.55
|
|
||||||
except:
|
|
||||||
self.MAX_CPU_MEM = int(cpu_mem * (1024**3) * 0.5) # TODO
|
|
||||||
|
|
||||||
ADDITIONAL_LAYERS_IN_CPU = math.floor((50 * (1024**3)) / MOE_LAYER_SIZE)
|
async def _async_offload_to_cpu(self, layer_idx):
|
||||||
CPU_MOE_RATIO = (min(64 - ADDITIONAL_LAYERS_IN_GPU, ADDITIONAL_LAYERS_IN_CPU)) / 64
|
# async offload from gpu (removed)
|
||||||
|
|
||||||
self.MAX_GPU_MEM = int(max_gpu_mem_gb * 1024**3)
|
num_experts = 64
|
||||||
self.SAFETY_BUFFER = int(safety_buffer_bytes)
|
moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i])
|
||||||
self.MAX_GPU_EVICT_ATTEMPTS = max_gpu_eviction_attempts
|
for i in range(num_experts)
|
||||||
|
if (layer_idx * num_experts + i) in self.gpu_cache]
|
||||||
|
event = torch.cuda.Event()
|
||||||
|
|
||||||
def _gpu_free_bytes(self):
|
with torch.cuda.stream(self.offload_stream):
|
||||||
free, total = torch.cuda.mem_get_info()
|
for index, moe in moe_group:
|
||||||
return int(free)
|
moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True)
|
||||||
|
for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()):
|
||||||
|
if p_gpu.device.type == "meta":
|
||||||
|
continue
|
||||||
|
with torch.no_grad():
|
||||||
|
p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True)
|
||||||
|
p_cpu.copy_(p_gpu, non_blocking=True)
|
||||||
|
|
||||||
def _estimate_size(self, moe):
|
self.cpu_cache[index] = moe_cpu
|
||||||
# include parameters + buffers
|
|
||||||
size = 0
|
|
||||||
for p in moe.parameters():
|
|
||||||
size += p.numel() * p.element_size()
|
|
||||||
for b in moe.buffers():
|
|
||||||
size += b.numel() * b.element_size()
|
|
||||||
return int(size)
|
|
||||||
|
|
||||||
def _evict_until_free(self, required_bytes, max_attempts=16):
|
self.offload_stream.record_event(event)
|
||||||
attempts = 0
|
|
||||||
while self._gpu_free_bytes() < required_bytes and attempts < max_attempts:
|
|
||||||
evicted = self._evict_from_gpu()
|
|
||||||
if not evicted:
|
|
||||||
break
|
|
||||||
attempts += 1
|
|
||||||
return self._gpu_free_bytes() >= required_bytes
|
|
||||||
|
|
||||||
@contextmanager
|
self.last_offload_event = event
|
||||||
def ensure_headroom(self, required_bytes):
|
|
||||||
|
|
||||||
safety = getattr(self, "SAFETY_BUFFER", 0)
|
def finalize_offload_layer():
|
||||||
target_free = int(required_bytes + safety)
|
event.synchronize()
|
||||||
|
for index, moe in moe_group:
|
||||||
|
moe.to("meta")
|
||||||
|
self.gpu_cache.pop(index, None)
|
||||||
|
del moe
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if getattr(self, "_headroom", None) is not None:
|
threading.Thread(target=finalize_offload_layer, daemon=True).start()
|
||||||
try:
|
|
||||||
del self._headroom
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._headroom = None
|
|
||||||
|
|
||||||
ok = self._evict_until_free(target_free)
|
async def _async_load_to_gpu(self, index, moe):
|
||||||
if not ok and self._gpu_free_bytes() < target_free:
|
|
||||||
# last ditch
|
|
||||||
try:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
# if enough memory load, otherwise wait for offload
|
||||||
yield
|
while True:
|
||||||
finally:
|
free_bytes, _ = torch.cuda.mem_get_info()
|
||||||
if getattr(self, "_headroom", None) is None:
|
if free_bytes > 2 * MOE_LAYER_SIZE:
|
||||||
try:
|
|
||||||
self._headroom = torch.empty((self._headroom_bytes,), dtype=torch.uint8, device="cuda:0")
|
|
||||||
except Exception:
|
|
||||||
self._headroom = None
|
|
||||||
|
|
||||||
def add_gpu(self, moe, index, allowed_retries=3):
|
|
||||||
size = self._estimate_size(moe)
|
|
||||||
|
|
||||||
while self.gpu_mem_usage + size > self.MAX_GPU_MEM:
|
|
||||||
if not self._evict_from_gpu():
|
|
||||||
break
|
break
|
||||||
|
|
||||||
attempts = 0
|
self.last_offload_event.synchronize()
|
||||||
while self._gpu_free_bytes() < size + self.SAFETY_BUFFER and attempts < self.MAX_GPU_EVICT_ATTEMPTS:
|
torch.cuda.empty_cache()
|
||||||
if not self._evict_from_gpu():
|
await asyncio.sleep(0.01)
|
||||||
break
|
|
||||||
attempts += 1
|
|
||||||
|
|
||||||
for _ in range(allowed_retries):
|
# async loading from cpu -> gpu
|
||||||
try:
|
with torch.cuda.stream(self.load_stream):
|
||||||
moe_cuda = moe.to("cuda:0")
|
moe_gpu = HunyuanMLP(moe.config).to("cuda", non_blocking=True)
|
||||||
break
|
for (name, p_cpu), p_gpu in zip(moe.named_parameters(), moe_gpu.parameters()):
|
||||||
except RuntimeError as e:
|
with torch.no_grad():
|
||||||
if "out of memory" not in str(e).lower():
|
p_gpu.data = torch.empty_like(p_cpu, device="cuda")
|
||||||
raise
|
p_gpu.copy_(p_cpu, non_blocking=True)
|
||||||
evicted = self._evict_from_gpu()
|
|
||||||
if not evicted: # can't evict
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Failed to move expert to GPU after evictions")
|
|
||||||
|
|
||||||
self.gpu_cache[index] = moe_cuda
|
def finalize_load():
|
||||||
self.gpu_cache.move_to_end(index)
|
self.gpu_cache[index] = moe_gpu
|
||||||
self.gpu_mem_usage += size
|
self.cpu_cache.pop(index, None)
|
||||||
|
|
||||||
return
|
threading.Thread(target=finalize_load, daemon=True).start()
|
||||||
|
|
||||||
def add_cpu(self, moe, index):
|
def add_cpu(self, moe, index):
|
||||||
size = self._estimate_size(moe)
|
|
||||||
while self.cpu_mem_usage + size > self.MAX_CPU_MEM:
|
|
||||||
if not self._evict_from_cpu():
|
|
||||||
break
|
|
||||||
moe_cpu = moe.to("cpu")
|
moe_cpu = moe.to("cpu")
|
||||||
|
|
||||||
|
for _, p in moe_cpu.named_parameters():
|
||||||
|
if not p.is_pinned():
|
||||||
|
p.data = p.data.pin_memory()
|
||||||
|
|
||||||
self.cpu_cache[index] = moe_cpu
|
self.cpu_cache[index] = moe_cpu
|
||||||
self.cpu_cache.move_to_end(index)
|
self.cpu_cache.move_to_end(index)
|
||||||
self.cpu_mem_usage += size
|
|
||||||
|
|
||||||
def get_from_device(self, index):
|
|
||||||
if index in self.gpu_cache:
|
|
||||||
moe = self.gpu_cache[index]
|
|
||||||
self.gpu_cache.move_to_end(index)
|
|
||||||
return moe
|
|
||||||
if index in self.cpu_cache:
|
|
||||||
moe = self.cpu_cache.pop(index)
|
|
||||||
self.cpu_mem_usage = max(0, self.cpu_mem_usage - self._estimate_size(moe))
|
|
||||||
try:
|
|
||||||
self.add_gpu(moe, index)
|
|
||||||
return self.gpu_cache[index]
|
|
||||||
except RuntimeError:
|
|
||||||
self.cpu_cache[index] = moe
|
|
||||||
self.cpu_cache.move_to_end(index)
|
|
||||||
self.cpu_mem_usage += self._estimate_size(moe)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return None # load from disk
|
|
||||||
|
|
||||||
def _evict_from_gpu(self):
|
|
||||||
if not self.gpu_cache:
|
|
||||||
return False
|
|
||||||
|
|
||||||
idx, moe = self.gpu_cache.popitem(last=False)
|
|
||||||
size = self._estimate_size(moe)
|
|
||||||
self.gpu_mem_usage = max(0, self.gpu_mem_usage - size)
|
|
||||||
|
|
||||||
if self.cpu_mem_usage + size <= self.MAX_CPU_MEM:
|
|
||||||
try:
|
|
||||||
moe_cpu = moe.to("cpu")
|
|
||||||
except Exception:
|
|
||||||
# drop the model if cpu is full
|
|
||||||
del moe
|
|
||||||
return True
|
|
||||||
self.cpu_cache[idx] = moe_cpu
|
|
||||||
self.cpu_cache.move_to_end(idx)
|
|
||||||
self.cpu_mem_usage += size
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
del moe
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _evict_from_cpu(self):
|
|
||||||
if not self.cpu_cache:
|
|
||||||
return False
|
|
||||||
_, moe = self.cpu_cache.popitem(last=False)
|
|
||||||
size = self._estimate_size(moe)
|
|
||||||
self.cpu_mem_usage = max(0, self.cpu_mem_usage - size)
|
|
||||||
del moe
|
|
||||||
gc.collect()
|
|
||||||
return True
|
|
||||||
|
|
||||||
class LazyMoELoader(nn.Module):
|
class LazyMoELoader(nn.Module):
|
||||||
def __init__(self, device):
|
def __init__(self, cache, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.cache = cache
|
||||||
|
self.config = config
|
||||||
|
self._loop = cache._loop
|
||||||
|
|
||||||
def lazy_init(self, config, layer_idx, expert_idx):
|
def get_checkpoint(self):
|
||||||
comfyui_dir = Path.home() / "ComfyUI"
|
comfyui_dir = Path.home() / "ComfyUI"
|
||||||
checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors"
|
checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors"
|
||||||
checkpoint = checkpoint.resolve()
|
checkpoint = checkpoint.resolve()
|
||||||
if not os.path.exists(checkpoint):
|
if not os.path.exists(checkpoint):
|
||||||
raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}")
|
raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}")
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
def lazy_init(self, layer_idx, expert_idx):
|
||||||
|
checkpoint = self.get_checkpoint()
|
||||||
prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}."
|
prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}."
|
||||||
additional_prefix = f"model.layers.{layer_idx}.mlp.gate_and_up_proj.weight"
|
additional_prefix = f"model.layers.{layer_idx}.mlp.gate_and_up_proj.weight"
|
||||||
sd = {}
|
sd = {}
|
||||||
|
|
||||||
with safe_open(checkpoint, framework="pt", device=self.device) as f:
|
with safe_open(checkpoint, framework="pt", device="cpu") as f:
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
if k.startswith(prefix) or k.startswith(additional_prefix):
|
if k.startswith(prefix) or k.startswith(additional_prefix):
|
||||||
new_k = k.split(f"experts.{expert_idx}.", 1)[1]
|
new_k = k.split(f"experts.{expert_idx}.", 1)[1]
|
||||||
sd[new_k] = f.get_tensor(k)
|
sd[new_k] = f.get_tensor(k)
|
||||||
|
|
||||||
return HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd).to(self.deivce)
|
return HunyuanMLP(self.config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True).load_state_dict(sd)
|
||||||
|
|
||||||
|
async def lazy_load_from_disk(self, layer_idx, expert_idx):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx)
|
||||||
|
|
||||||
|
def _schedule_disk_load(self, layer_idx, expert_idx):
|
||||||
|
|
||||||
|
coro = self.lazy_load_from_disk(layer_idx, expert_idx)
|
||||||
|
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||||
|
|
||||||
|
def _on_disk_loaded(fut):
|
||||||
|
moe_cpu = fut.result()
|
||||||
|
def _add_cpu_in_main_thread():
|
||||||
|
self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx)
|
||||||
|
|
||||||
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu),
|
||||||
|
self.cache._loop
|
||||||
|
)
|
||||||
|
threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start()
|
||||||
|
|
||||||
|
future.add_done_callback(_on_disk_loaded)
|
||||||
|
return future
|
||||||
|
|
||||||
|
def enough_vram(required_bytes):
|
||||||
|
free, total = torch.cuda.mem_get_info()
|
||||||
|
return free > required_bytes
|
||||||
|
|
||||||
class HunyuanMoE(nn.Module):
|
class HunyuanMoE(nn.Module):
|
||||||
def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None):
|
def __init__(self, config, layer_idx: Optional[int] = None, moe_lru=None):
|
||||||
@ -687,7 +623,7 @@ class HunyuanMoE(nn.Module):
|
|||||||
[HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)]
|
[HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.experts = None
|
self.experts = []
|
||||||
self.moe_lru = moe_lru
|
self.moe_lru = moe_lru
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
@ -702,38 +638,10 @@ class HunyuanMoE(nn.Module):
|
|||||||
reshaped_input = hidden_states.reshape(-1, hidden_size)
|
reshaped_input = hidden_states.reshape(-1, hidden_size)
|
||||||
|
|
||||||
with torch.cuda.nvtx.range("MoE"):
|
with torch.cuda.nvtx.range("MoE"):
|
||||||
expert_weight, expert_index, cpu_expert_index, indices_all = self.gate(hidden_states)
|
expert_weight, expert_index = self.gate(hidden_states)
|
||||||
if not INIT_MOE:
|
|
||||||
if ADDITIONAL_LAYERS_IN_GPU > 0:
|
|
||||||
additional_expert_index = indices_all[:, expert_index.size(1): expert_index.size(1) + ADDITIONAL_LAYERS_IN_GPU]
|
|
||||||
|
|
||||||
flat = additional_expert_index.reshape(-1).to("cpu")
|
|
||||||
counts = torch.bincount(flat, minlength=self.num_experts)
|
|
||||||
top_extra = torch.topk(counts, k=min(ADDITIONAL_LAYERS_IN_GPU, (counts>0).sum().item())).indices.tolist()
|
|
||||||
|
|
||||||
for expert_id in top_extra:
|
|
||||||
if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None:
|
|
||||||
expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id)
|
|
||||||
self.moe_lru.add_gpu(expert_cpu, expert_id + self.layer_idx)
|
|
||||||
|
|
||||||
if cpu_expert_index is not None and cpu_expert_index.numel() > 0:
|
|
||||||
for expert_id in torch.unique(cpu_expert_index).cpu().tolist():
|
|
||||||
if self.moe_lru.get_from_device(expert_id + self.layer_idx) is None:
|
|
||||||
expert_cpu = LazyMoELoader(device="cpu").lazy_init(self.config, self.layer_idx, expert_id)
|
|
||||||
self.moe_lru.add_cpu(expert_cpu, expert_id + self.layer_idx)
|
|
||||||
|
|
||||||
combined_output = torch.zeros_like(reshaped_input)
|
combined_output = torch.zeros_like(reshaped_input)
|
||||||
experts_list = []
|
experts_list = [(i, expert) for i, expert in enumerate(self.experts)]
|
||||||
for e in range(self.num_experts):
|
|
||||||
token_mask = (expert_index == e)
|
|
||||||
if not token_mask.any():
|
|
||||||
continue
|
|
||||||
expert = self.moe_lru.get_from_device(e + self.layer_idx)
|
|
||||||
if expert is None:
|
|
||||||
expert = LazyMoELoader()
|
|
||||||
expert = expert.lazy_init(self.config, self.layer_idx, e)
|
|
||||||
self.moe_lru.add_gpu(expert, e + self.layer_idx)
|
|
||||||
experts_list.append((e, expert))
|
|
||||||
|
|
||||||
per_pos, per_tokens, per_weights = [], [], []
|
per_pos, per_tokens, per_weights = [], [], []
|
||||||
for e, _ in experts_list:
|
for e, _ in experts_list:
|
||||||
@ -761,6 +669,8 @@ class HunyuanMoE(nn.Module):
|
|||||||
|
|
||||||
l1, l2 = [], []
|
l1, l2 = [], []
|
||||||
for _, expert in experts_list:
|
for _, expert in experts_list:
|
||||||
|
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
|
||||||
|
expert = expert.result()
|
||||||
l1.append(expert.gate_and_up_proj)
|
l1.append(expert.gate_and_up_proj)
|
||||||
l2.append(expert.down_proj)
|
l2.append(expert.down_proj)
|
||||||
|
|
||||||
@ -770,6 +680,12 @@ class HunyuanMoE(nn.Module):
|
|||||||
W1_T = W1.transpose(1, 2)
|
W1_T = W1.transpose(1, 2)
|
||||||
W2_T = W2.transpose(1, 2)
|
W2_T = W2.transpose(1, 2)
|
||||||
|
|
||||||
|
# wait for enough vram for the computations
|
||||||
|
while not enough_vram(5*(1024 ** 3)):
|
||||||
|
event = self.moe_lru.last_offload_event
|
||||||
|
if event is not None and not event.query():
|
||||||
|
time.sleep(0.001)
|
||||||
|
|
||||||
x = torch.bmm(tokens_padded, W1_T)
|
x = torch.bmm(tokens_padded, W1_T)
|
||||||
x = F.silu(x)
|
x = F.silu(x)
|
||||||
|
|
||||||
@ -781,7 +697,7 @@ class HunyuanMoE(nn.Module):
|
|||||||
for i, token_positions in enumerate(per_pos):
|
for i, token_positions in enumerate(per_pos):
|
||||||
Ni = lengths[i]
|
Ni = lengths[i]
|
||||||
out_i = out_padded[i, :Ni]
|
out_i = out_padded[i, :Ni]
|
||||||
combined_output.index_add_(0, token_positions.to(hidden_states.device), out_i)
|
combined_output.to(hidden_states.dtype).index_add_(0, token_positions.to(hidden_states.device), out_i.to(hidden_states.dtype))
|
||||||
|
|
||||||
#dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
|
#dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
|
||||||
#chunks = dispatched_input.chunk(self.num_experts, dim=0)
|
#chunks = dispatched_input.chunk(self.num_experts, dim=0)
|
||||||
@ -933,6 +849,7 @@ class HunyuanImage3Model(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.padding_idx = 128009
|
self.padding_idx = 128009
|
||||||
self.vocab_size = 133120
|
self.vocab_size = 133120
|
||||||
|
self.config = config
|
||||||
self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx)
|
self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])]
|
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])]
|
||||||
@ -941,6 +858,7 @@ class HunyuanImage3Model(nn.Module):
|
|||||||
self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
|
self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
|
||||||
|
|
||||||
self.shared_tensor = None
|
self.shared_tensor = None
|
||||||
|
self.moe_lru = moe_lru
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -962,19 +880,45 @@ class HunyuanImage3Model(nn.Module):
|
|||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
next_decoder_cache = None
|
next_decoder_cache = None
|
||||||
|
next_layers = 0
|
||||||
|
sparse_interval = max(1, len(self.layers) // 3)
|
||||||
|
|
||||||
|
if len(self.layers[0].mlp.experts) == 0:
|
||||||
|
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||||
|
self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)]
|
||||||
|
|
||||||
for layer_idx, decoder_layer in enumerate(self.layers):
|
for layer_idx, decoder_layer in enumerate(self.layers):
|
||||||
|
|
||||||
layer_outputs = decoder_layer(
|
if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded
|
||||||
hidden_states,
|
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||||
attention_mask=attention_mask,
|
self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)]
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval:
|
||||||
use_cache=use_cache,
|
if len(self.layers[next_layers].mlp.experts) > 0: # for testing
|
||||||
custom_pos_emb=custom_pos_emb,
|
raise ValueError("Problem with offloading")
|
||||||
mode=mode,
|
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
|
||||||
first_step=first_step,
|
self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
|
||||||
gen_timestep_scatter_index=gen_timestep_scatter_index,
|
next_layers += 1
|
||||||
)
|
|
||||||
|
with torch.no_grad():
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
custom_pos_emb=custom_pos_emb,
|
||||||
|
mode=mode,
|
||||||
|
first_step=first_step,
|
||||||
|
gen_timestep_scatter_index=gen_timestep_scatter_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer_idx >= 0:
|
||||||
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
self.moe_lru._async_offload_to_cpu(layer_idx),
|
||||||
|
self.moe_lru._loop
|
||||||
|
)
|
||||||
|
self.layers[layer_idx].mlp.experts = []
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user