diff --git a/comfy/model_base.py b/comfy/model_base.py index ec96b522b..8b712e812 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -126,6 +126,7 @@ class BaseModel(torch.nn.Module): logging.debug("model_type {}".format(model_type.name)) logging.debug("adm {}".format(self.adm_channels)) self.memory_usage_factor = model_config.memory_usage_factor + self.training = False def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t diff --git a/comfy/samplers.py b/comfy/samplers.py index 6d718e50e..27835259e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -687,7 +687,7 @@ class CFGGuider: def predict_noise(self, x, timestep, model_options={}, seed=None): return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) - def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed): + def inner_sample(self, noise, latent_image, device, sampler: KSAMPLER, sigmas, denoise_mask, callback, disable_pbar, seed): if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. latent_image = self.inner_model.process_latent_in(latent_image) @@ -698,7 +698,7 @@ class CFGGuider: samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return self.inner_model.process_latent_out(samples.to(torch.float32)) - def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + def sample(self, noise, latent_image, sampler: KSAMPLER, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): if sigmas.shape[-1] == 0: return latent_image diff --git a/comfy_extras/nodes/nodes_custom_sampler.py b/comfy_extras/nodes/nodes_custom_sampler.py index 7c05e0191..31d39aaa7 100644 --- a/comfy_extras/nodes/nodes_custom_sampler.py +++ b/comfy_extras/nodes/nodes_custom_sampler.py @@ -8,6 +8,7 @@ from comfy.cmd import latent_preview import torch from comfy import utils from comfy import node_helpers +from comfy.samplers import KSAMPLER class BasicScheduler: @@ -597,7 +598,7 @@ class SamplerCustomAdvanced: CATEGORY = "sampling/custom_sampling" - def sample(self, noise, guider, sampler, sigmas, latent_image): + def sample(self, noise, guider: comfy.samplers.CFGGuider, sampler: KSAMPLER, sigmas, latent_image): latent = latent_image latent_image = latent["samples"] latent = latent.copy()