mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
simplify and optimize model.forward
This commit is contained in:
parent
81ed835ffb
commit
e180d4ad79
@ -779,66 +779,54 @@ class Trellis2(nn.Module):
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
timestep = timestep.to(x.dtype)
|
||||
embeds = kwargs.get("embeds")
|
||||
if embeds is None:
|
||||
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
|
||||
# img2shape.resolution is the latent-grid size, not the input pixel size:
|
||||
# 32 -> 512px path, 64 -> 1024px path.
|
||||
uses_1024_conditioning = self.img2shape.resolution == 64
|
||||
|
||||
is_1024 = self.img2shape.resolution == 1024
|
||||
coords = transformer_options.get("coords", None)
|
||||
coord_counts = transformer_options.get("coord_counts")
|
||||
coord_counts = transformer_options.get("coord_counts", None)
|
||||
mode = transformer_options.get("generation_mode", "structure_generation")
|
||||
|
||||
is_512_run = False
|
||||
timestep = timestep.to(self.dtype)
|
||||
if mode == "shape_generation_512":
|
||||
is_512_run = True
|
||||
mode = "shape_generation"
|
||||
|
||||
if coords is not None:
|
||||
x = x.squeeze(-1).transpose(1, 2)
|
||||
if x.ndim == 4:
|
||||
x = x.squeeze(-1).transpose(1, 2)
|
||||
not_struct_mode = True
|
||||
else:
|
||||
mode = "structure_generation"
|
||||
not_struct_mode = False
|
||||
|
||||
if uses_1024_conditioning and not_struct_mode and not is_512_run:
|
||||
if is_1024 and not_struct_mode and not is_512_run:
|
||||
context = embeds
|
||||
|
||||
sigmas = transformer_options.get("sigmas")[0].item()
|
||||
if sigmas < 1.00001:
|
||||
timestep *= 1000.0
|
||||
|
||||
if context.size(0) > 1:
|
||||
cond = context.chunk(2)[1]
|
||||
else:
|
||||
cond = context
|
||||
|
||||
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
||||
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
||||
dense_out = None
|
||||
cond_or_uncond = transformer_options.get("cond_or_uncond") or []
|
||||
|
||||
def cond_group_indices(batch_groups):
|
||||
if len(cond_or_uncond) == batch_groups:
|
||||
cond_groups = [i for i, marker in enumerate(cond_or_uncond) if marker == 0]
|
||||
if len(cond_groups) > 0:
|
||||
return cond_groups
|
||||
return [batch_groups - 1]
|
||||
|
||||
if not_struct_mode:
|
||||
orig_bsz = x.shape[0]
|
||||
rule = txt_rule if mode == "texture_generation" else shape_rule
|
||||
|
||||
logical_batch = coord_counts.shape[0] if coord_counts is not None else 1
|
||||
if rule and orig_bsz > logical_batch:
|
||||
batch_groups = orig_bsz // logical_batch
|
||||
selected_groups = cond_group_indices(batch_groups)
|
||||
x_groups = x.reshape(batch_groups, logical_batch, *x.shape[1:])
|
||||
x_eval = x_groups[selected_groups].reshape(-1, *x.shape[1:])
|
||||
if timestep.shape[0] > 1:
|
||||
t_groups = timestep.reshape(batch_groups, logical_batch, *timestep.shape[1:])
|
||||
t_eval = t_groups[selected_groups].reshape(-1, *timestep.shape[1:])
|
||||
else:
|
||||
t_eval = timestep
|
||||
c_groups = context.reshape(batch_groups, logical_batch, *context.shape[1:])
|
||||
c_eval = c_groups[selected_groups].reshape(-1, *context.shape[1:])
|
||||
# 1. CFG Bypass Slicing
|
||||
if rule and orig_bsz > 1:
|
||||
half = orig_bsz // 2
|
||||
x_eval = x[half:]
|
||||
t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep
|
||||
c_eval = cond
|
||||
else:
|
||||
x_eval = x
|
||||
t_eval = timestep
|
||||
@ -846,112 +834,45 @@ class Trellis2(nn.Module):
|
||||
|
||||
B, N, C = x_eval.shape
|
||||
|
||||
# 2. Vectorized SparseTensor Construction (NO FOR LOOPS!)
|
||||
if mode in ["shape_generation", "texture_generation"]:
|
||||
if coord_counts is not None:
|
||||
logical_batch = coord_counts.shape[0]
|
||||
if B % logical_batch != 0:
|
||||
raise ValueError(
|
||||
f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}"
|
||||
)
|
||||
if int(coord_counts.sum().item()) != coords.shape[0]:
|
||||
raise ValueError(
|
||||
f"Trellis2 coord_counts total {int(coord_counts.sum().item())} does not match coords rows {coords.shape[0]}"
|
||||
)
|
||||
batch_ids = coords[:, 0].to(torch.int64)
|
||||
order = torch.argsort(batch_ids, stable=True)
|
||||
sorted_coords = coords.index_select(0, order)
|
||||
sorted_batch_ids = batch_ids.index_select(0, order)
|
||||
offsets = coord_counts.cumsum(0) - coord_counts
|
||||
coords_by_batch = []
|
||||
for i in range(logical_batch):
|
||||
count = int(coord_counts[i].item())
|
||||
start = int(offsets[i].item())
|
||||
coords_i = sorted_coords[start:start + count]
|
||||
ids_i = sorted_batch_ids[start:start + count]
|
||||
if coords_i.shape[0] != count or not torch.all(ids_i == i):
|
||||
raise ValueError(
|
||||
f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}"
|
||||
)
|
||||
coords_by_batch.append(coords_i)
|
||||
repeat_factor = B // logical_batch
|
||||
sparse_outs = []
|
||||
active_coord_counts = []
|
||||
for rep in range(repeat_factor):
|
||||
for i in range(logical_batch):
|
||||
out_index = rep * logical_batch + i
|
||||
count = int(coord_counts[i].item())
|
||||
if count > N:
|
||||
raise ValueError(
|
||||
f"Trellis2 coord count {count} exceeds latent token dimension {N} for batch {i}"
|
||||
)
|
||||
coords_i = coords_by_batch[i].clone()
|
||||
coords_i[:, 0] = 0
|
||||
feats_i = x_eval[out_index, :count].clone()
|
||||
x_st_i = SparseTensor(feats=feats_i, coords=coords_i.to(torch.int32))
|
||||
t_i = t_eval[out_index].unsqueeze(0).clone() if t_eval.shape[0] > 1 else t_eval
|
||||
c_i = c_eval[out_index].unsqueeze(0).clone() if c_eval.shape[0] > 1 else c_eval
|
||||
# Duplicate coords if CFG is active
|
||||
if B > logical_batch:
|
||||
c_pos = coords.clone()
|
||||
c_pos[:, 0] += logical_batch
|
||||
batched_coords = torch.cat([coords, c_pos], dim=0)
|
||||
counts_eval = torch.cat([coord_counts, coord_counts], dim=0)
|
||||
else:
|
||||
batched_coords = coords
|
||||
counts_eval = coord_counts
|
||||
|
||||
if mode == "shape_generation":
|
||||
if is_512_run:
|
||||
sparse_out = self.img2shape_512(x_st_i, t_i, c_i)
|
||||
else:
|
||||
sparse_out = self.img2shape(x_st_i, t_i, c_i)
|
||||
else:
|
||||
slat = transformer_options.get("shape_slat")
|
||||
if slat is None:
|
||||
raise ValueError("shape_slat can't be None")
|
||||
if slat.ndim == 3:
|
||||
if slat.shape[0] != logical_batch:
|
||||
raise ValueError(
|
||||
f"shape_slat batch {slat.shape[0]} doesn't match coord_counts batch {logical_batch}"
|
||||
)
|
||||
if slat.shape[1] < count:
|
||||
raise ValueError(
|
||||
f"shape_slat tokens {slat.shape[1]} can't cover coord count {count} for batch {i}"
|
||||
)
|
||||
slat_feats = slat[i, :count].to(x_st_i.device)
|
||||
else:
|
||||
slat_feats = slat[:count].to(x_st_i.device)
|
||||
x_st_i = x_st_i.replace(feats=torch.cat([x_st_i.feats, slat_feats], dim=-1))
|
||||
sparse_out = self.shape2txt(x_st_i, t_i, c_i)
|
||||
|
||||
sparse_outs.append(sparse_out.feats)
|
||||
active_coord_counts.append(count)
|
||||
|
||||
out_channels = sparse_outs[0].shape[-1]
|
||||
padded = sparse_outs[0].new_zeros((B, N, out_channels))
|
||||
for out_index, (count, feats_i) in enumerate(zip(active_coord_counts, sparse_outs)):
|
||||
padded[out_index, :count] = feats_i
|
||||
dense_out = padded.transpose(1, 2).unsqueeze(-1)
|
||||
elif coords.shape[0] == N:
|
||||
# Create boolean mask [B, N] to drop the padded zeros instantly
|
||||
mask = torch.arange(N, device=x.device).unsqueeze(0) < counts_eval.unsqueeze(1)
|
||||
feats_flat = x_eval[mask]
|
||||
else:
|
||||
feats_flat = x_eval.reshape(-1, C)
|
||||
coords_list = []
|
||||
coords_list =[]
|
||||
for i in range(B):
|
||||
c = coords.clone()
|
||||
c[:, 0] = i
|
||||
coords_list.append(c)
|
||||
batched_coords = torch.cat(coords_list, dim=0)
|
||||
elif coords.shape[0] == B * N:
|
||||
feats_flat = x_eval.reshape(-1, C)
|
||||
batched_coords = coords
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Trellis2 expected coords rows {N} or {B * N}, got {coords.shape[0]}"
|
||||
)
|
||||
mask = None
|
||||
else:
|
||||
batched_coords = coords
|
||||
feats_flat = x_eval
|
||||
mask = None
|
||||
|
||||
if dense_out is None:
|
||||
x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32))
|
||||
x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32))
|
||||
|
||||
if dense_out is not None:
|
||||
out = dense_out
|
||||
elif mode == "shape_generation":
|
||||
if mode == "shape_generation":
|
||||
if is_512_run:
|
||||
out = self.img2shape_512(x_st, t_eval, c_eval)
|
||||
else:
|
||||
out = self.img2shape(x_st, t_eval, c_eval)
|
||||
|
||||
elif mode == "texture_generation":
|
||||
if self.shape2txt is None:
|
||||
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
|
||||
@ -959,96 +880,43 @@ class Trellis2(nn.Module):
|
||||
if slat is None:
|
||||
raise ValueError("shape_slat can't be None")
|
||||
|
||||
if slat.ndim == 3:
|
||||
if coord_counts is not None:
|
||||
logical_batch = coord_counts.shape[0]
|
||||
if slat.shape[0] != logical_batch:
|
||||
raise ValueError(
|
||||
f"shape_slat batch {slat.shape[0]} doesn't match coord_counts batch {logical_batch}"
|
||||
)
|
||||
if B % logical_batch != 0:
|
||||
raise ValueError(
|
||||
f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}"
|
||||
)
|
||||
repeat_factor = B // logical_batch
|
||||
slat_list = []
|
||||
for _ in range(repeat_factor):
|
||||
for i in range(logical_batch):
|
||||
count = int(coord_counts[i].item())
|
||||
if slat.shape[1] < count:
|
||||
raise ValueError(
|
||||
f"shape_slat tokens {slat.shape[1]} can't cover coord count {count} for batch {i}"
|
||||
)
|
||||
slat_list.append(slat[i, :count])
|
||||
slat_feats_batched = torch.cat(slat_list, dim=0).to(x_st.device)
|
||||
else:
|
||||
if slat.shape[0] != B:
|
||||
raise ValueError(f"shape_slat batch {slat.shape[0]} doesn't match latent batch {B}")
|
||||
if slat.shape[1] != N:
|
||||
raise ValueError(f"shape_slat tokens {slat.shape[1]} doesn't match latent tokens {N}")
|
||||
slat_feats_batched = slat.reshape(B * N, -1).to(x_st.device)
|
||||
else:
|
||||
base_slat_feats = slat[:N]
|
||||
slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device)
|
||||
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1))
|
||||
slat_feats = slat.feats
|
||||
# Duplicate shape context if CFG is active
|
||||
if coord_counts is not None and B > coord_counts.shape[0]:
|
||||
slat_feats = torch.cat([slat_feats, slat_feats], dim=0)
|
||||
elif coord_counts is None:
|
||||
slat_feats = slat.feats[:N].repeat(B, 1)
|
||||
|
||||
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats], dim=-1))
|
||||
out = self.shape2txt(x_st, t_eval, c_eval)
|
||||
|
||||
else: # structure
|
||||
orig_bsz = x.shape[0]
|
||||
batch_groups = len(cond_or_uncond) if len(cond_or_uncond) > 0 and orig_bsz % len(cond_or_uncond) == 0 else 1
|
||||
logical_batch = orig_bsz // batch_groups
|
||||
if logical_batch > 1:
|
||||
x_groups = x.reshape(batch_groups, logical_batch, *x.shape[1:])
|
||||
if timestep.shape[0] > 1:
|
||||
t_groups = timestep.reshape(batch_groups, logical_batch, *timestep.shape[1:])
|
||||
else:
|
||||
t_groups = timestep
|
||||
c_groups = context.reshape(batch_groups, logical_batch, *context.shape[1:])
|
||||
|
||||
if shape_rule and batch_groups > 1:
|
||||
selected_group_indices = cond_group_indices(batch_groups)
|
||||
else:
|
||||
selected_group_indices = list(range(batch_groups))
|
||||
|
||||
out_groups = []
|
||||
for sample_index in range(logical_batch):
|
||||
if shape_rule and batch_groups > 1:
|
||||
x_i = x_groups[selected_group_indices, sample_index]
|
||||
if timestep.shape[0] > 1:
|
||||
t_i = t_groups[selected_group_indices, sample_index]
|
||||
else:
|
||||
t_i = timestep
|
||||
c_i = c_groups[selected_group_indices, sample_index]
|
||||
else:
|
||||
x_i = x_groups[selected_group_indices, sample_index]
|
||||
if timestep.shape[0] > 1:
|
||||
t_i = t_groups[selected_group_indices, sample_index]
|
||||
else:
|
||||
t_i = timestep
|
||||
c_i = c_groups[selected_group_indices, sample_index]
|
||||
out_groups.append(self.structure_model(x_i, t_i, c_i))
|
||||
|
||||
out = out_groups[0].new_zeros((orig_bsz, *out_groups[0].shape[1:]))
|
||||
for sample_index, out_sample in enumerate(out_groups):
|
||||
if shape_rule and batch_groups > 1:
|
||||
repeated = out_sample[0]
|
||||
for group_index in range(batch_groups):
|
||||
out[group_index * logical_batch + sample_index] = repeated
|
||||
else:
|
||||
for local_group_index, group_index in enumerate(selected_group_indices):
|
||||
out[group_index * logical_batch + sample_index] = out_sample[local_group_index]
|
||||
if shape_rule and orig_bsz > 1:
|
||||
half = orig_bsz // 2
|
||||
x_eval = x[half:]
|
||||
t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep
|
||||
out = self.structure_model(x_eval, t_eval, cond)
|
||||
out = out.repeat(2, 1, 1, 1, 1)
|
||||
else:
|
||||
if shape_rule and orig_bsz > 1:
|
||||
half = orig_bsz // 2
|
||||
x = x[half:]
|
||||
timestep = timestep[half:] if timestep.shape[0] > 1 else timestep
|
||||
out = self.structure_model(x, timestep, cond if shape_rule and orig_bsz > 1 else context)
|
||||
if shape_rule and orig_bsz > 1:
|
||||
out = out.repeat(2, 1, 1, 1, 1)
|
||||
out = self.structure_model(x, timestep, context)
|
||||
|
||||
# ==================================================
|
||||
# RE-PAD AND FORMAT OUTPUT
|
||||
# ==================================================
|
||||
if not_struct_mode:
|
||||
if dense_out is None:
|
||||
out = out.feats
|
||||
out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
|
||||
if rule and orig_bsz > B:
|
||||
out = out.repeat(orig_bsz // B, 1, 1, 1)
|
||||
if mask is not None:
|
||||
# Instantly scatter the valid tokens back into a padded rectangular tensor
|
||||
padded_out = torch.zeros((B, N, out.feats.shape[-1]), device=x.device, dtype=out.feats.dtype)
|
||||
padded_out[mask] = out.feats
|
||||
out_tensor = padded_out.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
out_tensor = out.feats.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
|
||||
|
||||
if rule and orig_bsz > 1:
|
||||
out_tensor = out_tensor.repeat(2, 1, 1, 1)
|
||||
return out_tensor
|
||||
#else:
|
||||
# out = torch.nn.functional.pad(out, (0, 0, 0, 0, 0, 0, 24, 0))
|
||||
|
||||
return out
|
||||
|
||||
Loading…
Reference in New Issue
Block a user