diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index f67ba84e9..bc36b8998 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -328,7 +328,7 @@ class CrossAttention(nn.Module): kv = torch.cat((k, v), dim=-1) split_size = kv.shape[-1] // self.num_heads // 2 - kv = kv.view(1, -1, self.num_heads, split_size * 2) + kv = kv.view(b, -1, self.num_heads, split_size * 2) k, v = torch.split(kv, split_size, dim=-1) q = q.view(b, s1, self.num_heads, self.head_dim) @@ -398,7 +398,7 @@ class Attention(nn.Module): qkv_combined = torch.cat((query, key, value), dim=-1) split_size = qkv_combined.shape[-1] // self.num_heads // 3 - qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3) + qkv = qkv_combined.view(B, -1, self.num_heads, split_size * 3) query, key, value = torch.split(qkv, split_size, dim=-1) query = query.reshape(B, N, self.num_heads, self.head_dim) @@ -607,9 +607,9 @@ class HunYuanDiTPlain(nn.Module): def forward(self, x, t, context, transformer_options = {}, **kwargs): x = x.movedim(-1, -2) - uncond_emb, cond_emb = context.chunk(2, dim = 0) - - context = torch.cat([cond_emb, uncond_emb], dim = 0) + if context.shape[0] >= 2: + uncond_emb, cond_emb = context.chunk(2, dim = 0) + context = torch.cat([cond_emb, uncond_emb], dim = 0) main_condition = context t = 1.0 - t @@ -657,5 +657,8 @@ class HunYuanDiTPlain(nn.Module): output = self.final_layer(combined) output = output.movedim(-2, -1) * (-1.0) - cond_emb, uncond_emb = output.chunk(2, dim = 0) - return torch.cat([uncond_emb, cond_emb]) + if output.shape[0] >= 2: + cond_emb, uncond_emb = output.chunk(2, dim = 0) + return torch.cat([uncond_emb, cond_emb]) + else: + return output