Made Chroma work with optimized_attention_override

This commit is contained in:
Jedrzej Kosinski 2025-08-28 22:45:31 -07:00
parent d644aba6bc
commit 8be3edb606
2 changed files with 15 additions and 10 deletions

View File

@ -76,7 +76,7 @@ class DoubleStreamBlock(nn.Module):
) )
self.flipped_img_txt = flipped_img_txt self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None): def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention # prepare image for attention
@ -95,7 +95,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((txt_q, img_q), dim=2), attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2), torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2), torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask) pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
@ -148,7 +148,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh") self.mlp_act = nn.GELU(approximate="tanh")
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor: def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
mod = vec mod = vec
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x)) x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@ -157,7 +157,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
# compute attention # compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask) attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x.addcmul_(mod.gate, output) x.addcmul_(mod.gate, output)

View File

@ -193,14 +193,16 @@ class Chroma(nn.Module):
txt=args["txt"], txt=args["txt"],
vec=args["vec"], vec=args["vec"],
pe=args["pe"], pe=args["pe"],
attn_mask=args.get("attn_mask")) attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out return out
out = blocks_replace[("double_block", i)]({"img": img, out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt, "txt": txt,
"vec": double_mod, "vec": double_mod,
"pe": pe, "pe": pe,
"attn_mask": attn_mask}, "attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap}) {"original_block": block_wrap})
txt = out["txt"] txt = out["txt"]
img = out["img"] img = out["img"]
@ -209,7 +211,8 @@ class Chroma(nn.Module):
txt=txt, txt=txt,
vec=double_mod, vec=double_mod,
pe=pe, pe=pe,
attn_mask=attn_mask) attn_mask=attn_mask,
transformer_options=transformer_options)
if control is not None: # Controlnet if control is not None: # Controlnet
control_i = control.get("input") control_i = control.get("input")
@ -229,17 +232,19 @@ class Chroma(nn.Module):
out["img"] = block(args["img"], out["img"] = block(args["img"],
vec=args["vec"], vec=args["vec"],
pe=args["pe"], pe=args["pe"],
attn_mask=args.get("attn_mask")) attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out return out
out = blocks_replace[("single_block", i)]({"img": img, out = blocks_replace[("single_block", i)]({"img": img,
"vec": single_mod, "vec": single_mod,
"pe": pe, "pe": pe,
"attn_mask": attn_mask}, "attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap}) {"original_block": block_wrap})
img = out["img"] img = out["img"]
else: else:
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask) img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet if control is not None: # Controlnet
control_o = control.get("output") control_o = control.get("output")