diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 11eb29d9e..60ab43340 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -548,7 +548,10 @@ class MoELRUCache(nn.Module): for _, p in moe_cpu.named_parameters(): if not p.is_pinned(): - p.data = p.data.pin_memory() + if p.device.type == "cpu": + p.data = p.data.pin_memory() + else: + return self.cpu_cache[index] = moe_cpu self.cpu_cache.move_to_end(index) @@ -643,74 +646,81 @@ class HunyuanMoE(nn.Module): with torch.cuda.nvtx.range("MoE"): expert_weight, expert_index = self.gate(hidden_states) - - combined_output = torch.zeros_like(reshaped_input) - experts_list = [(i, expert) for i, expert in enumerate(self.experts)] - - per_pos, per_tokens, per_weights = [], [], [] - for e, _ in experts_list: + device = hidden_states.device + dtype = reshaped_input.dtype + + combined_output = torch.zeros_like(reshaped_input, device=device, dtype=dtype) + + per_pos = [None] * self.num_experts + per_tokens = [None] * self.num_experts + per_weights = [None] * self.num_experts + for e in range(self.num_experts): token_mask = (expert_index == e) - token_ids = token_mask.nonzero(as_tuple=False) + if token_ids.numel() == 0: + continue token_positions = token_ids[:, 0] topk_slot = token_ids[:, 1] + per_pos[e] = token_positions + per_tokens[e] = reshaped_input[token_positions] + per_weights[e] = expert_weight[token_positions, topk_slot] + + used = [i for i, t in enumerate(per_tokens) if t is not None] + if len(used) == 0: + pass + else: + tokens_list = [per_tokens[i] for i in used] + weights_list = [per_weights[i] for i in used] + lengths = [t.shape[0] for t in tokens_list] + U = len(tokens_list) + L = max(lengths) + H = hidden_size + + tokens_padded = torch.zeros((U, L, H), device=device, dtype=dtype) + weights_padded = torch.zeros((U, L), device=device, dtype=weights_list[0].dtype) + for idx, t in enumerate(tokens_list): + n = t.shape[0] + tokens_padded[idx, :n] = t + weights_padded[idx, :n] = weights_list[idx] + + l1, l2 = [], [] + for i in used: + expert = self.experts[i] + if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): + expert = expert.result() + expert = expert.to(device) + l1.append(expert.gate_and_up_proj) + l2.append(expert.down_proj) - tokens = reshaped_input[token_positions] - weights = expert_weight[token_positions, topk_slot] + compute_device = hidden_states.device + l1 = [m.to(compute_device) for m in l1] + l2 = [m.to(compute_device) for m in l2] - per_pos.append(token_positions) - per_tokens.append(tokens) - per_weights.append(weights) - - lengths = [t.shape[0] for t in per_tokens] - E = len(per_tokens) - L = max(lengths) - tokens_padded = torch.zeros((E, L, hidden_size), device=hidden_states.device, dtype=reshaped_input.dtype) - weights_padded = torch.zeros((E, L), device=hidden_states.device, dtype=per_weights[0].dtype) - for i, t in enumerate(per_tokens): - tokens_padded[i, : t.shape[0]] = t - weights_padded[i, : t.shape[0]] = per_weights[i] - - l1, l2 = [], [] - for _, expert in experts_list: - if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): - expert = expert.result() - l1.append(expert.gate_and_up_proj) - l2.append(expert.down_proj) - - W1 = torch.stack([l.weight for l in l1]).to(hidden_states.device) - W2 = torch.stack([l.weight for l in l2]).to(hidden_states.device) - - W1_T = W1.transpose(1, 2) - W2_T = W2.transpose(1, 2) - - # wait for enough vram for the computations - while not enough_vram(5*(1024 ** 3)): - event = self.moe_lru.last_offload_event - if event is not None and not event.query(): - time.sleep(0.001) - - x = torch.bmm(tokens_padded, W1_T) - x = F.silu(x) - - x1, x2 = x.chunk(2, dim=2) - out_padded = torch.bmm(x1 * F.silu(x2), W2_T) - - out_padded = out_padded * weights_padded.unsqueeze(-1) - - for i, token_positions in enumerate(per_pos): - Ni = lengths[i] - out_i = out_padded[i, :Ni] - combined_output.to(hidden_states.dtype).index_add_(0, token_positions.to(hidden_states.device), out_i.to(hidden_states.dtype)) - - #dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input) - #chunks = dispatched_input.chunk(self.num_experts, dim=0) - #expert_outputs = [] - #for chunk, expert in zip(chunks, self.experts): - # expert_outputs.append(expert(chunk)) - - #expert_output = torch.cat(expert_outputs, dim=0) - #combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output) + W1 = torch.stack([m.weight for m in l1], dim=0) + W2 = torch.stack([m.weight for m in l2], dim=0) + + W1_T = W1.transpose(1, 2) + W2_T = W2.transpose(1, 2) + + while not enough_vram(3*(1024 ** 3)): + event = self.moe_lru.last_offload_event + if event is not None and not event.query(): + time.sleep(0.001) + + x = torch.bmm(tokens_padded, W1_T) + x = F.silu(x) + x1, x2 = x.chunk(2, dim=2) + out_padded = torch.bmm(x1 * F.silu(x2), W2_T) + + out_padded = out_padded * weights_padded.unsqueeze(-1) + + for idx, orig_expert_idx in enumerate(used): + pos = per_pos[orig_expert_idx] + n = lengths[idx] + out_i = out_padded[idx, :n] + combined_output.index_add_(0, pos.to(device), out_i.to(combined_output.dtype)) + + del tokens_padded, weights_padded, W1, W2, W1_T, W2_T, x, x1, x2, out_padded combined_output = combined_output.reshape(bsz, seq_len, hidden_size) @@ -863,6 +873,7 @@ class HunyuanImage3Model(nn.Module): self.shared_tensor = None self.moe_lru = moe_lru + self.self.additional_layers_set = False def forward( self, @@ -885,7 +896,8 @@ class HunyuanImage3Model(nn.Module): next_decoder_cache = None next_layers = 0 - sparse_interval = max(1, len(self.layers) // 3) + additional_layers = torch.cuda.mem_get_info()[0] // (MOE_LAYER_SIZE * 2) + sparse_interval = max(1, len(self.layers) // additional_layers) if len(self.layers[0].mlp.experts) == 0: experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] @@ -896,13 +908,12 @@ class HunyuanImage3Model(nn.Module): if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] self.layers[layer_idx+1].mlp.experts = [expert._schedule_disk_load(layer_idx+1, i) for i, expert in enumerate(experts)] - - if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval: - if len(self.layers[next_layers].mlp.experts) > 0: # for testing - raise ValueError("Problem with offloading") - experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] - self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)] - next_layers += 1 + + if not additional_layers_set: + if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval: + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)] + next_layers += 1 with torch.no_grad(): layer_outputs = decoder_layer( @@ -918,11 +929,15 @@ class HunyuanImage3Model(nn.Module): ) if layer_idx >= 0: - asyncio.run_coroutine_threadsafe( - self.moe_lru._async_offload_to_cpu(layer_idx), - self.moe_lru._loop - ) - self.layers[layer_idx].mlp.experts = [] + if self.additional_layers_set and layer_idx <= self.additional_layers_set: + pass + else: + torch.cuda.synchronize() + asyncio.run_coroutine_threadsafe( + self.moe_lru._async_offload_to_cpu(layer_idx), + self.moe_lru._loop + ) + self.layers[layer_idx].mlp.experts = [] hidden_states = layer_outputs[0] @@ -932,7 +947,7 @@ class HunyuanImage3Model(nn.Module): next_cache = None if use_cache: next_cache = next_decoder_cache - + self.additional_layers_set = True return tuple(v for v in [hidden_states, next_cache] if v is not None)