simplify and optimize model.forward

This commit is contained in:
Yousef Rafat 2026-05-07 18:47:03 +03:00
parent 81ed835ffb
commit e180d4ad79

View File

@ -779,66 +779,54 @@ class Trellis2(nn.Module):
def forward(self, x, timestep, context, **kwargs):
transformer_options = kwargs.get("transformer_options", {})
timestep = timestep.to(x.dtype)
embeds = kwargs.get("embeds")
if embeds is None:
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
# img2shape.resolution is the latent-grid size, not the input pixel size:
# 32 -> 512px path, 64 -> 1024px path.
uses_1024_conditioning = self.img2shape.resolution == 64
is_1024 = self.img2shape.resolution == 1024
coords = transformer_options.get("coords", None)
coord_counts = transformer_options.get("coord_counts")
coord_counts = transformer_options.get("coord_counts", None)
mode = transformer_options.get("generation_mode", "structure_generation")
is_512_run = False
timestep = timestep.to(self.dtype)
if mode == "shape_generation_512":
is_512_run = True
mode = "shape_generation"
if coords is not None:
x = x.squeeze(-1).transpose(1, 2)
if x.ndim == 4:
x = x.squeeze(-1).transpose(1, 2)
not_struct_mode = True
else:
mode = "structure_generation"
not_struct_mode = False
if uses_1024_conditioning and not_struct_mode and not is_512_run:
if is_1024 and not_struct_mode and not is_512_run:
context = embeds
sigmas = transformer_options.get("sigmas")[0].item()
if sigmas < 1.00001:
timestep *= 1000.0
if context.size(0) > 1:
cond = context.chunk(2)[1]
else:
cond = context
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
dense_out = None
cond_or_uncond = transformer_options.get("cond_or_uncond") or []
def cond_group_indices(batch_groups):
if len(cond_or_uncond) == batch_groups:
cond_groups = [i for i, marker in enumerate(cond_or_uncond) if marker == 0]
if len(cond_groups) > 0:
return cond_groups
return [batch_groups - 1]
if not_struct_mode:
orig_bsz = x.shape[0]
rule = txt_rule if mode == "texture_generation" else shape_rule
logical_batch = coord_counts.shape[0] if coord_counts is not None else 1
if rule and orig_bsz > logical_batch:
batch_groups = orig_bsz // logical_batch
selected_groups = cond_group_indices(batch_groups)
x_groups = x.reshape(batch_groups, logical_batch, *x.shape[1:])
x_eval = x_groups[selected_groups].reshape(-1, *x.shape[1:])
if timestep.shape[0] > 1:
t_groups = timestep.reshape(batch_groups, logical_batch, *timestep.shape[1:])
t_eval = t_groups[selected_groups].reshape(-1, *timestep.shape[1:])
else:
t_eval = timestep
c_groups = context.reshape(batch_groups, logical_batch, *context.shape[1:])
c_eval = c_groups[selected_groups].reshape(-1, *context.shape[1:])
# 1. CFG Bypass Slicing
if rule and orig_bsz > 1:
half = orig_bsz // 2
x_eval = x[half:]
t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep
c_eval = cond
else:
x_eval = x
t_eval = timestep
@ -846,112 +834,45 @@ class Trellis2(nn.Module):
B, N, C = x_eval.shape
# 2. Vectorized SparseTensor Construction (NO FOR LOOPS!)
if mode in ["shape_generation", "texture_generation"]:
if coord_counts is not None:
logical_batch = coord_counts.shape[0]
if B % logical_batch != 0:
raise ValueError(
f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}"
)
if int(coord_counts.sum().item()) != coords.shape[0]:
raise ValueError(
f"Trellis2 coord_counts total {int(coord_counts.sum().item())} does not match coords rows {coords.shape[0]}"
)
batch_ids = coords[:, 0].to(torch.int64)
order = torch.argsort(batch_ids, stable=True)
sorted_coords = coords.index_select(0, order)
sorted_batch_ids = batch_ids.index_select(0, order)
offsets = coord_counts.cumsum(0) - coord_counts
coords_by_batch = []
for i in range(logical_batch):
count = int(coord_counts[i].item())
start = int(offsets[i].item())
coords_i = sorted_coords[start:start + count]
ids_i = sorted_batch_ids[start:start + count]
if coords_i.shape[0] != count or not torch.all(ids_i == i):
raise ValueError(
f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}"
)
coords_by_batch.append(coords_i)
repeat_factor = B // logical_batch
sparse_outs = []
active_coord_counts = []
for rep in range(repeat_factor):
for i in range(logical_batch):
out_index = rep * logical_batch + i
count = int(coord_counts[i].item())
if count > N:
raise ValueError(
f"Trellis2 coord count {count} exceeds latent token dimension {N} for batch {i}"
)
coords_i = coords_by_batch[i].clone()
coords_i[:, 0] = 0
feats_i = x_eval[out_index, :count].clone()
x_st_i = SparseTensor(feats=feats_i, coords=coords_i.to(torch.int32))
t_i = t_eval[out_index].unsqueeze(0).clone() if t_eval.shape[0] > 1 else t_eval
c_i = c_eval[out_index].unsqueeze(0).clone() if c_eval.shape[0] > 1 else c_eval
# Duplicate coords if CFG is active
if B > logical_batch:
c_pos = coords.clone()
c_pos[:, 0] += logical_batch
batched_coords = torch.cat([coords, c_pos], dim=0)
counts_eval = torch.cat([coord_counts, coord_counts], dim=0)
else:
batched_coords = coords
counts_eval = coord_counts
if mode == "shape_generation":
if is_512_run:
sparse_out = self.img2shape_512(x_st_i, t_i, c_i)
else:
sparse_out = self.img2shape(x_st_i, t_i, c_i)
else:
slat = transformer_options.get("shape_slat")
if slat is None:
raise ValueError("shape_slat can't be None")
if slat.ndim == 3:
if slat.shape[0] != logical_batch:
raise ValueError(
f"shape_slat batch {slat.shape[0]} doesn't match coord_counts batch {logical_batch}"
)
if slat.shape[1] < count:
raise ValueError(
f"shape_slat tokens {slat.shape[1]} can't cover coord count {count} for batch {i}"
)
slat_feats = slat[i, :count].to(x_st_i.device)
else:
slat_feats = slat[:count].to(x_st_i.device)
x_st_i = x_st_i.replace(feats=torch.cat([x_st_i.feats, slat_feats], dim=-1))
sparse_out = self.shape2txt(x_st_i, t_i, c_i)
sparse_outs.append(sparse_out.feats)
active_coord_counts.append(count)
out_channels = sparse_outs[0].shape[-1]
padded = sparse_outs[0].new_zeros((B, N, out_channels))
for out_index, (count, feats_i) in enumerate(zip(active_coord_counts, sparse_outs)):
padded[out_index, :count] = feats_i
dense_out = padded.transpose(1, 2).unsqueeze(-1)
elif coords.shape[0] == N:
# Create boolean mask [B, N] to drop the padded zeros instantly
mask = torch.arange(N, device=x.device).unsqueeze(0) < counts_eval.unsqueeze(1)
feats_flat = x_eval[mask]
else:
feats_flat = x_eval.reshape(-1, C)
coords_list = []
coords_list =[]
for i in range(B):
c = coords.clone()
c[:, 0] = i
coords_list.append(c)
batched_coords = torch.cat(coords_list, dim=0)
elif coords.shape[0] == B * N:
feats_flat = x_eval.reshape(-1, C)
batched_coords = coords
else:
raise ValueError(
f"Trellis2 expected coords rows {N} or {B * N}, got {coords.shape[0]}"
)
mask = None
else:
batched_coords = coords
feats_flat = x_eval
mask = None
if dense_out is None:
x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32))
x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32))
if dense_out is not None:
out = dense_out
elif mode == "shape_generation":
if mode == "shape_generation":
if is_512_run:
out = self.img2shape_512(x_st, t_eval, c_eval)
else:
out = self.img2shape(x_st, t_eval, c_eval)
elif mode == "texture_generation":
if self.shape2txt is None:
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
@ -959,96 +880,43 @@ class Trellis2(nn.Module):
if slat is None:
raise ValueError("shape_slat can't be None")
if slat.ndim == 3:
if coord_counts is not None:
logical_batch = coord_counts.shape[0]
if slat.shape[0] != logical_batch:
raise ValueError(
f"shape_slat batch {slat.shape[0]} doesn't match coord_counts batch {logical_batch}"
)
if B % logical_batch != 0:
raise ValueError(
f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}"
)
repeat_factor = B // logical_batch
slat_list = []
for _ in range(repeat_factor):
for i in range(logical_batch):
count = int(coord_counts[i].item())
if slat.shape[1] < count:
raise ValueError(
f"shape_slat tokens {slat.shape[1]} can't cover coord count {count} for batch {i}"
)
slat_list.append(slat[i, :count])
slat_feats_batched = torch.cat(slat_list, dim=0).to(x_st.device)
else:
if slat.shape[0] != B:
raise ValueError(f"shape_slat batch {slat.shape[0]} doesn't match latent batch {B}")
if slat.shape[1] != N:
raise ValueError(f"shape_slat tokens {slat.shape[1]} doesn't match latent tokens {N}")
slat_feats_batched = slat.reshape(B * N, -1).to(x_st.device)
else:
base_slat_feats = slat[:N]
slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device)
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1))
slat_feats = slat.feats
# Duplicate shape context if CFG is active
if coord_counts is not None and B > coord_counts.shape[0]:
slat_feats = torch.cat([slat_feats, slat_feats], dim=0)
elif coord_counts is None:
slat_feats = slat.feats[:N].repeat(B, 1)
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats], dim=-1))
out = self.shape2txt(x_st, t_eval, c_eval)
else: # structure
orig_bsz = x.shape[0]
batch_groups = len(cond_or_uncond) if len(cond_or_uncond) > 0 and orig_bsz % len(cond_or_uncond) == 0 else 1
logical_batch = orig_bsz // batch_groups
if logical_batch > 1:
x_groups = x.reshape(batch_groups, logical_batch, *x.shape[1:])
if timestep.shape[0] > 1:
t_groups = timestep.reshape(batch_groups, logical_batch, *timestep.shape[1:])
else:
t_groups = timestep
c_groups = context.reshape(batch_groups, logical_batch, *context.shape[1:])
if shape_rule and batch_groups > 1:
selected_group_indices = cond_group_indices(batch_groups)
else:
selected_group_indices = list(range(batch_groups))
out_groups = []
for sample_index in range(logical_batch):
if shape_rule and batch_groups > 1:
x_i = x_groups[selected_group_indices, sample_index]
if timestep.shape[0] > 1:
t_i = t_groups[selected_group_indices, sample_index]
else:
t_i = timestep
c_i = c_groups[selected_group_indices, sample_index]
else:
x_i = x_groups[selected_group_indices, sample_index]
if timestep.shape[0] > 1:
t_i = t_groups[selected_group_indices, sample_index]
else:
t_i = timestep
c_i = c_groups[selected_group_indices, sample_index]
out_groups.append(self.structure_model(x_i, t_i, c_i))
out = out_groups[0].new_zeros((orig_bsz, *out_groups[0].shape[1:]))
for sample_index, out_sample in enumerate(out_groups):
if shape_rule and batch_groups > 1:
repeated = out_sample[0]
for group_index in range(batch_groups):
out[group_index * logical_batch + sample_index] = repeated
else:
for local_group_index, group_index in enumerate(selected_group_indices):
out[group_index * logical_batch + sample_index] = out_sample[local_group_index]
if shape_rule and orig_bsz > 1:
half = orig_bsz // 2
x_eval = x[half:]
t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep
out = self.structure_model(x_eval, t_eval, cond)
out = out.repeat(2, 1, 1, 1, 1)
else:
if shape_rule and orig_bsz > 1:
half = orig_bsz // 2
x = x[half:]
timestep = timestep[half:] if timestep.shape[0] > 1 else timestep
out = self.structure_model(x, timestep, cond if shape_rule and orig_bsz > 1 else context)
if shape_rule and orig_bsz > 1:
out = out.repeat(2, 1, 1, 1, 1)
out = self.structure_model(x, timestep, context)
# ==================================================
# RE-PAD AND FORMAT OUTPUT
# ==================================================
if not_struct_mode:
if dense_out is None:
out = out.feats
out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
if rule and orig_bsz > B:
out = out.repeat(orig_bsz // B, 1, 1, 1)
if mask is not None:
# Instantly scatter the valid tokens back into a padded rectangular tensor
padded_out = torch.zeros((B, N, out.feats.shape[-1]), device=x.device, dtype=out.feats.dtype)
padded_out[mask] = out.feats
out_tensor = padded_out.transpose(1, 2).unsqueeze(-1)
else:
out_tensor = out.feats.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
if rule and orig_bsz > 1:
out_tensor = out_tensor.repeat(2, 1, 1, 1)
return out_tensor
#else:
# out = torch.nn.functional.pad(out, (0, 0, 0, 0, 0, 0, 24, 0))
return out