zero-copy and optimized moe loader

This commit is contained in:
Yousef Rafat 2025-11-22 12:37:10 +02:00
parent 4d982e83f6
commit ae8592ebf5

View File

@ -568,6 +568,12 @@ class MoELRUCache(nn.Module):
self.cpu_cache[index] = moe_cpu
self.cpu_cache.move_to_end(index)
def parse_layer_expert(key):
parts = key.split(".")
layer = int(parts[2])
expert = int(parts[5])
return layer, expert
class LazyMoELoader(nn.Module):
def __init__(self, cache, config):
@ -575,6 +581,22 @@ class LazyMoELoader(nn.Module):
self.cache = cache
self.config = config
self._loop = cache._loop
self.expert_key_index = self.index_safetensors()
self._checkpoint = self.get_checkpoint()
self._file = safe_open(self._checkpoint, framework="pt", device="cpu", mmap=True)
self.expert_pool = self.build_meta_experts()
def build_meta_experts(self):
pool = {}
for layer, experts in self.expert_key_index.items():
pool[layer] = {}
for expert in experts:
pool[layer][expert] = HunyuanMLP(
self.config,
layer_idx=layer,
device="meta",
)
return pool
def get_checkpoint(self):
comfyui_dir = Path.home() / "ComfyUI"
@ -583,23 +605,28 @@ class LazyMoELoader(nn.Module):
if not os.path.exists(checkpoint):
raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}")
return checkpoint
def lazy_init(self, layer_idx, expert_idx):
checkpoint = self.get_checkpoint()
prefix = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}."
additional_prefix = f"model.layers.{layer_idx}.mlp.gate_and_up_proj.weight"
sd = {}
def index_safetensors(self):
checkpoint = self.get_checkpoint()
index = {}
with safe_open(checkpoint, framework="pt", device="cpu") as f:
for k in f.keys():
if k.startswith(prefix) or k.startswith(additional_prefix):
new_k = k.split(f"experts.{expert_idx}.", 1)[1]
sd[new_k] = f.get_tensor(k)
if "experts." in k:
layer, expert = parse_layer_expert(k)
index.setdefault(layer, {}).setdefault(expert, []).append(k)
return index
model = HunyuanMLP(self.config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True, device="meta")
model.to_empty(device = "cpu")
model.load_state_dict(sd)
def lazy_init(self, layer_idx, expert_idx):
keys = self.expert_key_index[layer_idx][expert_idx]
model = self.expert_pool[layer_idx][expert_idx]
def strip_expert_prefix(k):
return k.split(f"experts.{expert_idx}.", 1)[1]
sd = { strip_expert_prefix(k): self._file.get_tensor(k) for k in keys }
for name, tensor in sd.items():
getattr(model, name).data = tensor
return model
async def lazy_load_from_disk(self, layer_idx, expert_idx):