From 9ed3c5cc09c55d2fffa67b59d9d21e3b44d7653e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 20 Sep 2025 18:10:39 -0700 Subject: [PATCH] [Reviving #5709] Add strength input to Differential Diffusion (#9957) * Update nodes_differential_diffusion.py * Update nodes_differential_diffusion.py * Make strength optional to avoid validation errors when loading old workflows, adjust step --------- Co-authored-by: ThereforeGames --- comfy_extras/nodes_differential_diffusion.py | 33 +++++++++++++++----- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_differential_diffusion.py b/comfy_extras/nodes_differential_diffusion.py index 98dbbf102..255ac420d 100644 --- a/comfy_extras/nodes_differential_diffusion.py +++ b/comfy_extras/nodes_differential_diffusion.py @@ -5,19 +5,30 @@ import torch class DifferentialDiffusion(): @classmethod def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - }} + return { + "required": { + "model": ("MODEL", ), + }, + "optional": { + "strength": ("FLOAT", { + "default": 1.0, + "min": 0.0, + "max": 1.0, + "step": 0.01, + }), + } + } RETURN_TYPES = ("MODEL",) FUNCTION = "apply" CATEGORY = "_for_testing" INIT = False - def apply(self, model): + def apply(self, model, strength=1.0): model = model.clone() - model.set_model_denoise_mask_function(self.forward) - return (model,) + model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength)) + return (model, ) - def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): + def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float): model = extra_options["model"] step_sigmas = extra_options["sigmas"] sigma_to = model.inner_model.model_sampling.sigma_min @@ -31,7 +42,15 @@ class DifferentialDiffusion(): threshold = (current_ts - ts_to) / (ts_from - ts_to) - return (denoise_mask >= threshold).to(denoise_mask.dtype) + # Generate the binary mask based on the threshold + binary_mask = (denoise_mask >= threshold).to(denoise_mask.dtype) + + # Blend binary mask with the original denoise_mask using strength + if strength and strength < 1: + blended_mask = strength * binary_mask + (1 - strength) * denoise_mask + return blended_mask + else: + return binary_mask NODE_CLASS_MAPPINGS = {