Trellis2: guard structure shape_rule pruning to CFG batches

This commit is contained in:
John Pollock 2026-04-19 21:38:45 -05:00
parent b443f423b4
commit 70511a9a91

View File

@ -862,12 +862,12 @@ class Trellis2(nn.Module):
out = self.shape2txt(x_st, t_eval, c_eval)
else: # structure
orig_bsz = x.shape[0]
if shape_rule:
if shape_rule and orig_bsz > 1:
half = orig_bsz // 2
x = x[half:]
timestep = timestep[half:] if timestep.shape[0] > 1 else timestep
out = self.structure_model(x, timestep, context if not shape_rule else cond)
if shape_rule:
out = self.structure_model(x, timestep, cond if shape_rule and orig_bsz > 1 else context)
if shape_rule and orig_bsz > 1:
out = out.repeat(2, 1, 1, 1, 1)
if not_struct_mode: