mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-10 08:17:44 +08:00
Backport Hunyuan 3D 2.1 attention batch-size fixes from #13699
CrossAttention.kv.view and Attention.qkv_combined.view both hardcoded batch=1 in the reshape, crashing or silently mis-shaping whenever the actual batch dimension was greater than 1. These were fixed on master in #13699 as part of the same patch that gated the chunk(2) swap, but worksplit-multigpu only picked up the chunk(2) gate. Bring the two view() fixes over so we have parity with master. Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
019261ed96
commit
fd79f22bdf
@ -328,7 +328,7 @@ class CrossAttention(nn.Module):
|
||||
kv = torch.cat((k, v), dim=-1)
|
||||
split_size = kv.shape[-1] // self.num_heads // 2
|
||||
|
||||
kv = kv.view(1, -1, self.num_heads, split_size * 2)
|
||||
kv = kv.view(b, -1, self.num_heads, split_size * 2)
|
||||
k, v = torch.split(kv, split_size, dim=-1)
|
||||
|
||||
q = q.view(b, s1, self.num_heads, self.head_dim)
|
||||
@ -398,7 +398,7 @@ class Attention(nn.Module):
|
||||
qkv_combined = torch.cat((query, key, value), dim=-1)
|
||||
split_size = qkv_combined.shape[-1] // self.num_heads // 3
|
||||
|
||||
qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
|
||||
qkv = qkv_combined.view(B, -1, self.num_heads, split_size * 3)
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
query = query.reshape(B, N, self.num_heads, self.head_dim)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user