mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-04 02:30:21 +08:00
returned kv cache for image generation
This commit is contained in:
parent
10a17dc85d
commit
ca119c44fb
@ -832,6 +832,7 @@ class HunyuanImage3Attention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
past_key_value,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
@ -858,6 +859,11 @@ class HunyuanImage3Attention(nn.Module):
|
||||
query_states = query_states.to(value_states.dtype)
|
||||
key_states = key_states.to(value_states.dtype)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"cache_position": position_ids}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
query_states = query_states.to(key_states.dtype)
|
||||
|
||||
key_states = torch.repeat_interleave(key_states, dim=1, repeats = self.num_key_value_groups)
|
||||
value_states = torch.repeat_interleave(value_states, dim=1, repeats = self.num_key_value_groups)
|
||||
|
||||
@ -870,7 +876,7 @@ class HunyuanImage3Attention(nn.Module):
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
return attn_output, past_key_value
|
||||
|
||||
class HunyuanImage3DecoderLayer(nn.Module):
|
||||
def __init__(self, config, layer_idx: int, moe_lru=None):
|
||||
@ -900,7 +906,7 @@ class HunyuanImage3DecoderLayer(nn.Module):
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states, past_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@ -917,7 +923,7 @@ class HunyuanImage3DecoderLayer(nn.Module):
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
outputs = (hidden_states, past_key_value)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -1039,6 +1045,9 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
|
||||
def forward(self, x, condition, timestep, **kwargs):
|
||||
|
||||
cond, uncond = condition[:4], condition[4:]
|
||||
joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1]
|
||||
|
||||
if self.kv_cache is None:
|
||||
# TODO: should change when higgsv2 gets merged
|
||||
self.kv_cache = HunyuanStaticCache(
|
||||
@ -1049,10 +1058,8 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
)
|
||||
|
||||
image_mask = torch.ones(x.size(1))
|
||||
image_mask[:, :5] = torch.zeros(5); image_mask[:, -4:] = torch.zeros(4)
|
||||
image_mask[:, :3] = torch.zeros(5); image_mask[:, -1] = torch.zeros(0)
|
||||
gen_timestep_scatter_index = 4
|
||||
cond, uncond = condition[:4], condition[4:]
|
||||
joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1]
|
||||
joint_image[:, 2] = x[:, 2] # updates image ratio
|
||||
|
||||
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
||||
@ -1126,6 +1133,11 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
|
||||
# safety no-op
|
||||
past_key_value = outputs[1]
|
||||
if past_key_value is not None:
|
||||
self.kv_cache = past_key_value
|
||||
|
||||
hidden_states = hidden_states.to(input_ids.device)
|
||||
diffusion_prediction = self.ragged_final_layer(
|
||||
hidden_states, image_mask, timestep, token_h, token_w, self.first_step)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user