Ensure train node support real 5D tensor data

This commit is contained in:
Kohaku-Blueleaf 2026-04-28 01:12:57 +08:00
parent 43ff315b0a
commit 025bce5ab6

View File

@ -914,10 +914,11 @@ def _run_training_loop(
"""
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
ndim = latents[0].ndim
if bucket_mode:
# Use first bucket's first latent as dummy for guider
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
dummy_latent = latents[0][:1].repeat(num_images, *[1]*(ndim-1))
guider.sample(
noise.generate_noise({"samples": dummy_latent}),
dummy_latent,
@ -927,7 +928,7 @@ def _run_training_loop(
)
elif multi_res:
# use first latent as dummy latent if multi_res
latents = latents[0].repeat(num_images, 1, 1, 1)
latents = latents[0].repeat(num_images, *[1]*(ndim-1))
guider.sample(
noise.generate_noise({"samples": latents}),
latents,