diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 750928f03..a09024ca4 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -493,42 +493,44 @@ class MoELRUCache(nn.Module): self.last_offload_event = None self._loop = asyncio.new_event_loop() + self._gpu_sem = asyncio.Semaphore(1) # maybe 2 threading.Thread(target=self._loop.run_forever, daemon=True).start() async def _async_offload_to_cpu(self, layer_idx): # async offload from gpu (removed) - num_experts = 64 - moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i]) - for i in range(num_experts) - if (layer_idx * num_experts + i) in self.gpu_cache] - event = torch.cuda.Event() - - with torch.cuda.stream(self.offload_stream): - for index, moe in moe_group: - moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True) - for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()): - if p_gpu.device.type == "meta": - continue - with torch.no_grad(): - p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True) - p_cpu.copy_(p_gpu, non_blocking=True) + async with self._gpu_sem: + num_experts = 64 + moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i]) + for i in range(num_experts) + if (layer_idx * num_experts + i) in self.gpu_cache] + event = torch.cuda.Event() + + with torch.cuda.stream(self.offload_stream): + for index, moe in moe_group: + moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True) + for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()): + if p_gpu.device.type == "meta": + continue + with torch.no_grad(): + p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True) + p_cpu.copy_(p_gpu, non_blocking=True) - self.cpu_cache[index] = moe_cpu + self.cpu_cache[index] = moe_cpu - self.offload_stream.record_event(event) + self.offload_stream.record_event(event) - self.last_offload_event = event + self.last_offload_event = event - def finalize_offload_layer(): - event.synchronize() - for index, moe in moe_group: - moe.to("meta") - self.gpu_cache.pop(index, None) - del moe - torch.cuda.empty_cache() + def finalize_offload_layer(): + event.synchronize() + for index, moe in moe_group: + moe.to("meta") + self.gpu_cache.pop(index, None) + del moe + torch.cuda.empty_cache() - threading.Thread(target=finalize_offload_layer, daemon=True).start() + threading.Thread(target=finalize_offload_layer, daemon=True).start() async def _async_load_to_gpu(self, index, moe): @@ -632,37 +634,29 @@ class LazyMoELoader(nn.Module): getattr(model, name).data = tensor return model - def _register_expert_sync(self, layer_idx, expert_idx, moe_cpu): - self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx) - asyncio.run_coroutine_threadsafe( - self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu), - self.cache._loop - ) - async def lazy_load_from_disk(self, layer_idx, expert_idx): loop = asyncio.get_event_loop() - async with self._semaphore: - moe_cpu = await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx) - self._loop.call_soon_threadsafe(self._register_expert_sync, layer_idx, expert_idx, moe_cpu) - return moe_cpu - - async def schedule_layer_load_progressive(self, layer_idx, num_experts = 64): - tasks = [asyncio.create_task(self.lazy_load_from_disk(layer_idx, i)) for i in range(num_experts)] - results = await asyncio.gather(*tasks, return_exceptions=False) - return results - - def schedule_layer_load(self, layer_idx, num_experts = 64): - fut = asyncio.run_coroutine_threadsafe( - self.schedule_layer_load_progressive(layer_idx, num_experts), - self._loop - ) - return fut - - async def schedule_layer_load_(self, layer_idx): - tasks = [self.lazy_load_from_disk(layer_idx, i) for i in range(64)] - experts = await asyncio.gather(*tasks) - return experts + return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx) + def _schedule_disk_load(self, layer_idx, expert_idx): + + coro = self.lazy_load_from_disk(layer_idx, expert_idx) + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + + def _on_disk_loaded(fut): + moe_cpu = fut.result() + def _add_cpu_in_main_thread(): + self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx) + + asyncio.run_coroutine_threadsafe( + self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu), + self.cache._loop + ) + threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start() + + future.add_done_callback(_on_disk_loaded) + return future + def enough_vram(required_bytes): free, total = torch.cuda.mem_get_info() return free > required_bytes @@ -748,7 +742,6 @@ class HunyuanMoE(nn.Module): if event is not None and not event.query(): time.sleep(0.001) - combine_weights_used = combine_weights[:, used_indices, :] combined_output = torch.einsum("suc,ucm->sm", @@ -900,6 +893,7 @@ class HunyuanImage3Model(nn.Module): self.padding_idx = 128009 self.vocab_size = 133120 self.config = config + self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx) self.layers = nn.ModuleList( [HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])] ) @@ -932,20 +926,27 @@ class HunyuanImage3Model(nn.Module): sparse_interval = max(1, len(self.layers) // additional_layers) if len(self.layers[0].mlp.experts) == 0: - self.layers[0].mlp.experts = self.moe_loader.schedule_layer_load(0) + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)] for layer_idx, decoder_layer in enumerate(self.layers): - if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded - self.layers[layer_idx+1].mlp.experts = self.moe_loader.schedule_layer_load(layer_idx + 1) + # maybe the second layer loading should depend on how much gpu memory is there + next_layer = layer_idx + 1 if isinstance(self.layers[layer_idx + 1].mlp.experts, list) else layer_idx + 2 + second_next_layer = next_layer + 1 if isinstance(self.layers[layer_idx + 2].mlp.experts, list) else next_layer + 2 - if layer_idx + 2 < len(self.layers) and len(self.layers[layer_idx + 2].mlp.experts) == 0: # load first and second layers - self.layers[layer_idx+2].mlp.experts = self.moe_loader.schedule_layer_load(layer_idx + 2) - - if not self.additional_layers_set: - if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval: - self.layers[next_layers].mlp.experts = self.moe_loader.schedule_layer_load(next_layers) - next_layers += 1 + if next_layer < len(self.layers) and len(self.layers[next_layer].mlp.experts) == 0: # not loaded + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[next_layer].mlp.experts = [expert._schedule_disk_load(next_layer, i) for i, expert in enumerate(experts)] + + if second_next_layer < len(self.layers) and len(self.layers[second_next_layer].mlp.experts) == 0: # not loaded + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[second_next_layer].mlp.experts = [expert._schedule_disk_load(second_next_layer, i) for i, expert in enumerate(experts)] + + if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval: + experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)] + self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)] + next_layers += 1 with torch.no_grad(): layer_outputs = decoder_layer( @@ -1020,7 +1021,8 @@ class HunyuanImage3ForCausalMM(nn.Module): self.first_step = True self.kv_cache = None - self.token_dims = () + self.encode_tok = None + self.special_tok = None @staticmethod def get_pos_emb(custom_pos_emb, position_ids): @@ -1043,24 +1045,36 @@ class HunyuanImage3ForCausalMM(nn.Module): joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind() + bsz, seq_len, n_embd = inputs_embeds.shape cond_exists = (joint_image[:, 0, :] != -100.0).any(dim=1).any() + height, width = x.size(2) * 16, x.size(3) * 16 gen_timestep_scatter_index = 4 + + def fn(string, func = self.encode_tok): + return self.model.wte(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=inputs_embeds.device))\ + .unsqueeze(0).expand(bsz, -1, -1) if cond_exists: with torch.no_grad(): - joint_image[:, 2:3, :] = x[:, 2:3, :] # updates image ratio - - if self.first_step: - token_height, token_width = x[:, -2:, 0].tolist()[0] - self.token_dims = (int(token_height), int(token_width)) - x = x[:, :-2, :] - else: - token_height, token_width = self.token_dims + joint_image[:, 2:3, :] = fn(f"", self.special_tok) # updates image ratio img_slices = [] - bsz, seq_len, n_embd = inputs_embeds.shape + cond_timestep = torch.zeros(x.size(0)) + t_emb = self.time_embed(timestep) + + if self.first_step: + x, token_height, token_width = self.patch_embed(x, t_emb) + x = torch.cat([fn(""), fn("", func = self.special_tok), fn(f"", self.special_tok), fn("", self.special_tok), x, fn("")], dim = 1) + x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) + else: + x, token_height, token_width = self.patch_embed(x, t_emb) + timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd) + inputs_embeds = torch.cat([timestep_emb, x], dim=1) + + input_args = [inputs_embeds, x] if self.first_step else [inputs_embeds] + for i in range(x.size(0)): gen_offset = seq_len + x.size(1) if cond_exists: @@ -1085,27 +1099,13 @@ class HunyuanImage3ForCausalMM(nn.Module): rope_img = [(img_s[0], (token_height, token_width))] rope_image_info = [rope_img if len(joint_slices_i) == 0 else rope_img + [(img_s[1], (384 // 16, 384 // 16)), (img_s[2], (256 // 16, 256 // 16))]] - cond_timestep = torch.zeros(inputs_embeds.size(0)) - t_emb = self.time_embed(cond_timestep) - - - if self.first_step: - x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) - else: - t_emb = self.time_embed(timestep) - x[:, 3:-1], token_height, token_width = self.patch_embed(x[:, 3:-1], t_emb) - timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd) - inputs_embeds = torch.cat([timestep_emb, x], dim=1) - - input_args = [inputs_embeds, x] if self.first_step else [inputs_embeds] - #///////////// # cond_vae_images # cond_timestep_scatter_index if cond_exists: with torch.no_grad(): - joint_image[:, 3:4, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) + joint_image[:, 3:4, :] = self.timestep_emb(cond_timestep.reshape(-1)).reshape(bsz, -1, n_embd) inputs_embeds = torch.cat([*input_args, joint_image], dim = 1) else: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index a3256aa35..43bc4b2e4 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1342,6 +1342,11 @@ class HunyuanImage3(supported_models_base.BASE): state_dict["text_encoders.wte"] = state_dict["model.model.wte"] state_dict.pop("model.model.wte", None) model = model_base.HunyuanImage3(self, device = device) + + temp_tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImage3Tokenizer() + model.encode_tok = temp_tokenizer.tokenizer.convert_tokens_to_ids + model.special_tok = temp_tokenizer.tokenizer.added_tokens_encoder + return model def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImage3Tokenizer, comfy.text_encoders.hunyuan_image.HunyuanImage3) diff --git a/comfy_extras/nodes_hunyuan_image.py b/comfy_extras/nodes_hunyuan_image.py index 1989e4b3d..7d765d53a 100644 --- a/comfy_extras/nodes_hunyuan_image.py +++ b/comfy_extras/nodes_hunyuan_image.py @@ -22,34 +22,13 @@ class EmptyLatentHunyuanImage3(io.ComfyNode): io.Int.Input("height", min = 1, default = 512), io.Int.Input("width", min = 1, default = 512), io.Int.Input("batch_size", min = 1, max = 48_000, default = 1), - io.Clip.Input("clip"), - io.Model.Input("model") ], outputs=[io.Latent.Output(display_name="latent")] ) @classmethod - def execute(cls, height, width, batch_size, clip, model): - encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids - special_fn = clip.tokenizer.tokenizer.added_tokens_encoder - - word_embed = clip.wte - patch_embed = model.patch_embed - t_embed = model.time_embed - + def execute(cls, height, width, batch_size): height, width = get_target_size(height, width) latent = torch.randn(batch_size, 32, int(height) // 16, int(width) // 16, device=comfy.model_management.intermediate_device()) - - latent, tk_height, tk_width = patch_embed(latent, t_embed(torch.tensor([0]).repeat(batch_size))) - - def tk_fn(token): - return torch.tensor([token], device = latent.device, dtype = latent.dtype).unsqueeze(1).expand(batch_size, 1, latent.size(-1)) - - def fn(string, func = encode_fn): - return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\ - .unsqueeze(0).expand(batch_size, -1, -1) - - latent = torch.cat([fn(""), fn("", func = special_fn), fn(f"", special_fn), fn("", special_fn), latent, fn("")], dim = 1) - latent = torch.cat([latent, tk_fn(tk_height), tk_fn(tk_width)], dim = 1) return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, ) class HunyuanImage3Conditioning(io.ComfyNode):