mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-17 07:05:12 +08:00
Fix for TemporalScoreRescaling node when batch_size > 1
This commit is contained in:
parent
3ad36d6be6
commit
bfdfc7d11c
@ -133,11 +133,11 @@ class TemporalScoreRescaling(io.ComfyNode):
|
|||||||
def temporal_score_rescaling(args):
|
def temporal_score_rescaling(args):
|
||||||
denoised = args["denoised"]
|
denoised = args["denoised"]
|
||||||
x = args["input"]
|
x = args["input"]
|
||||||
sigma = args["sigma"]
|
sigma = args["sigma"][0]
|
||||||
curr_model = args["model"]
|
curr_model = args["model"]
|
||||||
|
|
||||||
# No rescaling (r = 1) or no noise
|
# No rescaling (r = 1) or no noise
|
||||||
if tsr_k == 1 or sigma == 0:
|
if tsr_k == 1 or sigma.item() == 0:
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
model_sampling = curr_model.current_patcher.get_model_object("model_sampling")
|
model_sampling = curr_model.current_patcher.get_model_object("model_sampling")
|
||||||
@ -145,7 +145,7 @@ class TemporalScoreRescaling(io.ComfyNode):
|
|||||||
snr = (2 * half_log_snr).exp()
|
snr = (2 * half_log_snr).exp()
|
||||||
|
|
||||||
# No rescaling needed (r = 1)
|
# No rescaling needed (r = 1)
|
||||||
if snr == 0:
|
if snr.item() == 0:
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
rescaling_r = compute_tsr_rescaling_factor(snr, tsr_k, tsr_variance)
|
rescaling_r = compute_tsr_rescaling_factor(snr, tsr_k, tsr_variance)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user