SAG works with sampler_cfg_function

This commit is contained in:
Raphael Walker 2023-12-07 14:01:19 +01:00
parent ecd098a7fd
commit a66060de1a
No known key found for this signature in database
GPG Key ID: E6F58BE3395D3AA8

View File

@ -252,11 +252,12 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
uncond = None uncond = None
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
if "sampler_cfg_function" in model_options: if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
return x - model_options["sampler_cfg_function"](args) cfg_result = x - model_options["sampler_cfg_function"](args)
# if cfg = 1.0, we can't do sag
elif "sag" in model_options: if "sag" in model_options:
assert uncond is not None, "SAG requires uncond guidance" assert uncond is not None, "SAG requires uncond guidance"
sag_scale = model_options["sag_scale"] sag_scale = model_options["sag_scale"]
sag_sigma = model_options["sag_sigma"] sag_sigma = model_options["sag_sigma"]
@ -265,13 +266,13 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
# these methods are added by the sag patcher # these methods are added by the sag patcher
uncond_attn = model.get_attn_scores() uncond_attn = model.get_attn_scores()
mid_shape = model.get_mid_block_shape() mid_shape = model.get_mid_block_shape()
# create the adversarially blurred image
degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold) degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred degraded_noised = degraded + x - uncond_pred
# call into the UNet with the adversarially blurred image # call into the UNet
(sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options) (sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options)
return uncond_pred + (cond_pred - uncond_pred) * cond_scale + (degraded - sag) * sag_scale cfg_result += (degraded - sag) * sag_scale
else: return cfg_result
return uncond_pred + (cond_pred - uncond_pred) * cond_scale
def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0): def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map # reshape and GAP the attention map