diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 266f67043..f57b1b018 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -334,12 +334,16 @@ class Trellis2UpsampleStage(IO.ComfyNode): ) @staticmethod - def _quantize_unique(hr_coords: torch.Tensor, lr_resolution: int, hr_resolution: int) -> torch.Tensor: - # Fold the two scalar divisions into one and chain the float math in-place - # to avoid 3 full M*3 fp32 transients per call. - scale = (hr_resolution // 16) / lr_resolution + def _quantize_unique(hr_coords: torch.Tensor, lr_resolution: int, hr_resolution: int, pixal3d_mode: bool = False) -> torch.Tensor: + # Trellis2 uses `floor((c+0.5) * grid_res / lr_res) + # Pixal3D uses `round((c+0.5) * (grid_res-1) / lr_res)` + # this is a half-cell spatial shift. Branch so each upstream is matched bit-for-bit. + grid_res = hr_resolution // 16 spatial = hr_coords[:, 1:].float() - spatial.add_(0.5).mul_(scale) + if pixal3d_mode: + spatial.add_(0.5).mul_((grid_res - 1) / lr_resolution).round_() + else: + spatial.add_(0.5).mul_(grid_res / lr_resolution) quant = torch.cat([hr_coords[:, :1], spatial.int()], dim=1) return quant.unique(dim=0) @@ -352,6 +356,8 @@ class Trellis2UpsampleStage(IO.ComfyNode): shape_vae = vae.first_stage_model lr_resolution = 512 target_resolution = int(target_resolution) + proj_pack = _proj_pack_from_conditioning(positive) + pixal3d_mode = proj_pack is not None # Decode each sample's HR coords, then search for the largest hr_resolution # that fits under max_tokens across all samples. @@ -380,7 +386,7 @@ class Trellis2UpsampleStage(IO.ComfyNode): quant_unique_list = [] exceeds_limit = False for hr_coords_i in sample_hr_coords: - qu = cls._quantize_unique(hr_coords_i, lr_resolution, hr_resolution) + qu = cls._quantize_unique(hr_coords_i, lr_resolution, hr_resolution, pixal3d_mode) quant_unique_list.append(qu) if qu.shape[0] >= max_tokens: exceeds_limit = True @@ -390,7 +396,7 @@ class Trellis2UpsampleStage(IO.ComfyNode): if hr_resolution <= 1024: for k in range(len(quant_unique_list), len(sample_hr_coords)): quant_unique_list.append( - cls._quantize_unique(sample_hr_coords[k], lr_resolution, hr_resolution) + cls._quantize_unique(sample_hr_coords[k], lr_resolution, hr_resolution, pixal3d_mode) ) break hr_resolution -= 128 @@ -412,7 +418,6 @@ class Trellis2UpsampleStage(IO.ComfyNode): "trellis2_coords": coords, "trellis2_coord_counts": counts, } - proj_pack = _proj_pack_from_conditioning(positive) if proj_pack is not None: extras["trellis2_proj_feats"] = compute_stage_proj_feats( proj_pack, "shape_1024", coords=coords, coord_resolution=coord_resolution, @@ -1188,9 +1193,11 @@ class CFGGuidanceInterval(IO.ComfyNode): is done via model_sampling.percent_to_sigma so the window is portable across schedules (flow / EDM / discrete) and shift settings. - Defaults are full-range (no bypass). For Trellis2's upstream behavior, - wire (start_percent=0.0, end_percent=0.667) on the SS / shape KSamplers; - texture defaults to cfg=1 so the node is moot there.""" + Defaults are full-range (no bypass). Upstream Trellis2 / Pixal3D + pipeline.json sets guidance_interval=[0.6, 1.0] (upstream t-space) on the + SS and shape samplers — CFG active only in the first 40% of sampling. + Wire (start_percent=0.0, end_percent=0.4) on the SS / shape KSamplers to + match. Texture defaults to cfg=1 so the node is moot there.""" @classmethod def define_schema(cls): return IO.Schema(