mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 00:12:33 +08:00
SAG works with sampler_cfg_function
This commit is contained in:
parent
ecd098a7fd
commit
a66060de1a
@ -252,11 +252,12 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
|||||||
uncond = None
|
uncond = None
|
||||||
|
|
||||||
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
|
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
|
||||||
|
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||||
if "sampler_cfg_function" in model_options:
|
if "sampler_cfg_function" in model_options:
|
||||||
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
||||||
return x - model_options["sampler_cfg_function"](args)
|
cfg_result = x - model_options["sampler_cfg_function"](args)
|
||||||
# if cfg = 1.0, we can't do sag
|
|
||||||
elif "sag" in model_options:
|
if "sag" in model_options:
|
||||||
assert uncond is not None, "SAG requires uncond guidance"
|
assert uncond is not None, "SAG requires uncond guidance"
|
||||||
sag_scale = model_options["sag_scale"]
|
sag_scale = model_options["sag_scale"]
|
||||||
sag_sigma = model_options["sag_sigma"]
|
sag_sigma = model_options["sag_sigma"]
|
||||||
@ -265,13 +266,13 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
|||||||
# these methods are added by the sag patcher
|
# these methods are added by the sag patcher
|
||||||
uncond_attn = model.get_attn_scores()
|
uncond_attn = model.get_attn_scores()
|
||||||
mid_shape = model.get_mid_block_shape()
|
mid_shape = model.get_mid_block_shape()
|
||||||
|
# create the adversarially blurred image
|
||||||
degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold)
|
degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold)
|
||||||
degraded_noised = degraded + x - uncond_pred
|
degraded_noised = degraded + x - uncond_pred
|
||||||
# call into the UNet with the adversarially blurred image
|
# call into the UNet
|
||||||
(sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options)
|
(sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options)
|
||||||
return uncond_pred + (cond_pred - uncond_pred) * cond_scale + (degraded - sag) * sag_scale
|
cfg_result += (degraded - sag) * sag_scale
|
||||||
else:
|
return cfg_result
|
||||||
return uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
|
||||||
|
|
||||||
def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0):
|
def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0):
|
||||||
# reshape and GAP the attention map
|
# reshape and GAP the attention map
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user