mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 01:02:56 +08:00
fix: Hunyuan3D 2.1 batch size crashes in attention and forward pass
- CrossAttention.forward: hardcoded `1` in kv.view() replaced with actual batch size `b` - Attention.forward: hardcoded `1` in qkv_combined.view() replaced with actual batch size `B` - HunYuanDiTPlain.forward: context.chunk(2) and output.chunk(2) now guarded with shape[0] >= 2 check to avoid crash when running without negative conditioning Fixes Comfy-Org/ComfyUI#10142
This commit is contained in:
parent
f3ea976cba
commit
43b0dab91e
@ -328,7 +328,7 @@ class CrossAttention(nn.Module):
|
||||
kv = torch.cat((k, v), dim=-1)
|
||||
split_size = kv.shape[-1] // self.num_heads // 2
|
||||
|
||||
kv = kv.view(1, -1, self.num_heads, split_size * 2)
|
||||
kv = kv.view(b, -1, self.num_heads, split_size * 2)
|
||||
k, v = torch.split(kv, split_size, dim=-1)
|
||||
|
||||
q = q.view(b, s1, self.num_heads, self.head_dim)
|
||||
@ -398,7 +398,7 @@ class Attention(nn.Module):
|
||||
qkv_combined = torch.cat((query, key, value), dim=-1)
|
||||
split_size = qkv_combined.shape[-1] // self.num_heads // 3
|
||||
|
||||
qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
|
||||
qkv = qkv_combined.view(B, -1, self.num_heads, split_size * 3)
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
query = query.reshape(B, N, self.num_heads, self.head_dim)
|
||||
@ -607,9 +607,9 @@ class HunYuanDiTPlain(nn.Module):
|
||||
def forward(self, x, t, context, transformer_options = {}, **kwargs):
|
||||
|
||||
x = x.movedim(-1, -2)
|
||||
uncond_emb, cond_emb = context.chunk(2, dim = 0)
|
||||
|
||||
context = torch.cat([cond_emb, uncond_emb], dim = 0)
|
||||
if context.shape[0] >= 2:
|
||||
uncond_emb, cond_emb = context.chunk(2, dim = 0)
|
||||
context = torch.cat([cond_emb, uncond_emb], dim = 0)
|
||||
main_condition = context
|
||||
|
||||
t = 1.0 - t
|
||||
@ -657,5 +657,8 @@ class HunYuanDiTPlain(nn.Module):
|
||||
output = self.final_layer(combined)
|
||||
output = output.movedim(-2, -1) * (-1.0)
|
||||
|
||||
cond_emb, uncond_emb = output.chunk(2, dim = 0)
|
||||
return torch.cat([uncond_emb, cond_emb])
|
||||
if output.shape[0] >= 2:
|
||||
cond_emb, uncond_emb = output.chunk(2, dim = 0)
|
||||
return torch.cat([uncond_emb, cond_emb])
|
||||
else:
|
||||
return output
|
||||
|
||||
Loading…
Reference in New Issue
Block a user