mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
Merge 43b0dab91e into c5ecd231a2
This commit is contained in:
commit
47d8bd5bb0
@ -328,7 +328,7 @@ class CrossAttention(nn.Module):
|
|||||||
kv = torch.cat((k, v), dim=-1)
|
kv = torch.cat((k, v), dim=-1)
|
||||||
split_size = kv.shape[-1] // self.num_heads // 2
|
split_size = kv.shape[-1] // self.num_heads // 2
|
||||||
|
|
||||||
kv = kv.view(1, -1, self.num_heads, split_size * 2)
|
kv = kv.view(b, -1, self.num_heads, split_size * 2)
|
||||||
k, v = torch.split(kv, split_size, dim=-1)
|
k, v = torch.split(kv, split_size, dim=-1)
|
||||||
|
|
||||||
q = q.view(b, s1, self.num_heads, self.head_dim)
|
q = q.view(b, s1, self.num_heads, self.head_dim)
|
||||||
@ -398,7 +398,7 @@ class Attention(nn.Module):
|
|||||||
qkv_combined = torch.cat((query, key, value), dim=-1)
|
qkv_combined = torch.cat((query, key, value), dim=-1)
|
||||||
split_size = qkv_combined.shape[-1] // self.num_heads // 3
|
split_size = qkv_combined.shape[-1] // self.num_heads // 3
|
||||||
|
|
||||||
qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
|
qkv = qkv_combined.view(B, -1, self.num_heads, split_size * 3)
|
||||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||||
|
|
||||||
query = query.reshape(B, N, self.num_heads, self.head_dim)
|
query = query.reshape(B, N, self.num_heads, self.head_dim)
|
||||||
@ -607,9 +607,9 @@ class HunYuanDiTPlain(nn.Module):
|
|||||||
def forward(self, x, t, context, transformer_options = {}, **kwargs):
|
def forward(self, x, t, context, transformer_options = {}, **kwargs):
|
||||||
|
|
||||||
x = x.movedim(-1, -2)
|
x = x.movedim(-1, -2)
|
||||||
uncond_emb, cond_emb = context.chunk(2, dim = 0)
|
if context.shape[0] >= 2:
|
||||||
|
uncond_emb, cond_emb = context.chunk(2, dim = 0)
|
||||||
context = torch.cat([cond_emb, uncond_emb], dim = 0)
|
context = torch.cat([cond_emb, uncond_emb], dim = 0)
|
||||||
main_condition = context
|
main_condition = context
|
||||||
|
|
||||||
t = 1.0 - t
|
t = 1.0 - t
|
||||||
@ -657,5 +657,8 @@ class HunYuanDiTPlain(nn.Module):
|
|||||||
output = self.final_layer(combined)
|
output = self.final_layer(combined)
|
||||||
output = output.movedim(-2, -1) * (-1.0)
|
output = output.movedim(-2, -1) * (-1.0)
|
||||||
|
|
||||||
cond_emb, uncond_emb = output.chunk(2, dim = 0)
|
if output.shape[0] >= 2:
|
||||||
return torch.cat([uncond_emb, cond_emb])
|
cond_emb, uncond_emb = output.chunk(2, dim = 0)
|
||||||
|
return torch.cat([uncond_emb, cond_emb])
|
||||||
|
else:
|
||||||
|
return output
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user