diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 99c930c8b..3ed6e114d 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -1100,7 +1100,8 @@ class Trellis2(nn.Module): # Pre-computed per-stage back-projected features proj_feats = kwargs.get("trellis2_proj_feats") - sampling_preview.set_context(mode=mode, coords=coords, coord_counts=coord_counts) + sampling_preview.set_context(mode=mode, coords=coords, coord_counts=coord_counts, + model_frame=kwargs.get("trellis2_model_frame")) is_first_shape_pass = False if mode == "shape_generation_512": diff --git a/comfy/ldm/trellis2/sampling_preview.py b/comfy/ldm/trellis2/sampling_preview.py index 76b2ef2c4..ddab00a5c 100644 --- a/comfy/ldm/trellis2/sampling_preview.py +++ b/comfy/ldm/trellis2/sampling_preview.py @@ -17,10 +17,11 @@ _context = {} _tex_rgb = None -def set_context(mode=None, coords=None, coord_counts=None): +def set_context(mode=None, coords=None, coord_counts=None, model_frame=None): _context["mode"] = mode _context["coords"] = coords _context["coord_counts"] = coord_counts + _context["model_frame"] = model_frame def get_context(): diff --git a/comfy/model_base.py b/comfy/model_base.py index 3c87a36a0..f1b56dc33 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1741,7 +1741,7 @@ class Trellis2(BaseModel): # CONDConstant: shared across pos/neg for k in ("trellis2_coords", "trellis2_coord_counts", "trellis2_generation_mode", "trellis2_shape_slat", - "trellis2_proj_feats"): + "trellis2_proj_feats", "trellis2_model_frame"): v = kwargs.get(k) if v is not None: out[k] = comfy.conds.CONDConstant(v) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 9ff444014..ceace5b7d 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -861,13 +861,16 @@ class Trellis2TextureStage(IO.ComfyNode): shape_slat = shape_slat.squeeze(-1).transpose(1, 2).reshape(-1, channels) latent = torch.zeros(batch_size, channels, max_tokens, 1) + proj_pack = _proj_pack_from_conditioning(positive) + model_frame = shape_latent.get("model_frame", + "y_up" if proj_pack is not None else "z_up") extras = { "trellis2_generation_mode": "texture_generation", "trellis2_coords": coords, "trellis2_coord_counts": counts, "trellis2_shape_slat": shape_slat, + "trellis2_model_frame": model_frame, } - proj_pack = _proj_pack_from_conditioning(positive) if proj_pack is not None and coord_resolution is not None: extras["trellis2_proj_feats"] = compute_stage_proj_feats( proj_pack, "tex_1024", coords=coords, coord_resolution=coord_resolution, diff --git a/latent_preview.py b/latent_preview.py index 40c9e0fe3..d97c538fe 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -178,7 +178,7 @@ class Trellis3DPreviewer(LatentPreviewer): pmax = proj.amax(dim=0, keepdim=True) return ((proj - pmin) / (pmax - pmin + 1e-8)).clamp(0, 1) - def _texture(self, x0, coords): + def _texture(self, x0, coords, model_frame=None): if coords.shape[-1] == 4: b0 = coords[:, 0] == 0 spatial = coords[b0][:, 1:4].float() @@ -187,9 +187,11 @@ class Trellis3DPreviewer(LatentPreviewer): n0 = spatial.shape[0] if n0 == 0: return None + if model_frame == "z_up": + spatial = torch.stack([spatial[:, 0], spatial[:, 2], -spatial[:, 1]], dim=-1) latent = x0[0, :, :n0, 0].float().transpose(0, 1) # [n0, C] colors = self._latent_color(latent) # [n0, 3] - res = float(spatial.max().item()) + 1.0 + res = float(spatial.abs().max().item()) + 1.0 rad = max(1, int(round(self._SIZE * self._FILL / max(res, 1) / 2))) return self._splat(spatial, colors, rad) @@ -202,7 +204,7 @@ class Trellis3DPreviewer(LatentPreviewer): mode = ctx.get("mode") coords = ctx.get("coords") if mode == "texture_generation" and coords is not None: - return self._texture(x0, coords) + return self._texture(x0, coords, model_frame=ctx.get("model_frame")) except Exception as e: logging.debug(f"Trellis3DPreviewer: skipping preview ({e})") return None