From 860894d83f7a309971e26438e1d086055bb6336d Mon Sep 17 00:00:00 2001 From: FizzleDorf <1fizzledorf@gmail.com> Date: Fri, 1 Dec 2023 09:15:05 -0500 Subject: [PATCH] seed scheduling added --- comfy/sample.py | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index 034db97ee..4e28f2ec8 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -6,22 +6,32 @@ import comfy.utils import math import numpy as np -def prepare_noise(latent_image, seed, noise_inds=None): +def prepare_noise(latent_image, seeds, noise_inds=None): """ - creates random noise given a latent image and a seed. - optional arg skip can be used to skip and discard x number of noise generations for a given seed + Creates random noise given a latent image and a seed or a list of seeds. + Optional arg noise_inds can be used to select specific noise indices. """ - generator = torch.manual_seed(seed) - if noise_inds is None: - return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - - unique_inds, inverse = np.unique(noise_inds, return_inverse=True) + num_latents = latent_image.size(0) + + if not isinstance(seeds, list): + seeds = [seeds] + + generator = torch.Generator() + generator.manual_seed(seeds[0]) # Use the first seed as the default generator seed + noises = [] - for i in range(unique_inds[-1]+1): - noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - if i in unique_inds: - noises.append(noise) - noises = [noises[i] for i in inverse] + + for i in range(num_latents): + if i < len(seeds): # Use the provided seeds if available + seed = seeds[i] + else: + seed = torch.randint(0, 2**32, (1,)).item() # Generate a random seed for additional latents + generator.manual_seed(seed) + print("seed:", seed) + noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, device="cpu", generator=generator) + noises.append(noise) + + noises = [noises[i] for i in range(num_latents)] noises = torch.cat(noises, axis=0) return noises