From e180d4ad799f533b82bb2e8e83f977317d458ff9 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 7 May 2026 18:47:03 +0300 Subject: [PATCH] simplify and optimize model.forward --- comfy/ldm/trellis2/model.py | 272 ++++++++++-------------------------- 1 file changed, 70 insertions(+), 202 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index e8ed39aed..a54e4ca94 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -779,66 +779,54 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): transformer_options = kwargs.get("transformer_options", {}) + timestep = timestep.to(x.dtype) embeds = kwargs.get("embeds") if embeds is None: raise ValueError("Trellis2.forward requires 'embeds' in kwargs") - # img2shape.resolution is the latent-grid size, not the input pixel size: - # 32 -> 512px path, 64 -> 1024px path. - uses_1024_conditioning = self.img2shape.resolution == 64 + + is_1024 = self.img2shape.resolution == 1024 coords = transformer_options.get("coords", None) - coord_counts = transformer_options.get("coord_counts") + coord_counts = transformer_options.get("coord_counts", None) mode = transformer_options.get("generation_mode", "structure_generation") + is_512_run = False - timestep = timestep.to(self.dtype) if mode == "shape_generation_512": is_512_run = True mode = "shape_generation" + if coords is not None: - x = x.squeeze(-1).transpose(1, 2) + if x.ndim == 4: + x = x.squeeze(-1).transpose(1, 2) not_struct_mode = True else: mode = "structure_generation" not_struct_mode = False - if uses_1024_conditioning and not_struct_mode and not is_512_run: + if is_1024 and not_struct_mode and not is_512_run: context = embeds sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: timestep *= 1000.0 + if context.size(0) > 1: cond = context.chunk(2)[1] else: cond = context + shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] - dense_out = None - cond_or_uncond = transformer_options.get("cond_or_uncond") or [] - - def cond_group_indices(batch_groups): - if len(cond_or_uncond) == batch_groups: - cond_groups = [i for i, marker in enumerate(cond_or_uncond) if marker == 0] - if len(cond_groups) > 0: - return cond_groups - return [batch_groups - 1] if not_struct_mode: orig_bsz = x.shape[0] rule = txt_rule if mode == "texture_generation" else shape_rule - logical_batch = coord_counts.shape[0] if coord_counts is not None else 1 - if rule and orig_bsz > logical_batch: - batch_groups = orig_bsz // logical_batch - selected_groups = cond_group_indices(batch_groups) - x_groups = x.reshape(batch_groups, logical_batch, *x.shape[1:]) - x_eval = x_groups[selected_groups].reshape(-1, *x.shape[1:]) - if timestep.shape[0] > 1: - t_groups = timestep.reshape(batch_groups, logical_batch, *timestep.shape[1:]) - t_eval = t_groups[selected_groups].reshape(-1, *timestep.shape[1:]) - else: - t_eval = timestep - c_groups = context.reshape(batch_groups, logical_batch, *context.shape[1:]) - c_eval = c_groups[selected_groups].reshape(-1, *context.shape[1:]) + # 1. CFG Bypass Slicing + if rule and orig_bsz > 1: + half = orig_bsz // 2 + x_eval = x[half:] + t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep + c_eval = cond else: x_eval = x t_eval = timestep @@ -846,112 +834,45 @@ class Trellis2(nn.Module): B, N, C = x_eval.shape + # 2. Vectorized SparseTensor Construction (NO FOR LOOPS!) if mode in ["shape_generation", "texture_generation"]: if coord_counts is not None: logical_batch = coord_counts.shape[0] - if B % logical_batch != 0: - raise ValueError( - f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}" - ) - if int(coord_counts.sum().item()) != coords.shape[0]: - raise ValueError( - f"Trellis2 coord_counts total {int(coord_counts.sum().item())} does not match coords rows {coords.shape[0]}" - ) - batch_ids = coords[:, 0].to(torch.int64) - order = torch.argsort(batch_ids, stable=True) - sorted_coords = coords.index_select(0, order) - sorted_batch_ids = batch_ids.index_select(0, order) - offsets = coord_counts.cumsum(0) - coord_counts - coords_by_batch = [] - for i in range(logical_batch): - count = int(coord_counts[i].item()) - start = int(offsets[i].item()) - coords_i = sorted_coords[start:start + count] - ids_i = sorted_batch_ids[start:start + count] - if coords_i.shape[0] != count or not torch.all(ids_i == i): - raise ValueError( - f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}" - ) - coords_by_batch.append(coords_i) - repeat_factor = B // logical_batch - sparse_outs = [] - active_coord_counts = [] - for rep in range(repeat_factor): - for i in range(logical_batch): - out_index = rep * logical_batch + i - count = int(coord_counts[i].item()) - if count > N: - raise ValueError( - f"Trellis2 coord count {count} exceeds latent token dimension {N} for batch {i}" - ) - coords_i = coords_by_batch[i].clone() - coords_i[:, 0] = 0 - feats_i = x_eval[out_index, :count].clone() - x_st_i = SparseTensor(feats=feats_i, coords=coords_i.to(torch.int32)) - t_i = t_eval[out_index].unsqueeze(0).clone() if t_eval.shape[0] > 1 else t_eval - c_i = c_eval[out_index].unsqueeze(0).clone() if c_eval.shape[0] > 1 else c_eval + # Duplicate coords if CFG is active + if B > logical_batch: + c_pos = coords.clone() + c_pos[:, 0] += logical_batch + batched_coords = torch.cat([coords, c_pos], dim=0) + counts_eval = torch.cat([coord_counts, coord_counts], dim=0) + else: + batched_coords = coords + counts_eval = coord_counts - if mode == "shape_generation": - if is_512_run: - sparse_out = self.img2shape_512(x_st_i, t_i, c_i) - else: - sparse_out = self.img2shape(x_st_i, t_i, c_i) - else: - slat = transformer_options.get("shape_slat") - if slat is None: - raise ValueError("shape_slat can't be None") - if slat.ndim == 3: - if slat.shape[0] != logical_batch: - raise ValueError( - f"shape_slat batch {slat.shape[0]} doesn't match coord_counts batch {logical_batch}" - ) - if slat.shape[1] < count: - raise ValueError( - f"shape_slat tokens {slat.shape[1]} can't cover coord count {count} for batch {i}" - ) - slat_feats = slat[i, :count].to(x_st_i.device) - else: - slat_feats = slat[:count].to(x_st_i.device) - x_st_i = x_st_i.replace(feats=torch.cat([x_st_i.feats, slat_feats], dim=-1)) - sparse_out = self.shape2txt(x_st_i, t_i, c_i) - - sparse_outs.append(sparse_out.feats) - active_coord_counts.append(count) - - out_channels = sparse_outs[0].shape[-1] - padded = sparse_outs[0].new_zeros((B, N, out_channels)) - for out_index, (count, feats_i) in enumerate(zip(active_coord_counts, sparse_outs)): - padded[out_index, :count] = feats_i - dense_out = padded.transpose(1, 2).unsqueeze(-1) - elif coords.shape[0] == N: + # Create boolean mask [B, N] to drop the padded zeros instantly + mask = torch.arange(N, device=x.device).unsqueeze(0) < counts_eval.unsqueeze(1) + feats_flat = x_eval[mask] + else: feats_flat = x_eval.reshape(-1, C) - coords_list = [] + coords_list =[] for i in range(B): c = coords.clone() c[:, 0] = i coords_list.append(c) batched_coords = torch.cat(coords_list, dim=0) - elif coords.shape[0] == B * N: - feats_flat = x_eval.reshape(-1, C) - batched_coords = coords - else: - raise ValueError( - f"Trellis2 expected coords rows {N} or {B * N}, got {coords.shape[0]}" - ) + mask = None else: batched_coords = coords feats_flat = x_eval + mask = None - if dense_out is None: - x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) + x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) - if dense_out is not None: - out = dense_out - elif mode == "shape_generation": + if mode == "shape_generation": if is_512_run: out = self.img2shape_512(x_st, t_eval, c_eval) else: out = self.img2shape(x_st, t_eval, c_eval) + elif mode == "texture_generation": if self.shape2txt is None: raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") @@ -959,96 +880,43 @@ class Trellis2(nn.Module): if slat is None: raise ValueError("shape_slat can't be None") - if slat.ndim == 3: - if coord_counts is not None: - logical_batch = coord_counts.shape[0] - if slat.shape[0] != logical_batch: - raise ValueError( - f"shape_slat batch {slat.shape[0]} doesn't match coord_counts batch {logical_batch}" - ) - if B % logical_batch != 0: - raise ValueError( - f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}" - ) - repeat_factor = B // logical_batch - slat_list = [] - for _ in range(repeat_factor): - for i in range(logical_batch): - count = int(coord_counts[i].item()) - if slat.shape[1] < count: - raise ValueError( - f"shape_slat tokens {slat.shape[1]} can't cover coord count {count} for batch {i}" - ) - slat_list.append(slat[i, :count]) - slat_feats_batched = torch.cat(slat_list, dim=0).to(x_st.device) - else: - if slat.shape[0] != B: - raise ValueError(f"shape_slat batch {slat.shape[0]} doesn't match latent batch {B}") - if slat.shape[1] != N: - raise ValueError(f"shape_slat tokens {slat.shape[1]} doesn't match latent tokens {N}") - slat_feats_batched = slat.reshape(B * N, -1).to(x_st.device) - else: - base_slat_feats = slat[:N] - slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device) - x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1)) + slat_feats = slat.feats + # Duplicate shape context if CFG is active + if coord_counts is not None and B > coord_counts.shape[0]: + slat_feats = torch.cat([slat_feats, slat_feats], dim=0) + elif coord_counts is None: + slat_feats = slat.feats[:N].repeat(B, 1) + + x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats], dim=-1)) out = self.shape2txt(x_st, t_eval, c_eval) + else: # structure orig_bsz = x.shape[0] - batch_groups = len(cond_or_uncond) if len(cond_or_uncond) > 0 and orig_bsz % len(cond_or_uncond) == 0 else 1 - logical_batch = orig_bsz // batch_groups - if logical_batch > 1: - x_groups = x.reshape(batch_groups, logical_batch, *x.shape[1:]) - if timestep.shape[0] > 1: - t_groups = timestep.reshape(batch_groups, logical_batch, *timestep.shape[1:]) - else: - t_groups = timestep - c_groups = context.reshape(batch_groups, logical_batch, *context.shape[1:]) - - if shape_rule and batch_groups > 1: - selected_group_indices = cond_group_indices(batch_groups) - else: - selected_group_indices = list(range(batch_groups)) - - out_groups = [] - for sample_index in range(logical_batch): - if shape_rule and batch_groups > 1: - x_i = x_groups[selected_group_indices, sample_index] - if timestep.shape[0] > 1: - t_i = t_groups[selected_group_indices, sample_index] - else: - t_i = timestep - c_i = c_groups[selected_group_indices, sample_index] - else: - x_i = x_groups[selected_group_indices, sample_index] - if timestep.shape[0] > 1: - t_i = t_groups[selected_group_indices, sample_index] - else: - t_i = timestep - c_i = c_groups[selected_group_indices, sample_index] - out_groups.append(self.structure_model(x_i, t_i, c_i)) - - out = out_groups[0].new_zeros((orig_bsz, *out_groups[0].shape[1:])) - for sample_index, out_sample in enumerate(out_groups): - if shape_rule and batch_groups > 1: - repeated = out_sample[0] - for group_index in range(batch_groups): - out[group_index * logical_batch + sample_index] = repeated - else: - for local_group_index, group_index in enumerate(selected_group_indices): - out[group_index * logical_batch + sample_index] = out_sample[local_group_index] + if shape_rule and orig_bsz > 1: + half = orig_bsz // 2 + x_eval = x[half:] + t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep + out = self.structure_model(x_eval, t_eval, cond) + out = out.repeat(2, 1, 1, 1, 1) else: - if shape_rule and orig_bsz > 1: - half = orig_bsz // 2 - x = x[half:] - timestep = timestep[half:] if timestep.shape[0] > 1 else timestep - out = self.structure_model(x, timestep, cond if shape_rule and orig_bsz > 1 else context) - if shape_rule and orig_bsz > 1: - out = out.repeat(2, 1, 1, 1, 1) + out = self.structure_model(x, timestep, context) + # ================================================== + # RE-PAD AND FORMAT OUTPUT + # ================================================== if not_struct_mode: - if dense_out is None: - out = out.feats - out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) - if rule and orig_bsz > B: - out = out.repeat(orig_bsz // B, 1, 1, 1) + if mask is not None: + # Instantly scatter the valid tokens back into a padded rectangular tensor + padded_out = torch.zeros((B, N, out.feats.shape[-1]), device=x.device, dtype=out.feats.dtype) + padded_out[mask] = out.feats + out_tensor = padded_out.transpose(1, 2).unsqueeze(-1) + else: + out_tensor = out.feats.view(B, N, -1).transpose(1, 2).unsqueeze(-1) + + if rule and orig_bsz > 1: + out_tensor = out_tensor.repeat(2, 1, 1, 1) + return out_tensor + #else: + # out = torch.nn.functional.pad(out, (0, 0, 0, 0, 0, 0, 24, 0)) + return out