From f02de13316b24436eb69222d7bc8181b73eeccb2 Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Tue, 1 Jul 2025 14:33:07 +0800 Subject: [PATCH 1/4] Add TCFG node (#8730) --- comfy_extras/nodes_tcfg.py | 71 ++++++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 72 insertions(+) create mode 100644 comfy_extras/nodes_tcfg.py diff --git a/comfy_extras/nodes_tcfg.py b/comfy_extras/nodes_tcfg.py new file mode 100644 index 000000000..35b89a73f --- /dev/null +++ b/comfy_extras/nodes_tcfg.py @@ -0,0 +1,71 @@ +# TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137) + +import torch + +from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict + + +def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor: + """Drop tangential components from uncond score to align with cond score.""" + # (B, 1, ...) + batch_num = cond_score.shape[0] + cond_score_flat = cond_score.reshape(batch_num, 1, -1).float() + uncond_score_flat = uncond_score.reshape(batch_num, 1, -1).float() + + # Score matrix A (B, 2, ...) + score_matrix = torch.cat((uncond_score_flat, cond_score_flat), dim=1) + try: + _, _, Vh = torch.linalg.svd(score_matrix, full_matrices=False) + except RuntimeError: + # Fallback to CPU + _, _, Vh = torch.linalg.svd(score_matrix.cpu(), full_matrices=False) + + # Drop the tangential components + v1 = Vh[:, 0:1, :].to(uncond_score_flat.device) # (B, 1, ...) + uncond_score_td = (uncond_score_flat @ v1.transpose(-2, -1)) * v1 + return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype) + + +class TCFG(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "model": (IO.MODEL, {}), + } + } + + RETURN_TYPES = (IO.MODEL,) + RETURN_NAMES = ("patched_model",) + FUNCTION = "patch" + + CATEGORY = "advanced/guidance" + DESCRIPTION = "TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality." + + def patch(self, model): + m = model.clone() + + def tangential_damping_cfg(args): + # Assume [cond, uncond, ...] + x = args["input"] + conds_out = args["conds_out"] + if len(conds_out) <= 1 or None in args["conds"][:2]: + # Skip when either cond or uncond is None + return conds_out + cond_pred = conds_out[0] + uncond_pred = conds_out[1] + uncond_td = score_tangential_damping(x - cond_pred, x - uncond_pred) + uncond_pred_td = x - uncond_td + return [cond_pred, uncond_pred_td] + conds_out[2:] + + m.set_model_sampler_pre_cfg_function(tangential_damping_cfg) + return (m,) + + +NODE_CLASS_MAPPINGS = { + "TCFG": TCFG, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "TCFG": "Tangential Damping CFG", +} diff --git a/nodes.py b/nodes.py index 99411a1fe..1b465b9e6 100644 --- a/nodes.py +++ b/nodes.py @@ -2283,6 +2283,7 @@ def init_builtin_extra_nodes(): "nodes_string.py", "nodes_camera_trajectory.py", "nodes_edit_model.py", + "nodes_tcfg.py" ] import_failed = [] From b22e97dcfa1736190cfcafd6091c4da885fcf48a Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Tue, 1 Jul 2025 14:38:52 +0800 Subject: [PATCH 2/4] Migrate ER-SDE from VE to VP algorithm and add its sampler node (#8744) Apply alpha scaling in the algorithm for reverse-time SDE and add custom ER-SDE sampler node for other solver types (SDE, ODE). --- comfy/k_diffusion/sampling.py | 65 ++++++++++++++++------------ comfy_extras/nodes_custom_sampler.py | 42 ++++++++++++++++++ 2 files changed, 80 insertions(+), 27 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 739468872..e231d6a3d 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1447,14 +1447,15 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, old_d = d return x + @torch.no_grad() def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.): return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True) + @torch.no_grad() -def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3): - """ - Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169. +def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3): + """Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169. Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py. """ extra_args = {} if extra_args is None else extra_args @@ -1462,12 +1463,18 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) - def default_noise_scaler(sigma): - return sigma * ((sigma ** 0.3).exp() + 10.0) - noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler + def default_er_sde_noise_scaler(x): + return x * ((x ** 0.3).exp() + 10.0) + + noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler num_integration_points = 200.0 point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device) + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling) + er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t + old_denoised = None old_denoised_d = None @@ -1478,32 +1485,36 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None stage_used = min(max_stage, i + 1) if sigmas[i + 1] == 0: x = denoised - elif stage_used == 1: - r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i]) - x = r * x + (1 - r) * denoised else: - r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i]) - x = r * x + (1 - r) * denoised + er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1] + alpha_s = sigmas[i] / er_lambda_s + alpha_t = sigmas[i + 1] / er_lambda_t + r_alpha = alpha_t / alpha_s + r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s) - dt = sigmas[i + 1] - sigmas[i] - sigma_step_size = -dt / num_integration_points - sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size - scaled_pos = noise_scaler(sigma_pos) + # Stage 1 Euler + x = r_alpha * r * x + alpha_t * (1 - r) * denoised - # Stage 2 - s = torch.sum(1 / scaled_pos) * sigma_step_size - denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1]) - x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d + if stage_used >= 2: + dt = er_lambda_t - er_lambda_s + lambda_step_size = -dt / num_integration_points + lambda_pos = er_lambda_t + point_indice * lambda_step_size + scaled_pos = noise_scaler(lambda_pos) - if stage_used >= 3: - # Stage 3 - s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size - denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2) - x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u - old_denoised_d = denoised_d + # Stage 2 + s = torch.sum(1 / scaled_pos) * lambda_step_size + denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1]) + x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d - if s_noise != 0 and sigmas[i + 1] > 0: - x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0) + if stage_used >= 3: + # Stage 3 + s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size + denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2) + x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u + old_denoised_d = denoised_d + + if s_noise > 0: + x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0) old_denoised = denoised return x diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index fc506a0cc..b3a772714 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -2,6 +2,7 @@ import math import comfy.samplers import comfy.sample from comfy.k_diffusion import sampling as k_diffusion_sampling +from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict import latent_preview import torch import comfy.utils @@ -480,6 +481,46 @@ class SamplerDPMAdaptative: "s_noise":s_noise }) return (sampler, ) + +class SamplerER_SDE(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "solver_type": (IO.COMBO, {"options": ["ER-SDE", "Reverse-time SDE", "ODE"]}), + "max_stage": (IO.INT, {"default": 3, "min": 1, "max": 3}), + "eta": ( + IO.FLOAT, + {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False, "tooltip": "Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."}, + ), + "s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False}), + } + } + + RETURN_TYPES = (IO.SAMPLER,) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, solver_type, max_stage, eta, s_noise): + if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0): + eta = 0 + s_noise = 0 + + def reverse_time_sde_noise_scaler(x): + return x ** (eta + 1) + + if solver_type == "ER-SDE": + # Use the default one in sample_er_sde() + noise_scaler = None + else: + noise_scaler = reverse_time_sde_noise_scaler + + sampler_name = "er_sde" + sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage}) + return (sampler,) + + class Noise_EmptyNoise: def __init__(self): self.seed = 0 @@ -787,6 +828,7 @@ NODE_CLASS_MAPPINGS = { "SamplerDPMPP_SDE": SamplerDPMPP_SDE, "SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral, "SamplerDPMAdaptative": SamplerDPMAdaptative, + "SamplerER_SDE": SamplerER_SDE, "SplitSigmas": SplitSigmas, "SplitSigmasDenoise": SplitSigmasDenoise, "FlipSigmas": FlipSigmas, From 772de7c00653fc3a825762f555e836d071a4dc80 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 1 Jul 2025 00:09:07 -0700 Subject: [PATCH 3/4] PerpNeg Guider optimizations. (#8753) --- comfy_extras/nodes_perpneg.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py index f051cbf9a..89e5eef90 100644 --- a/comfy_extras/nodes_perpneg.py +++ b/comfy_extras/nodes_perpneg.py @@ -4,6 +4,7 @@ import comfy.sampler_helpers import comfy.samplers import comfy.utils import node_helpers +import math def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale): pos = noise_pred_pos - noise_pred_nocond @@ -69,6 +70,12 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider): negative_cond = self.conds.get("negative", None) empty_cond = self.conds.get("empty_negative_prompt", None) + if model_options.get("disable_cfg1_optimization", False) == False: + if math.isclose(self.neg_scale, 0.0): + negative_cond = None + if math.isclose(self.cfg, 1.0): + empty_cond = None + conds = [positive_cond, negative_cond, empty_cond] out = comfy.samplers.calc_cond_batch(self.inner_model, conds, x, timestep, model_options) From 79ed75274874590967ff13ac73c5d84262d489d0 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Tue, 1 Jul 2025 20:43:48 -0400 Subject: [PATCH 4/4] support upload 3d model to custom subfolder (#8597) --- comfy_extras/nodes_load_3d.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 40d03e18a..899608149 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -5,6 +5,8 @@ import os from comfy.comfy_types import IO from comfy_api.input_impl import VideoFromFile +from pathlib import Path + def normalize_path(path): return path.replace('\\', '/') @@ -16,7 +18,14 @@ class Load3D(): os.makedirs(input_dir, exist_ok=True) - files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.fbx', '.stl'))] + input_path = Path(input_dir) + base_path = Path(folder_paths.get_input_directory()) + + files = [ + normalize_path(str(file_path.relative_to(base_path))) + for file_path in input_path.rglob("*") + if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'} + ] return {"required": { "model_file": (sorted(files), {"file_upload": True}), @@ -61,7 +70,14 @@ class Load3DAnimation(): os.makedirs(input_dir, exist_ok=True) - files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.fbx'))] + input_path = Path(input_dir) + base_path = Path(folder_paths.get_input_directory()) + + files = [ + normalize_path(str(file_path.relative_to(base_path))) + for file_path in input_path.rglob("*") + if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'} + ] return {"required": { "model_file": (sorted(files), {"file_upload": True}),