Add transformer_options for attention

This commit is contained in:
kijai 2025-11-25 00:45:04 +02:00
parent 8fd49be716
commit 0920cdcf63
2 changed files with 32 additions and 37 deletions

View File

@ -108,13 +108,13 @@ class SelfAttention(nn.Module):
v = self.to_value(x)
shape = q.shape[:-1]
q = q.reshape(*shape, self.num_heads, -1)
k = k.reshape(*shape, self.num_heads, -1)
v = v.reshape(*shape, self.num_heads, -1)
q = q.view(*shape, self.num_heads, -1)
k = k.view(*shape, self.num_heads, -1)
v = v.view(*shape, self.num_heads, -1)
return q, k, v
def forward(self, x, freqs):
def forward(self, x, freqs, transformer_options={}):
q, k, v = self.get_qkv(x)
q = apply_rope1(self.query_norm(q), freqs)
@ -124,9 +124,10 @@ class SelfAttention(nn.Module):
q.flatten(-2, -1),
k.flatten(-2, -1),
v.flatten(-2, -1),
heads=self.num_heads)
out = self.out_layer(out)
return out
heads=self.num_heads,
transformer_options=transformer_options
)
return self.out_layer(out)
class CrossAttention(nn.Module):
@ -145,18 +146,18 @@ class CrossAttention(nn.Module):
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def get_qkv(self, x, cond):
query = self.to_query(x)
key = self.to_key(cond)
value = self.to_value(cond)
q = self.to_query(x)
k = self.to_key(cond)
v = self.to_value(cond)
shape, cond_shape = query.shape[:-1], key.shape[:-1]
query = query.reshape(*shape, self.num_heads, -1)
key = key.reshape(*cond_shape, self.num_heads, -1)
value = value.reshape(*cond_shape, self.num_heads, -1)
shape, cond_shape = q.shape[:-1], k.shape[:-1]
q = q.view(*shape, self.num_heads, -1)
k = k.view(*cond_shape, self.num_heads, -1)
v = v.view(*cond_shape, self.num_heads, -1)
return query, key, value
return q, k, v
def forward(self, x, cond):
def forward(self, x, cond, transformer_options={}):
q, k, v = self.get_qkv(x, cond)
q = self.query_norm(q)
k = self.key_norm(k)
@ -165,10 +166,11 @@ class CrossAttention(nn.Module):
q.flatten(-2, -1),
k.flatten(-2, -1),
v.flatten(-2, -1),
heads=self.num_heads)
heads=self.num_heads,
transformer_options=transformer_options
)
out = self.out_layer(out)
return out
return self.out_layer(out)
class FeedForward(nn.Module):
@ -220,11 +222,11 @@ class TransformerEncoderBlock(nn.Module):
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
def forward(self, x, time_embed, freqs):
def forward(self, x, time_embed, freqs, transformer_options={}):
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
shift, scale, gate = get_shift_scale_gate(self_attn_params)
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
out = self.self_attention(out, freqs)
out = self.self_attention(out, freqs, transformer_options=transformer_options)
x = apply_gate_sum(x, out, gate)
shift, scale, gate = get_shift_scale_gate(ff_params)
@ -249,19 +251,17 @@ class TransformerDecoderBlock(nn.Module):
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
def forward(self, visual_embed, text_embed, time_embed, freqs):
self_attn_params, cross_attn_params, ff_params = torch.chunk(
self.visual_modulation(time_embed), 3, dim=-1
)
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
# self attention
shift, scale, gate = get_shift_scale_gate(self_attn_params)
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
visual_out = self.self_attention(visual_out, freqs)
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# cross attention
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
visual_out = self.cross_attention(visual_out, text_embed)
visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# feed forward
shift, scale, gate = get_shift_scale_gate(ff_params)
@ -374,13 +374,13 @@ class Kandinsky5(nn.Module):
time_embed = self.time_embeddings(timestep) + self.pooled_text_embeddings(y)
for block in self.text_transformer_blocks:
context = block(context, time_embed, freqs_text)
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
visual_embed = self.visual_embeddings(x)
visual_shape = visual_embed.shape[:-1]
for block in self.visual_transformer_blocks:
visual_embed = block(visual_embed.flatten(1, -2), context, time_embed, freqs=freqs)
visual_embed = block(visual_embed.flatten(1, -2), context, time_embed, freqs=freqs, transformer_options=transformer_options)
visual_embed = visual_embed.reshape(*visual_shape, -1)
return self.out_layer(visual_embed, time_embed)

View File

@ -6,12 +6,11 @@ from .llama import Qwen25_7BVLI
class Kandinsky5Tokenizer(QwenImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
# yes the typo "promt" was also in the original template...
# yes the typos "promt", "scren" were also in the original template... todo: check if it matters
self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
self.llama_template_image2video = "<|im_start|>system\nYou are a promt engineer. Your task is to create a highly detailed and effective video description based on a provided input image.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe main characters actions.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
@ -34,12 +33,8 @@ class Kandinsky5TEModel(QwenImageTEModel):
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
#tok_pairs = token_weight_pairs["qwen25_7b"][0]
token_weight_pairs_l = token_weight_pairs["l"]
template_end = -1
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=template_end)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"])
return cond, l_pooled, extra