From 5fc84518c0094e838c9a82af7d2765a85b4a1d4c Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 19 Apr 2026 23:28:30 -0700 Subject: [PATCH] Fix Hunyuan 3D 2.1 multi-GPU worksplit: use cond_or_uncond instead of hardcoded chunk(2) Amp-Thread-ID: https://ampcode.com/threads/T-019da964-2cc8-77f9-9aae-23f65da233db Co-authored-by: Amp --- comfy/ldm/hunyuan3dv2_1/hunyuandit.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index f67ba84e9..61d1b3dc6 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -607,9 +607,14 @@ class HunYuanDiTPlain(nn.Module): def forward(self, x, t, context, transformer_options = {}, **kwargs): x = x.movedim(-1, -2) - uncond_emb, cond_emb = context.chunk(2, dim = 0) - context = torch.cat([cond_emb, uncond_emb], dim = 0) + cond_or_uncond = transformer_options.get("cond_or_uncond", []) + swap_cfg_halves = len(cond_or_uncond) == 2 and set(cond_or_uncond) == {0, 1} + + if swap_cfg_halves: + first_half, second_half = context.chunk(2, dim = 0) + context = torch.cat([second_half, first_half], dim = 0) + main_condition = context t = 1.0 - t @@ -657,5 +662,8 @@ class HunYuanDiTPlain(nn.Module): output = self.final_layer(combined) output = output.movedim(-2, -1) * (-1.0) - cond_emb, uncond_emb = output.chunk(2, dim = 0) - return torch.cat([uncond_emb, cond_emb]) + if swap_cfg_halves: + first_half, second_half = output.chunk(2, dim = 0) + output = torch.cat([second_half, first_half], dim = 0) + + return output