diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 0616dfc2d..d6a062d06 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -914,10 +914,11 @@ def _run_training_loop( """ sigmas = torch.tensor(range(num_images)) noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) + ndim = latents[0].ndim if bucket_mode: # Use first bucket's first latent as dummy for guider - dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1) + dummy_latent = latents[0][:1].repeat(num_images, *[1]*(ndim-1)) guider.sample( noise.generate_noise({"samples": dummy_latent}), dummy_latent, @@ -927,7 +928,7 @@ def _run_training_loop( ) elif multi_res: # use first latent as dummy latent if multi_res - latents = latents[0].repeat(num_images, 1, 1, 1) + latents = latents[0].repeat(num_images, *[1]*(ndim-1)) guider.sample( noise.generate_noise({"samples": latents}), latents,