Trellis2: slice cond half of x symmetrically under shape_rule pruning

This commit is contained in:
John Pollock 2026-04-19 21:26:48 -05:00
parent 036d159237
commit b443f423b4

View File

@ -863,11 +863,12 @@ class Trellis2(nn.Module):
else: # structure
orig_bsz = x.shape[0]
if shape_rule:
x = x[1].unsqueeze(0)
timestep = timestep[1].unsqueeze(0)
half = orig_bsz // 2
x = x[half:]
timestep = timestep[half:] if timestep.shape[0] > 1 else timestep
out = self.structure_model(x, timestep, context if not shape_rule else cond)
if shape_rule:
out = out.repeat(orig_bsz, 1, 1, 1, 1)
out = out.repeat(2, 1, 1, 1, 1)
if not_struct_mode:
out = out.feats