Align Pixal3D grid_res math to upstream

This commit is contained in:
kijai 2026-05-23 13:27:33 +03:00
parent 3edbf7c4a7
commit 56a03e748f

View File

@ -334,12 +334,16 @@ class Trellis2UpsampleStage(IO.ComfyNode):
)
@staticmethod
def _quantize_unique(hr_coords: torch.Tensor, lr_resolution: int, hr_resolution: int) -> torch.Tensor:
# Fold the two scalar divisions into one and chain the float math in-place
# to avoid 3 full M*3 fp32 transients per call.
scale = (hr_resolution // 16) / lr_resolution
def _quantize_unique(hr_coords: torch.Tensor, lr_resolution: int, hr_resolution: int, pixal3d_mode: bool = False) -> torch.Tensor:
# Trellis2 uses `floor((c+0.5) * grid_res / lr_res)
# Pixal3D uses `round((c+0.5) * (grid_res-1) / lr_res)`
# this is a half-cell spatial shift. Branch so each upstream is matched bit-for-bit.
grid_res = hr_resolution // 16
spatial = hr_coords[:, 1:].float()
spatial.add_(0.5).mul_(scale)
if pixal3d_mode:
spatial.add_(0.5).mul_((grid_res - 1) / lr_resolution).round_()
else:
spatial.add_(0.5).mul_(grid_res / lr_resolution)
quant = torch.cat([hr_coords[:, :1], spatial.int()], dim=1)
return quant.unique(dim=0)
@ -352,6 +356,8 @@ class Trellis2UpsampleStage(IO.ComfyNode):
shape_vae = vae.first_stage_model
lr_resolution = 512
target_resolution = int(target_resolution)
proj_pack = _proj_pack_from_conditioning(positive)
pixal3d_mode = proj_pack is not None
# Decode each sample's HR coords, then search for the largest hr_resolution
# that fits under max_tokens across all samples.
@ -380,7 +386,7 @@ class Trellis2UpsampleStage(IO.ComfyNode):
quant_unique_list = []
exceeds_limit = False
for hr_coords_i in sample_hr_coords:
qu = cls._quantize_unique(hr_coords_i, lr_resolution, hr_resolution)
qu = cls._quantize_unique(hr_coords_i, lr_resolution, hr_resolution, pixal3d_mode)
quant_unique_list.append(qu)
if qu.shape[0] >= max_tokens:
exceeds_limit = True
@ -390,7 +396,7 @@ class Trellis2UpsampleStage(IO.ComfyNode):
if hr_resolution <= 1024:
for k in range(len(quant_unique_list), len(sample_hr_coords)):
quant_unique_list.append(
cls._quantize_unique(sample_hr_coords[k], lr_resolution, hr_resolution)
cls._quantize_unique(sample_hr_coords[k], lr_resolution, hr_resolution, pixal3d_mode)
)
break
hr_resolution -= 128
@ -412,7 +418,6 @@ class Trellis2UpsampleStage(IO.ComfyNode):
"trellis2_coords": coords,
"trellis2_coord_counts": counts,
}
proj_pack = _proj_pack_from_conditioning(positive)
if proj_pack is not None:
extras["trellis2_proj_feats"] = compute_stage_proj_feats(
proj_pack, "shape_1024", coords=coords, coord_resolution=coord_resolution,
@ -1188,9 +1193,11 @@ class CFGGuidanceInterval(IO.ComfyNode):
is done via model_sampling.percent_to_sigma so the window is portable
across schedules (flow / EDM / discrete) and shift settings.
Defaults are full-range (no bypass). For Trellis2's upstream behavior,
wire (start_percent=0.0, end_percent=0.667) on the SS / shape KSamplers;
texture defaults to cfg=1 so the node is moot there."""
Defaults are full-range (no bypass). Upstream Trellis2 / Pixal3D
pipeline.json sets guidance_interval=[0.6, 1.0] (upstream t-space) on the
SS and shape samplers CFG active only in the first 40% of sampling.
Wire (start_percent=0.0, end_percent=0.4) on the SS / shape KSamplers to
match. Texture defaults to cfg=1 so the node is moot there."""
@classmethod
def define_schema(cls):
return IO.Schema(