Epsilon Scaling method

This commit is contained in:
Koratahiu~ 2025-10-01 06:19:43 +03:00
parent bab8ba20bf
commit 1664f191a5
2 changed files with 67 additions and 0 deletions

66
comfy_extras/nodes_eps.py Normal file
View File

@ -0,0 +1,66 @@
import torch
class EpsilonScaling:
"""
Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
(https://arxiv.org/abs/2308.15321v6).
This method mitigates exposure bias by scaling the predicted noise during sampling,
which can significantly improve sample quality. This implementation uses the "uniform schedule"
recommended by the paper for its practicality and effectiveness.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scaling_factor": ("FLOAT", {
"default": 1.005,
"min": 0.5,
"max": 1.5,
"step": 0.001,
"display": "number"
}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "model_patches/unet"
def patch(self, model, scaling_factor):
# Prevent division by zero, though the UI's min value should prevent this.
if scaling_factor == 0:
scaling_factor = 1e-9
def epsilon_scaling_function(args):
"""
This function is applied after the CFG guidance has been calculated.
It recalculates the denoised latent by scaling the predicted noise.
"""
denoised = args["denoised"]
x = args["input"]
noise_pred = x - denoised
scaled_noise_pred = noise_pred / scaling_factor
new_denoised = x - scaled_noise_pred
return new_denoised
# Clone the model patcher to avoid modifying the original model in place
model_clone = model.clone()
# Apply the patch using set_model_sampler_post_cfg_function.
# disable_cfg1_optimization=True is crucial. This patch needs the outputs of both the
# conditional and unconditional models to correctly calculate the guided noise, even when
# the CFG scale is 1.0. Disabling the optimization ensures both are always computed.
model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function, disable_cfg1_optimization=True)
return (model_clone,)
NODE_CLASS_MAPPINGS = {
"Epsilon Scaling": EpsilonScaling
}

View File

@ -2297,6 +2297,7 @@ async def init_builtin_extra_nodes():
"nodes_gits.py",
"nodes_controlnet.py",
"nodes_hunyuan.py",
"nodes_eps.py",
"nodes_flux.py",
"nodes_lora_extract.py",
"nodes_torch_compile.py",