diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index f67ba84e9..61d1b3dc6 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -607,9 +607,14 @@ class HunYuanDiTPlain(nn.Module): def forward(self, x, t, context, transformer_options = {}, **kwargs): x = x.movedim(-1, -2) - uncond_emb, cond_emb = context.chunk(2, dim = 0) - context = torch.cat([cond_emb, uncond_emb], dim = 0) + cond_or_uncond = transformer_options.get("cond_or_uncond", []) + swap_cfg_halves = len(cond_or_uncond) == 2 and set(cond_or_uncond) == {0, 1} + + if swap_cfg_halves: + first_half, second_half = context.chunk(2, dim = 0) + context = torch.cat([second_half, first_half], dim = 0) + main_condition = context t = 1.0 - t @@ -657,5 +662,8 @@ class HunYuanDiTPlain(nn.Module): output = self.final_layer(combined) output = output.movedim(-2, -1) * (-1.0) - cond_emb, uncond_emb = output.chunk(2, dim = 0) - return torch.cat([uncond_emb, cond_emb]) + if swap_cfg_halves: + first_half, second_half = output.chunk(2, dim = 0) + output = torch.cat([second_half, first_half], dim = 0) + + return output