From ca119c44fb68d1bb694fdbf7a83d97c9095b8968 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sat, 1 Nov 2025 23:06:11 +0200 Subject: [PATCH] returned kv cache for image generation --- comfy/ldm/hunyuan_image_3/model.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index ba2c1e90c..682fd8781 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -832,6 +832,7 @@ class HunyuanImage3Attention(nn.Module): def forward( self, hidden_states: torch.Tensor, + past_key_value, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None, @@ -858,6 +859,11 @@ class HunyuanImage3Attention(nn.Module): query_states = query_states.to(value_states.dtype) key_states = key_states.to(value_states.dtype) + if past_key_value is not None: + cache_kwargs = {"cache_position": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + query_states = query_states.to(key_states.dtype) + key_states = torch.repeat_interleave(key_states, dim=1, repeats = self.num_key_value_groups) value_states = torch.repeat_interleave(value_states, dim=1, repeats = self.num_key_value_groups) @@ -870,7 +876,7 @@ class HunyuanImage3Attention(nn.Module): attn_output = self.o_proj(attn_output) - return attn_output + return attn_output, past_key_value class HunyuanImage3DecoderLayer(nn.Module): def __init__(self, config, layer_idx: int, moe_lru=None): @@ -900,7 +906,7 @@ class HunyuanImage3DecoderLayer(nn.Module): hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states = self.self_attn( + hidden_states, past_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -917,7 +923,7 @@ class HunyuanImage3DecoderLayer(nn.Module): hidden_states = residual + hidden_states - outputs = (hidden_states,) + outputs = (hidden_states, past_key_value) return outputs @@ -1039,6 +1045,9 @@ class HunyuanImage3ForCausalMM(nn.Module): def forward(self, x, condition, timestep, **kwargs): + cond, uncond = condition[:4], condition[4:] + joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1] + if self.kv_cache is None: # TODO: should change when higgsv2 gets merged self.kv_cache = HunyuanStaticCache( @@ -1049,10 +1058,8 @@ class HunyuanImage3ForCausalMM(nn.Module): ) image_mask = torch.ones(x.size(1)) - image_mask[:, :5] = torch.zeros(5); image_mask[:, -4:] = torch.zeros(4) + image_mask[:, :3] = torch.zeros(5); image_mask[:, -1] = torch.zeros(0) gen_timestep_scatter_index = 4 - cond, uncond = condition[:4], condition[4:] - joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1] joint_image[:, 2] = x[:, 2] # updates image ratio position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) @@ -1126,6 +1133,11 @@ class HunyuanImage3ForCausalMM(nn.Module): ) hidden_states = outputs[0] + # safety no-op + past_key_value = outputs[1] + if past_key_value is not None: + self.kv_cache = past_key_value + hidden_states = hidden_states.to(input_ids.device) diffusion_prediction = self.ragged_final_layer( hidden_states, image_mask, timestep, token_h, token_w, self.first_step)