mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
Fix TRELLIS2 texture sampling preview orientation
This commit is contained in:
parent
6ef69849a0
commit
805e7d5ae3
@ -1100,7 +1100,8 @@ class Trellis2(nn.Module):
|
||||
# Pre-computed per-stage back-projected features
|
||||
proj_feats = kwargs.get("trellis2_proj_feats")
|
||||
|
||||
sampling_preview.set_context(mode=mode, coords=coords, coord_counts=coord_counts)
|
||||
sampling_preview.set_context(mode=mode, coords=coords, coord_counts=coord_counts,
|
||||
model_frame=kwargs.get("trellis2_model_frame"))
|
||||
|
||||
is_first_shape_pass = False
|
||||
if mode == "shape_generation_512":
|
||||
|
||||
@ -17,10 +17,11 @@ _context = {}
|
||||
_tex_rgb = None
|
||||
|
||||
|
||||
def set_context(mode=None, coords=None, coord_counts=None):
|
||||
def set_context(mode=None, coords=None, coord_counts=None, model_frame=None):
|
||||
_context["mode"] = mode
|
||||
_context["coords"] = coords
|
||||
_context["coord_counts"] = coord_counts
|
||||
_context["model_frame"] = model_frame
|
||||
|
||||
|
||||
def get_context():
|
||||
|
||||
@ -1741,7 +1741,7 @@ class Trellis2(BaseModel):
|
||||
# CONDConstant: shared across pos/neg
|
||||
for k in ("trellis2_coords", "trellis2_coord_counts",
|
||||
"trellis2_generation_mode", "trellis2_shape_slat",
|
||||
"trellis2_proj_feats"):
|
||||
"trellis2_proj_feats", "trellis2_model_frame"):
|
||||
v = kwargs.get(k)
|
||||
if v is not None:
|
||||
out[k] = comfy.conds.CONDConstant(v)
|
||||
|
||||
@ -861,13 +861,16 @@ class Trellis2TextureStage(IO.ComfyNode):
|
||||
shape_slat = shape_slat.squeeze(-1).transpose(1, 2).reshape(-1, channels)
|
||||
|
||||
latent = torch.zeros(batch_size, channels, max_tokens, 1)
|
||||
proj_pack = _proj_pack_from_conditioning(positive)
|
||||
model_frame = shape_latent.get("model_frame",
|
||||
"y_up" if proj_pack is not None else "z_up")
|
||||
extras = {
|
||||
"trellis2_generation_mode": "texture_generation",
|
||||
"trellis2_coords": coords,
|
||||
"trellis2_coord_counts": counts,
|
||||
"trellis2_shape_slat": shape_slat,
|
||||
"trellis2_model_frame": model_frame,
|
||||
}
|
||||
proj_pack = _proj_pack_from_conditioning(positive)
|
||||
if proj_pack is not None and coord_resolution is not None:
|
||||
extras["trellis2_proj_feats"] = compute_stage_proj_feats(
|
||||
proj_pack, "tex_1024", coords=coords, coord_resolution=coord_resolution,
|
||||
|
||||
@ -178,7 +178,7 @@ class Trellis3DPreviewer(LatentPreviewer):
|
||||
pmax = proj.amax(dim=0, keepdim=True)
|
||||
return ((proj - pmin) / (pmax - pmin + 1e-8)).clamp(0, 1)
|
||||
|
||||
def _texture(self, x0, coords):
|
||||
def _texture(self, x0, coords, model_frame=None):
|
||||
if coords.shape[-1] == 4:
|
||||
b0 = coords[:, 0] == 0
|
||||
spatial = coords[b0][:, 1:4].float()
|
||||
@ -187,9 +187,11 @@ class Trellis3DPreviewer(LatentPreviewer):
|
||||
n0 = spatial.shape[0]
|
||||
if n0 == 0:
|
||||
return None
|
||||
if model_frame == "z_up":
|
||||
spatial = torch.stack([spatial[:, 0], spatial[:, 2], -spatial[:, 1]], dim=-1)
|
||||
latent = x0[0, :, :n0, 0].float().transpose(0, 1) # [n0, C]
|
||||
colors = self._latent_color(latent) # [n0, 3]
|
||||
res = float(spatial.max().item()) + 1.0
|
||||
res = float(spatial.abs().max().item()) + 1.0
|
||||
rad = max(1, int(round(self._SIZE * self._FILL / max(res, 1) / 2)))
|
||||
return self._splat(spatial, colors, rad)
|
||||
|
||||
@ -202,7 +204,7 @@ class Trellis3DPreviewer(LatentPreviewer):
|
||||
mode = ctx.get("mode")
|
||||
coords = ctx.get("coords")
|
||||
if mode == "texture_generation" and coords is not None:
|
||||
return self._texture(x0, coords)
|
||||
return self._texture(x0, coords, model_frame=ctx.get("model_frame"))
|
||||
except Exception as e:
|
||||
logging.debug(f"Trellis3DPreviewer: skipping preview ({e})")
|
||||
return None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user