Fix TRELLIS2 texture sampling preview orientation

This commit is contained in:
kijai 2026-06-17 02:27:44 +03:00
parent 6ef69849a0
commit 805e7d5ae3
5 changed files with 14 additions and 7 deletions

View File

@ -1100,7 +1100,8 @@ class Trellis2(nn.Module):
# Pre-computed per-stage back-projected features
proj_feats = kwargs.get("trellis2_proj_feats")
sampling_preview.set_context(mode=mode, coords=coords, coord_counts=coord_counts)
sampling_preview.set_context(mode=mode, coords=coords, coord_counts=coord_counts,
model_frame=kwargs.get("trellis2_model_frame"))
is_first_shape_pass = False
if mode == "shape_generation_512":

View File

@ -17,10 +17,11 @@ _context = {}
_tex_rgb = None
def set_context(mode=None, coords=None, coord_counts=None):
def set_context(mode=None, coords=None, coord_counts=None, model_frame=None):
_context["mode"] = mode
_context["coords"] = coords
_context["coord_counts"] = coord_counts
_context["model_frame"] = model_frame
def get_context():

View File

@ -1741,7 +1741,7 @@ class Trellis2(BaseModel):
# CONDConstant: shared across pos/neg
for k in ("trellis2_coords", "trellis2_coord_counts",
"trellis2_generation_mode", "trellis2_shape_slat",
"trellis2_proj_feats"):
"trellis2_proj_feats", "trellis2_model_frame"):
v = kwargs.get(k)
if v is not None:
out[k] = comfy.conds.CONDConstant(v)

View File

@ -861,13 +861,16 @@ class Trellis2TextureStage(IO.ComfyNode):
shape_slat = shape_slat.squeeze(-1).transpose(1, 2).reshape(-1, channels)
latent = torch.zeros(batch_size, channels, max_tokens, 1)
proj_pack = _proj_pack_from_conditioning(positive)
model_frame = shape_latent.get("model_frame",
"y_up" if proj_pack is not None else "z_up")
extras = {
"trellis2_generation_mode": "texture_generation",
"trellis2_coords": coords,
"trellis2_coord_counts": counts,
"trellis2_shape_slat": shape_slat,
"trellis2_model_frame": model_frame,
}
proj_pack = _proj_pack_from_conditioning(positive)
if proj_pack is not None and coord_resolution is not None:
extras["trellis2_proj_feats"] = compute_stage_proj_feats(
proj_pack, "tex_1024", coords=coords, coord_resolution=coord_resolution,

View File

@ -178,7 +178,7 @@ class Trellis3DPreviewer(LatentPreviewer):
pmax = proj.amax(dim=0, keepdim=True)
return ((proj - pmin) / (pmax - pmin + 1e-8)).clamp(0, 1)
def _texture(self, x0, coords):
def _texture(self, x0, coords, model_frame=None):
if coords.shape[-1] == 4:
b0 = coords[:, 0] == 0
spatial = coords[b0][:, 1:4].float()
@ -187,9 +187,11 @@ class Trellis3DPreviewer(LatentPreviewer):
n0 = spatial.shape[0]
if n0 == 0:
return None
if model_frame == "z_up":
spatial = torch.stack([spatial[:, 0], spatial[:, 2], -spatial[:, 1]], dim=-1)
latent = x0[0, :, :n0, 0].float().transpose(0, 1) # [n0, C]
colors = self._latent_color(latent) # [n0, 3]
res = float(spatial.max().item()) + 1.0
res = float(spatial.abs().max().item()) + 1.0
rad = max(1, int(round(self._SIZE * self._FILL / max(res, 1) / 2)))
return self._splat(spatial, colors, rad)
@ -202,7 +204,7 @@ class Trellis3DPreviewer(LatentPreviewer):
mode = ctx.get("mode")
coords = ctx.get("coords")
if mode == "texture_generation" and coords is not None:
return self._texture(x0, coords)
return self._texture(x0, coords, model_frame=ctx.get("model_frame"))
except Exception as e:
logging.debug(f"Trellis3DPreviewer: skipping preview ({e})")
return None