improving performance and fixing race condition

This commit is contained in:
Yousef Rafat 2025-11-16 16:19:39 +02:00
parent 12cc6924ac
commit d731c58353

View File

@ -548,7 +548,10 @@ class MoELRUCache(nn.Module):
for _, p in moe_cpu.named_parameters(): for _, p in moe_cpu.named_parameters():
if not p.is_pinned(): if not p.is_pinned():
p.data = p.data.pin_memory() if p.device.type == "cpu":
p.data = p.data.pin_memory()
else:
return
self.cpu_cache[index] = moe_cpu self.cpu_cache[index] = moe_cpu
self.cpu_cache.move_to_end(index) self.cpu_cache.move_to_end(index)
@ -643,74 +646,81 @@ class HunyuanMoE(nn.Module):
with torch.cuda.nvtx.range("MoE"): with torch.cuda.nvtx.range("MoE"):
expert_weight, expert_index = self.gate(hidden_states) expert_weight, expert_index = self.gate(hidden_states)
device = hidden_states.device
dtype = reshaped_input.dtype
combined_output = torch.zeros_like(reshaped_input) combined_output = torch.zeros_like(reshaped_input, device=device, dtype=dtype)
experts_list = [(i, expert) for i, expert in enumerate(self.experts)]
per_pos, per_tokens, per_weights = [], [], [] per_pos = [None] * self.num_experts
for e, _ in experts_list: per_tokens = [None] * self.num_experts
per_weights = [None] * self.num_experts
for e in range(self.num_experts):
token_mask = (expert_index == e) token_mask = (expert_index == e)
token_ids = token_mask.nonzero(as_tuple=False) token_ids = token_mask.nonzero(as_tuple=False)
if token_ids.numel() == 0:
continue
token_positions = token_ids[:, 0] token_positions = token_ids[:, 0]
topk_slot = token_ids[:, 1] topk_slot = token_ids[:, 1]
per_pos[e] = token_positions
per_tokens[e] = reshaped_input[token_positions]
per_weights[e] = expert_weight[token_positions, topk_slot]
tokens = reshaped_input[token_positions] used = [i for i, t in enumerate(per_tokens) if t is not None]
weights = expert_weight[token_positions, topk_slot] if len(used) == 0:
pass
else:
tokens_list = [per_tokens[i] for i in used]
weights_list = [per_weights[i] for i in used]
lengths = [t.shape[0] for t in tokens_list]
U = len(tokens_list)
L = max(lengths)
H = hidden_size
per_pos.append(token_positions) tokens_padded = torch.zeros((U, L, H), device=device, dtype=dtype)
per_tokens.append(tokens) weights_padded = torch.zeros((U, L), device=device, dtype=weights_list[0].dtype)
per_weights.append(weights) for idx, t in enumerate(tokens_list):
n = t.shape[0]
tokens_padded[idx, :n] = t
weights_padded[idx, :n] = weights_list[idx]
lengths = [t.shape[0] for t in per_tokens] l1, l2 = [], []
E = len(per_tokens) for i in used:
L = max(lengths) expert = self.experts[i]
tokens_padded = torch.zeros((E, L, hidden_size), device=hidden_states.device, dtype=reshaped_input.dtype) if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
weights_padded = torch.zeros((E, L), device=hidden_states.device, dtype=per_weights[0].dtype) expert = expert.result()
for i, t in enumerate(per_tokens): expert = expert.to(device)
tokens_padded[i, : t.shape[0]] = t l1.append(expert.gate_and_up_proj)
weights_padded[i, : t.shape[0]] = per_weights[i] l2.append(expert.down_proj)
l1, l2 = [], [] compute_device = hidden_states.device
for _, expert in experts_list: l1 = [m.to(compute_device) for m in l1]
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): l2 = [m.to(compute_device) for m in l2]
expert = expert.result()
l1.append(expert.gate_and_up_proj)
l2.append(expert.down_proj)
W1 = torch.stack([l.weight for l in l1]).to(hidden_states.device) W1 = torch.stack([m.weight for m in l1], dim=0)
W2 = torch.stack([l.weight for l in l2]).to(hidden_states.device) W2 = torch.stack([m.weight for m in l2], dim=0)
W1_T = W1.transpose(1, 2) W1_T = W1.transpose(1, 2)
W2_T = W2.transpose(1, 2) W2_T = W2.transpose(1, 2)
# wait for enough vram for the computations while not enough_vram(3*(1024 ** 3)):
while not enough_vram(5*(1024 ** 3)): event = self.moe_lru.last_offload_event
event = self.moe_lru.last_offload_event if event is not None and not event.query():
if event is not None and not event.query(): time.sleep(0.001)
time.sleep(0.001)
x = torch.bmm(tokens_padded, W1_T) x = torch.bmm(tokens_padded, W1_T)
x = F.silu(x) x = F.silu(x)
x1, x2 = x.chunk(2, dim=2)
out_padded = torch.bmm(x1 * F.silu(x2), W2_T)
x1, x2 = x.chunk(2, dim=2) out_padded = out_padded * weights_padded.unsqueeze(-1)
out_padded = torch.bmm(x1 * F.silu(x2), W2_T)
out_padded = out_padded * weights_padded.unsqueeze(-1) for idx, orig_expert_idx in enumerate(used):
pos = per_pos[orig_expert_idx]
n = lengths[idx]
out_i = out_padded[idx, :n]
combined_output.index_add_(0, pos.to(device), out_i.to(combined_output.dtype))
for i, token_positions in enumerate(per_pos): del tokens_padded, weights_padded, W1, W2, W1_T, W2_T, x, x1, x2, out_padded
Ni = lengths[i]
out_i = out_padded[i, :Ni]
combined_output.to(hidden_states.dtype).index_add_(0, token_positions.to(hidden_states.device), out_i.to(hidden_states.dtype))
#dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
#chunks = dispatched_input.chunk(self.num_experts, dim=0)
#expert_outputs = []
#for chunk, expert in zip(chunks, self.experts):
# expert_outputs.append(expert(chunk))
#expert_output = torch.cat(expert_outputs, dim=0)
#combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output)
combined_output = combined_output.reshape(bsz, seq_len, hidden_size) combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
@ -863,6 +873,7 @@ class HunyuanImage3Model(nn.Module):
self.shared_tensor = None self.shared_tensor = None
self.moe_lru = moe_lru self.moe_lru = moe_lru
self.self.additional_layers_set = False
def forward( def forward(
self, self,
@ -885,7 +896,8 @@ class HunyuanImage3Model(nn.Module):
next_decoder_cache = None next_decoder_cache = None
next_layers = 0 next_layers = 0
sparse_interval = max(1, len(self.layers) // 3) additional_layers = torch.cuda.mem_get_info()[0] // (MOE_LAYER_SIZE * 2)
sparse_interval = max(1, len(self.layers) // additional_layers)
if len(self.layers[0].mlp.experts) == 0: if len(self.layers[0].mlp.experts) == 0:
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
@ -897,12 +909,11 @@ class HunyuanImage3Model(nn.Module):
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)] self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)]
if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval: if not additional_layers_set:
if len(self.layers[next_layers].mlp.experts) > 0: # for testing if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval:
raise ValueError("Problem with offloading") experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)] next_layers += 1
next_layers += 1
with torch.no_grad(): with torch.no_grad():
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@ -918,11 +929,15 @@ class HunyuanImage3Model(nn.Module):
) )
if layer_idx >= 0: if layer_idx >= 0:
asyncio.run_coroutine_threadsafe( if self.additional_layers_set and layer_idx <= self.additional_layers_set:
self.moe_lru._async_offload_to_cpu(layer_idx), pass
self.moe_lru._loop else:
) torch.cuda.synchronize()
self.layers[layer_idx].mlp.experts = [] asyncio.run_coroutine_threadsafe(
self.moe_lru._async_offload_to_cpu(layer_idx),
self.moe_lru._loop
)
self.layers[layer_idx].mlp.experts = []
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -932,7 +947,7 @@ class HunyuanImage3Model(nn.Module):
next_cache = None next_cache = None
if use_cache: if use_cache:
next_cache = next_decoder_cache next_cache = next_decoder_cache
self.additional_layers_set = True
return tuple(v for v in [hidden_states, next_cache] if v is not None) return tuple(v for v in [hidden_states, next_cache] if v is not None)