From 305602c66831c5b76e34a7648d205ab572c94f6e Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Dec 2025 22:52:32 +0800 Subject: [PATCH] ensure the model train properly in both grad ckpt or not --- comfy_extras/nodes_train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 71b307389..dfbe2c1ce 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -623,13 +623,13 @@ def _create_weight_adapter( for name, parameter in train_adapter.named_parameters(): lora_params[f"{module_name}.{name}"] = parameter - return train_adapter, lora_params + return train_adapter.train().requires_grad_(True), lora_params else: # 1D weight - use BiasDiff diff = torch.nn.Parameter( torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True) ) - diff_module = BiasDiff(diff) + diff_module = BiasDiff(diff).train().requires_grad_(True) lora_params[f"{module_name}.diff"] = diff return diff_module, lora_params @@ -648,7 +648,7 @@ def _create_bias_adapter(module, module_name, lora_dtype): bias = torch.nn.Parameter( torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True) ) - bias_module = BiasDiff(bias) + bias_module = BiasDiff(bias).train().requires_grad_(True) lora_params = {f"{module_name}.diff_b": bias} return bias_module, lora_params @@ -961,6 +961,9 @@ class TrainLoraNode(io.ComfyNode): positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode) with torch.inference_mode(False): + # Setup models for training + mp.model.requires_grad_(False) + # Load existing LoRA weights if provided existing_weights, existing_steps = _load_existing_lora(existing_lora) @@ -982,8 +985,6 @@ class TrainLoraNode(io.ComfyNode): ): patch(m) - # Setup models for training - mp.model.requires_grad_(False) torch.cuda.empty_cache() # With force_full_load=False we should be able to have offloading # But for offloading in training we need custom AutoGrad hooks for fwd/bwd