mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-04 21:37:40 +08:00
Use cpu generator for rng in VAE
This commit is contained in:
parent
ee9a1ffe19
commit
63d422cd7b
@ -37,7 +37,7 @@ def hammersley_sequence(dim, n, num_samples):
|
|||||||
return [n / num_samples] + halton_sequence(dim - 1, n)
|
return [n / num_samples] + halton_sequence(dim - 1, n)
|
||||||
|
|
||||||
|
|
||||||
def sample_probs(probs, counts):
|
def sample_probs(probs, counts, generator=None):
|
||||||
# Systematic resampling: distribute counts[r] draws across the P bins of row r
|
# Systematic resampling: distribute counts[r] draws across the P bins of row r
|
||||||
batch_shape = counts.shape
|
batch_shape = counts.shape
|
||||||
R = counts.numel()
|
R = counts.numel()
|
||||||
@ -55,7 +55,7 @@ def sample_probs(probs, counts):
|
|||||||
return counts.new_zeros(*batch_shape, P)
|
return counts.new_zeros(*batch_shape, P)
|
||||||
cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1)
|
cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1)
|
||||||
grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax)
|
grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax)
|
||||||
u = (torch.rand(R, 1, device=device) + grid) / cnt # (R, Nmax) systematic samples
|
u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded)
|
||||||
idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1)
|
idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1)
|
||||||
weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r]
|
weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r]
|
||||||
out = torch.zeros(R, P, dtype=torch.float32, device=device)
|
out = torch.zeros(R, P, dtype=torch.float32, device=device)
|
||||||
@ -201,7 +201,7 @@ class OctreeProbabilityFixedlenDecoder(nn.Module):
|
|||||||
return {"logits": logits, "probs": torch.softmax(logits, dim=-1)}
|
return {"logits": logits, "probs": torch.softmax(logits, dim=-1)}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sample(model, cond, num_points, level, temperature=1.0):
|
def sample(model, cond, num_points, level, temperature=1.0, generator=None):
|
||||||
B = cond.shape[0]
|
B = cond.shape[0]
|
||||||
device = cond.device
|
device = cond.device
|
||||||
child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]],
|
child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]],
|
||||||
@ -219,7 +219,7 @@ class OctreeProbabilityFixedlenDecoder(nn.Module):
|
|||||||
pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature
|
pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature
|
||||||
pred_probs = torch.softmax(pred_logits, dim=-1)
|
pred_probs = torch.softmax(pred_logits, dim=-1)
|
||||||
pred_log_probs = torch.log_softmax(pred_logits, dim=-1)
|
pred_log_probs = torch.log_softmax(pred_logits, dim=-1)
|
||||||
sampled = sample_probs(pred_probs, prev_counts).flatten(1, 2)
|
sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2)
|
||||||
pred_log_probs = pred_log_probs.flatten(1, 2)
|
pred_log_probs = pred_log_probs.flatten(1, 2)
|
||||||
prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1)
|
prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1)
|
||||||
child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2)
|
child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2)
|
||||||
@ -241,7 +241,8 @@ class OctreeProbabilityFixedlenDecoder(nn.Module):
|
|||||||
res = 1 << level
|
res = 1 << level
|
||||||
prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points)
|
prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points)
|
||||||
coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1)
|
coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1)
|
||||||
coords_norm = (coords_int.to(torch.float32) + torch.rand_like(coords_int, dtype=torch.float32)) / res
|
rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device)
|
||||||
|
coords_norm = (coords_int.to(torch.float32) + rand) / res
|
||||||
return {"points": coords_norm, "log_probs": prev_log_probs}
|
return {"points": coords_norm, "log_probs": prev_log_probs}
|
||||||
|
|
||||||
|
|
||||||
@ -369,12 +370,13 @@ class OctreeGaussianDecoder(nn.Module):
|
|||||||
def gaussians_per_point(self) -> int:
|
def gaussians_per_point(self) -> int:
|
||||||
return self.gs.rep_config['num_gaussians']
|
return self.gs.rep_config['num_gaussians']
|
||||||
|
|
||||||
def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None):
|
def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None):
|
||||||
# level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews.
|
# level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews.
|
||||||
|
# generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG.
|
||||||
level = self._MAX_VOXEL_LEVEL if level is None else level
|
level = self._MAX_VOXEL_LEVEL if level is None else level
|
||||||
num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point)
|
num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point)
|
||||||
points_pred = OctreeProbabilityFixedlenDecoder.sample(
|
points_pred = OctreeProbabilityFixedlenDecoder.sample(
|
||||||
self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0,
|
self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator,
|
||||||
)
|
)
|
||||||
pred = self.gs(x=points_pred, cond=latent)
|
pred = self.gs(x=points_pred, cond=latent)
|
||||||
return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item
|
return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item
|
||||||
|
|||||||
@ -177,8 +177,8 @@ class VAEDecodeTripoSplat(IO.ComfyNode):
|
|||||||
memory_required = (cond_tokens * 4 + (n // gpp) * 10) * hidden * dtype_size
|
memory_required = (cond_tokens * 4 + (n // gpp) * 10) * hidden * dtype_size
|
||||||
comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required)
|
comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required)
|
||||||
latent = latent.to(device=vae.device, dtype=vae.vae_dtype)
|
latent = latent.to(device=vae.device, dtype=vae.vae_dtype)
|
||||||
torch.manual_seed(seed)
|
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||||
parts = [g.render_tensors() for g in decoder.decode(latent, num_gaussians=n)]
|
parts = [g.render_tensors() for g in decoder.decode(latent, num_gaussians=n, generator=generator)]
|
||||||
positions, scales, rotations, opacities, sh = (torch.stack(t) for t in zip(*parts))
|
positions, scales, rotations, opacities, sh = (torch.stack(t) for t in zip(*parts))
|
||||||
return IO.NodeOutput(Types.SPLAT(positions, scales, rotations, opacities, sh))
|
return IO.NodeOutput(Types.SPLAT(positions, scales, rotations, opacities, sh))
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user