From fd79f22bdfceaf58fe6bdd6fb65a267b14710fec Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Thu, 21 May 2026 12:17:24 -0700 Subject: [PATCH] Backport Hunyuan 3D 2.1 attention batch-size fixes from #13699 CrossAttention.kv.view and Attention.qkv_combined.view both hardcoded batch=1 in the reshape, crashing or silently mis-shaping whenever the actual batch dimension was greater than 1. These were fixed on master in #13699 as part of the same patch that gated the chunk(2) swap, but worksplit-multigpu only picked up the chunk(2) gate. Bring the two view() fixes over so we have parity with master. Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp --- comfy/ldm/hunyuan3dv2_1/hunyuandit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index cb260e0a8..4e4819fe3 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -328,7 +328,7 @@ class CrossAttention(nn.Module): kv = torch.cat((k, v), dim=-1) split_size = kv.shape[-1] // self.num_heads // 2 - kv = kv.view(1, -1, self.num_heads, split_size * 2) + kv = kv.view(b, -1, self.num_heads, split_size * 2) k, v = torch.split(kv, split_size, dim=-1) q = q.view(b, s1, self.num_heads, self.head_dim) @@ -398,7 +398,7 @@ class Attention(nn.Module): qkv_combined = torch.cat((query, key, value), dim=-1) split_size = qkv_combined.shape[-1] // self.num_heads // 3 - qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3) + qkv = qkv_combined.view(B, -1, self.num_heads, split_size * 3) query, key, value = torch.split(qkv, split_size, dim=-1) query = query.reshape(B, N, self.num_heads, self.head_dim)