Use cpu generator for rng in VAE

This commit is contained in:
kijai 2026-06-01 14:47:19 +03:00
parent ee9a1ffe19
commit 63d422cd7b
2 changed files with 11 additions and 9 deletions

View File

@ -37,7 +37,7 @@ def hammersley_sequence(dim, n, num_samples):
return [n / num_samples] + halton_sequence(dim - 1, n) return [n / num_samples] + halton_sequence(dim - 1, n)
def sample_probs(probs, counts): def sample_probs(probs, counts, generator=None):
# Systematic resampling: distribute counts[r] draws across the P bins of row r # Systematic resampling: distribute counts[r] draws across the P bins of row r
batch_shape = counts.shape batch_shape = counts.shape
R = counts.numel() R = counts.numel()
@ -55,7 +55,7 @@ def sample_probs(probs, counts):
return counts.new_zeros(*batch_shape, P) return counts.new_zeros(*batch_shape, P)
cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1) cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1)
grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax) grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax)
u = (torch.rand(R, 1, device=device) + grid) / cnt # (R, Nmax) systematic samples u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded)
idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1) idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1)
weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r] weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r]
out = torch.zeros(R, P, dtype=torch.float32, device=device) out = torch.zeros(R, P, dtype=torch.float32, device=device)
@ -201,7 +201,7 @@ class OctreeProbabilityFixedlenDecoder(nn.Module):
return {"logits": logits, "probs": torch.softmax(logits, dim=-1)} return {"logits": logits, "probs": torch.softmax(logits, dim=-1)}
@staticmethod @staticmethod
def sample(model, cond, num_points, level, temperature=1.0): def sample(model, cond, num_points, level, temperature=1.0, generator=None):
B = cond.shape[0] B = cond.shape[0]
device = cond.device device = cond.device
child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]], child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]],
@ -219,7 +219,7 @@ class OctreeProbabilityFixedlenDecoder(nn.Module):
pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature
pred_probs = torch.softmax(pred_logits, dim=-1) pred_probs = torch.softmax(pred_logits, dim=-1)
pred_log_probs = torch.log_softmax(pred_logits, dim=-1) pred_log_probs = torch.log_softmax(pred_logits, dim=-1)
sampled = sample_probs(pred_probs, prev_counts).flatten(1, 2) sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2)
pred_log_probs = pred_log_probs.flatten(1, 2) pred_log_probs = pred_log_probs.flatten(1, 2)
prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1) prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1)
child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2) child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2)
@ -241,7 +241,8 @@ class OctreeProbabilityFixedlenDecoder(nn.Module):
res = 1 << level res = 1 << level
prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points) prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points)
coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1) coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1)
coords_norm = (coords_int.to(torch.float32) + torch.rand_like(coords_int, dtype=torch.float32)) / res rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device)
coords_norm = (coords_int.to(torch.float32) + rand) / res
return {"points": coords_norm, "log_probs": prev_log_probs} return {"points": coords_norm, "log_probs": prev_log_probs}
@ -369,12 +370,13 @@ class OctreeGaussianDecoder(nn.Module):
def gaussians_per_point(self) -> int: def gaussians_per_point(self) -> int:
return self.gs.rep_config['num_gaussians'] return self.gs.rep_config['num_gaussians']
def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None): def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None):
# level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews. # level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews.
# generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG.
level = self._MAX_VOXEL_LEVEL if level is None else level level = self._MAX_VOXEL_LEVEL if level is None else level
num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point) num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point)
points_pred = OctreeProbabilityFixedlenDecoder.sample( points_pred = OctreeProbabilityFixedlenDecoder.sample(
self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator,
) )
pred = self.gs(x=points_pred, cond=latent) pred = self.gs(x=points_pred, cond=latent)
return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item

View File

@ -177,8 +177,8 @@ class VAEDecodeTripoSplat(IO.ComfyNode):
memory_required = (cond_tokens * 4 + (n // gpp) * 10) * hidden * dtype_size memory_required = (cond_tokens * 4 + (n // gpp) * 10) * hidden * dtype_size
comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required) comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required)
latent = latent.to(device=vae.device, dtype=vae.vae_dtype) latent = latent.to(device=vae.device, dtype=vae.vae_dtype)
torch.manual_seed(seed) generator = torch.Generator(device="cpu").manual_seed(seed)
parts = [g.render_tensors() for g in decoder.decode(latent, num_gaussians=n)] parts = [g.render_tensors() for g in decoder.decode(latent, num_gaussians=n, generator=generator)]
positions, scales, rotations, opacities, sh = (torch.stack(t) for t in zip(*parts)) positions, scales, rotations, opacities, sh = (torch.stack(t) for t in zip(*parts))
return IO.NodeOutput(Types.SPLAT(positions, scales, rotations, opacities, sh)) return IO.NodeOutput(Types.SPLAT(positions, scales, rotations, opacities, sh))