Add TemporalScoreRescaling node (#10351)

* Add TemporalScoreRescaling node

* Mention image generation in tsr_k's tooltip
This commit is contained in:
chaObserv 2025-10-16 06:12:25 +08:00 committed by GitHub
parent 1c10b33f9b
commit f72c6616b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,7 @@
import torch
from typing_extensions import override
from comfy.k_diffusion.sampling import sigma_to_half_log_snr
from comfy_api.latest import ComfyExtension, io
@ -63,12 +65,105 @@ class EpsilonScaling(io.ComfyNode):
return io.NodeOutput(model_clone)
def compute_tsr_rescaling_factor(
snr: torch.Tensor, tsr_k: float, tsr_variance: float
) -> torch.Tensor:
"""Compute the rescaling score ratio in Temporal Score Rescaling.
See equation (6) in https://arxiv.org/pdf/2510.01184v1.
"""
posinf_mask = torch.isposinf(snr)
rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1)
return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k
class TemporalScoreRescaling(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TemporalScoreRescaling",
display_name="TSR - Temporal Score Rescaling",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Float.Input(
"tsr_k",
tooltip=(
"Controls the rescaling strength.\n"
"Lower k produces more detailed results; higher k produces smoother results in image generation. Setting k = 1 disables rescaling."
),
default=0.95,
min=0.01,
max=100.0,
step=0.001,
display_mode=io.NumberDisplay.number,
),
io.Float.Input(
"tsr_sigma",
tooltip=(
"Controls how early rescaling takes effect.\n"
"Larger values take effect earlier."
),
default=1.0,
min=0.01,
max=100.0,
step=0.001,
display_mode=io.NumberDisplay.number,
),
],
outputs=[
io.Model.Output(
display_name="patched_model",
),
],
description=(
"[Post-CFG Function]\n"
"TSR - Temporal Score Rescaling (2510.01184)\n\n"
"Rescaling the model's score or noise to steer the sampling diversity.\n"
),
)
@classmethod
def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput:
tsr_variance = tsr_sigma**2
def temporal_score_rescaling(args):
denoised = args["denoised"]
x = args["input"]
sigma = args["sigma"]
curr_model = args["model"]
# No rescaling (r = 1) or no noise
if tsr_k == 1 or sigma == 0:
return denoised
model_sampling = curr_model.current_patcher.get_model_object("model_sampling")
half_log_snr = sigma_to_half_log_snr(sigma, model_sampling)
snr = (2 * half_log_snr).exp()
# No rescaling needed (r = 1)
if snr == 0:
return denoised
rescaling_r = compute_tsr_rescaling_factor(snr, tsr_k, tsr_variance)
# Derived from scaled_denoised = (x - r * sigma * noise) / alpha
alpha = sigma * half_log_snr.exp()
return torch.lerp(x / alpha, denoised, rescaling_r)
m = model.clone()
m.set_model_sampler_post_cfg_function(temporal_score_rescaling)
return io.NodeOutput(m)
class EpsilonScalingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EpsilonScaling,
TemporalScoreRescaling,
]
async def comfy_entrypoint() -> EpsilonScalingExtension:
return EpsilonScalingExtension()