mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-18 18:43:05 +08:00
Made Chroma work with optimized_attention_override
This commit is contained in:
parent
d644aba6bc
commit
8be3edb606
@ -76,7 +76,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
self.flipped_img_txt = flipped_img_txt
|
self.flipped_img_txt = flipped_img_txt
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
|
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
|
||||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||||
|
|
||||||
# prepare image for attention
|
# prepare image for attention
|
||||||
@ -95,7 +95,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||||
torch.cat((txt_k, img_k), dim=2),
|
torch.cat((txt_k, img_k), dim=2),
|
||||||
torch.cat((txt_v, img_v), dim=2),
|
torch.cat((txt_v, img_v), dim=2),
|
||||||
pe=pe, mask=attn_mask)
|
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
self.mlp_act = nn.GELU(approximate="tanh")
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
|
|
||||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
|
||||||
mod = vec
|
mod = vec
|
||||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
@ -157,7 +157,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
x.addcmul_(mod.gate, output)
|
x.addcmul_(mod.gate, output)
|
||||||
|
|||||||
@ -193,14 +193,16 @@ class Chroma(nn.Module):
|
|||||||
txt=args["txt"],
|
txt=args["txt"],
|
||||||
vec=args["vec"],
|
vec=args["vec"],
|
||||||
pe=args["pe"],
|
pe=args["pe"],
|
||||||
attn_mask=args.get("attn_mask"))
|
attn_mask=args.get("attn_mask"),
|
||||||
|
transformer_options=args.get("transformer_options"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": img,
|
out = blocks_replace[("double_block", i)]({"img": img,
|
||||||
"txt": txt,
|
"txt": txt,
|
||||||
"vec": double_mod,
|
"vec": double_mod,
|
||||||
"pe": pe,
|
"pe": pe,
|
||||||
"attn_mask": attn_mask},
|
"attn_mask": attn_mask,
|
||||||
|
"transformer_options": transformer_options},
|
||||||
{"original_block": block_wrap})
|
{"original_block": block_wrap})
|
||||||
txt = out["txt"]
|
txt = out["txt"]
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
@ -209,7 +211,8 @@ class Chroma(nn.Module):
|
|||||||
txt=txt,
|
txt=txt,
|
||||||
vec=double_mod,
|
vec=double_mod,
|
||||||
pe=pe,
|
pe=pe,
|
||||||
attn_mask=attn_mask)
|
attn_mask=attn_mask,
|
||||||
|
transformer_options=transformer_options)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_i = control.get("input")
|
control_i = control.get("input")
|
||||||
@ -229,17 +232,19 @@ class Chroma(nn.Module):
|
|||||||
out["img"] = block(args["img"],
|
out["img"] = block(args["img"],
|
||||||
vec=args["vec"],
|
vec=args["vec"],
|
||||||
pe=args["pe"],
|
pe=args["pe"],
|
||||||
attn_mask=args.get("attn_mask"))
|
attn_mask=args.get("attn_mask"),
|
||||||
|
transformer_options=args.get("transformer_options"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": img,
|
out = blocks_replace[("single_block", i)]({"img": img,
|
||||||
"vec": single_mod,
|
"vec": single_mod,
|
||||||
"pe": pe,
|
"pe": pe,
|
||||||
"attn_mask": attn_mask},
|
"attn_mask": attn_mask,
|
||||||
|
"transformer_options": transformer_options},
|
||||||
{"original_block": block_wrap})
|
{"original_block": block_wrap})
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
|
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_o = control.get("output")
|
control_o = control.get("output")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user