From 63d422cd7b2a2fcb4302a1213e2def7d663cac14 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 1 Jun 2026 14:47:19 +0300 Subject: [PATCH] Use cpu generator for rng in VAE --- comfy/ldm/triposplat/vae.py | 16 +++++++++------- comfy_extras/nodes_triposplat.py | 4 ++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/triposplat/vae.py b/comfy/ldm/triposplat/vae.py index 077d0a612..5e1822daf 100644 --- a/comfy/ldm/triposplat/vae.py +++ b/comfy/ldm/triposplat/vae.py @@ -37,7 +37,7 @@ def hammersley_sequence(dim, n, num_samples): return [n / num_samples] + halton_sequence(dim - 1, n) -def sample_probs(probs, counts): +def sample_probs(probs, counts, generator=None): # Systematic resampling: distribute counts[r] draws across the P bins of row r batch_shape = counts.shape R = counts.numel() @@ -55,7 +55,7 @@ def sample_probs(probs, counts): return counts.new_zeros(*batch_shape, P) cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1) grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax) - u = (torch.rand(R, 1, device=device) + grid) / cnt # (R, Nmax) systematic samples + u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded) idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1) weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r] out = torch.zeros(R, P, dtype=torch.float32, device=device) @@ -201,7 +201,7 @@ class OctreeProbabilityFixedlenDecoder(nn.Module): return {"logits": logits, "probs": torch.softmax(logits, dim=-1)} @staticmethod - def sample(model, cond, num_points, level, temperature=1.0): + def sample(model, cond, num_points, level, temperature=1.0, generator=None): B = cond.shape[0] device = cond.device child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]], @@ -219,7 +219,7 @@ class OctreeProbabilityFixedlenDecoder(nn.Module): pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature pred_probs = torch.softmax(pred_logits, dim=-1) pred_log_probs = torch.log_softmax(pred_logits, dim=-1) - sampled = sample_probs(pred_probs, prev_counts).flatten(1, 2) + sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2) pred_log_probs = pred_log_probs.flatten(1, 2) prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1) child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2) @@ -241,7 +241,8 @@ class OctreeProbabilityFixedlenDecoder(nn.Module): res = 1 << level prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points) coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1) - coords_norm = (coords_int.to(torch.float32) + torch.rand_like(coords_int, dtype=torch.float32)) / res + rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device) + coords_norm = (coords_int.to(torch.float32) + rand) / res return {"points": coords_norm, "log_probs": prev_log_probs} @@ -369,12 +370,13 @@ class OctreeGaussianDecoder(nn.Module): def gaussians_per_point(self) -> int: return self.gs.rep_config['num_gaussians'] - def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None): + def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None): # level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews. + # generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG. level = self._MAX_VOXEL_LEVEL if level is None else level num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point) points_pred = OctreeProbabilityFixedlenDecoder.sample( - self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, + self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator, ) pred = self.gs(x=points_pred, cond=latent) return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item diff --git a/comfy_extras/nodes_triposplat.py b/comfy_extras/nodes_triposplat.py index d51f134b4..021b669fd 100644 --- a/comfy_extras/nodes_triposplat.py +++ b/comfy_extras/nodes_triposplat.py @@ -177,8 +177,8 @@ class VAEDecodeTripoSplat(IO.ComfyNode): memory_required = (cond_tokens * 4 + (n // gpp) * 10) * hidden * dtype_size comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required) latent = latent.to(device=vae.device, dtype=vae.vae_dtype) - torch.manual_seed(seed) - parts = [g.render_tensors() for g in decoder.decode(latent, num_gaussians=n)] + generator = torch.Generator(device="cpu").manual_seed(seed) + parts = [g.render_tensors() for g in decoder.decode(latent, num_gaussians=n, generator=generator)] positions, scales, rotations, opacities, sh = (torch.stack(t) for t in zip(*parts)) return IO.NodeOutput(Types.SPLAT(positions, scales, rotations, opacities, sh))