Merge pull request #10 from pollockjj/issue_74_extract

Trellis2: fix structure batch pruning under shape_rule
This commit is contained in:
John Pollock 2026-04-19 19:57:29 -07:00 committed by GitHub
commit 45fcf0f9cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -862,12 +862,13 @@ class Trellis2(nn.Module):
out = self.shape2txt(x_st, t_eval, c_eval) out = self.shape2txt(x_st, t_eval, c_eval)
else: # structure else: # structure
orig_bsz = x.shape[0] orig_bsz = x.shape[0]
if shape_rule: if shape_rule and orig_bsz > 1:
x = x[1].unsqueeze(0) half = orig_bsz // 2
timestep = timestep[1].unsqueeze(0) x = x[half:]
out = self.structure_model(x, timestep, context if not shape_rule else cond) timestep = timestep[half:] if timestep.shape[0] > 1 else timestep
if shape_rule: out = self.structure_model(x, timestep, cond if shape_rule and orig_bsz > 1 else context)
out = out.repeat(orig_bsz, 1, 1, 1, 1) if shape_rule and orig_bsz > 1:
out = out.repeat(2, 1, 1, 1, 1)
if not_struct_mode: if not_struct_mode:
out = out.feats out = out.feats