[Reviving #5709] Add strength input to Differential Diffusion (#9957)

* Update nodes_differential_diffusion.py

* Update nodes_differential_diffusion.py

* Make strength optional to avoid validation errors when loading old workflows, adjust step

---------

Co-authored-by: ThereforeGames <eric@sparknight.io>
This commit is contained in:
Jedrzej Kosinski 2025-09-20 18:10:39 -07:00 committed by GitHub
parent 66241cef31
commit 9ed3c5cc09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,19 +5,30 @@ import torch
class DifferentialDiffusion(): class DifferentialDiffusion():
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ), return {
}} "required": {
"model": ("MODEL", ),
},
"optional": {
"strength": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
}),
}
}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "apply" FUNCTION = "apply"
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
INIT = False INIT = False
def apply(self, model): def apply(self, model, strength=1.0):
model = model.clone() model = model.clone()
model.set_model_denoise_mask_function(self.forward) model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength))
return (model, ) return (model, )
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
model = extra_options["model"] model = extra_options["model"]
step_sigmas = extra_options["sigmas"] step_sigmas = extra_options["sigmas"]
sigma_to = model.inner_model.model_sampling.sigma_min sigma_to = model.inner_model.model_sampling.sigma_min
@ -31,7 +42,15 @@ class DifferentialDiffusion():
threshold = (current_ts - ts_to) / (ts_from - ts_to) threshold = (current_ts - ts_to) / (ts_from - ts_to)
return (denoise_mask >= threshold).to(denoise_mask.dtype) # Generate the binary mask based on the threshold
binary_mask = (denoise_mask >= threshold).to(denoise_mask.dtype)
# Blend binary mask with the original denoise_mask using strength
if strength and strength < 1:
blended_mask = strength * binary_mask + (1 - strength) * denoise_mask
return blended_mask
else:
return binary_mask
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {