From 43b0dab91e7074143afc7cc10c150d25310b8d82 Mon Sep 17 00:00:00 2001 From: Kivylius Date: Mon, 4 May 2026 13:04:57 +0100 Subject: [PATCH] fix: Hunyuan3D 2.1 batch size crashes in attention and forward pass - CrossAttention.forward: hardcoded `1` in kv.view() replaced with actual batch size `b` - Attention.forward: hardcoded `1` in qkv_combined.view() replaced with actual batch size `B` - HunYuanDiTPlain.forward: context.chunk(2) and output.chunk(2) now guarded with shape[0] >= 2 check to avoid crash when running without negative conditioning Fixes Comfy-Org/ComfyUI#10142 --- comfy/ldm/hunyuan3dv2_1/hunyuandit.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index f67ba84e9..bc36b8998 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -328,7 +328,7 @@ class CrossAttention(nn.Module): kv = torch.cat((k, v), dim=-1) split_size = kv.shape[-1] // self.num_heads // 2 - kv = kv.view(1, -1, self.num_heads, split_size * 2) + kv = kv.view(b, -1, self.num_heads, split_size * 2) k, v = torch.split(kv, split_size, dim=-1) q = q.view(b, s1, self.num_heads, self.head_dim) @@ -398,7 +398,7 @@ class Attention(nn.Module): qkv_combined = torch.cat((query, key, value), dim=-1) split_size = qkv_combined.shape[-1] // self.num_heads // 3 - qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3) + qkv = qkv_combined.view(B, -1, self.num_heads, split_size * 3) query, key, value = torch.split(qkv, split_size, dim=-1) query = query.reshape(B, N, self.num_heads, self.head_dim) @@ -607,9 +607,9 @@ class HunYuanDiTPlain(nn.Module): def forward(self, x, t, context, transformer_options = {}, **kwargs): x = x.movedim(-1, -2) - uncond_emb, cond_emb = context.chunk(2, dim = 0) - - context = torch.cat([cond_emb, uncond_emb], dim = 0) + if context.shape[0] >= 2: + uncond_emb, cond_emb = context.chunk(2, dim = 0) + context = torch.cat([cond_emb, uncond_emb], dim = 0) main_condition = context t = 1.0 - t @@ -657,5 +657,8 @@ class HunYuanDiTPlain(nn.Module): output = self.final_layer(combined) output = output.movedim(-2, -1) * (-1.0) - cond_emb, uncond_emb = output.chunk(2, dim = 0) - return torch.cat([uncond_emb, cond_emb]) + if output.shape[0] >= 2: + cond_emb, uncond_emb = output.chunk(2, dim = 0) + return torch.cat([uncond_emb, cond_emb]) + else: + return output