returned kv cache for image generation

This commit is contained in:
Yousef Rafat 2025-11-01 23:06:11 +02:00
parent 10a17dc85d
commit ca119c44fb

View File

@ -832,6 +832,7 @@ class HunyuanImage3Attention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
past_key_value,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
@ -858,6 +859,11 @@ class HunyuanImage3Attention(nn.Module):
query_states = query_states.to(value_states.dtype)
key_states = key_states.to(value_states.dtype)
if past_key_value is not None:
cache_kwargs = {"cache_position": position_ids}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
query_states = query_states.to(key_states.dtype)
key_states = torch.repeat_interleave(key_states, dim=1, repeats = self.num_key_value_groups)
value_states = torch.repeat_interleave(value_states, dim=1, repeats = self.num_key_value_groups)
@ -870,7 +876,7 @@ class HunyuanImage3Attention(nn.Module):
attn_output = self.o_proj(attn_output)
return attn_output
return attn_output, past_key_value
class HunyuanImage3DecoderLayer(nn.Module):
def __init__(self, config, layer_idx: int, moe_lru=None):
@ -900,7 +906,7 @@ class HunyuanImage3DecoderLayer(nn.Module):
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, past_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -917,7 +923,7 @@ class HunyuanImage3DecoderLayer(nn.Module):
hidden_states = residual + hidden_states
outputs = (hidden_states,)
outputs = (hidden_states, past_key_value)
return outputs
@ -1039,6 +1045,9 @@ class HunyuanImage3ForCausalMM(nn.Module):
def forward(self, x, condition, timestep, **kwargs):
cond, uncond = condition[:4], condition[4:]
joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1]
if self.kv_cache is None:
# TODO: should change when higgsv2 gets merged
self.kv_cache = HunyuanStaticCache(
@ -1049,10 +1058,8 @@ class HunyuanImage3ForCausalMM(nn.Module):
)
image_mask = torch.ones(x.size(1))
image_mask[:, :5] = torch.zeros(5); image_mask[:, -4:] = torch.zeros(4)
image_mask[:, :3] = torch.zeros(5); image_mask[:, -1] = torch.zeros(0)
gen_timestep_scatter_index = 4
cond, uncond = condition[:4], condition[4:]
joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1]
joint_image[:, 2] = x[:, 2] # updates image ratio
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
@ -1126,6 +1133,11 @@ class HunyuanImage3ForCausalMM(nn.Module):
)
hidden_states = outputs[0]
# safety no-op
past_key_value = outputs[1]
if past_key_value is not None:
self.kv_cache = past_key_value
hidden_states = hidden_states.to(input_ids.device)
diffusion_prediction = self.ragged_final_layer(
hidden_states, image_mask, timestep, token_h, token_w, self.first_step)