diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 71b307389..dfbe2c1ce 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -623,13 +623,13 @@ def _create_weight_adapter( for name, parameter in train_adapter.named_parameters(): lora_params[f"{module_name}.{name}"] = parameter - return train_adapter, lora_params + return train_adapter.train().requires_grad_(True), lora_params else: # 1D weight - use BiasDiff diff = torch.nn.Parameter( torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True) ) - diff_module = BiasDiff(diff) + diff_module = BiasDiff(diff).train().requires_grad_(True) lora_params[f"{module_name}.diff"] = diff return diff_module, lora_params @@ -648,7 +648,7 @@ def _create_bias_adapter(module, module_name, lora_dtype): bias = torch.nn.Parameter( torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True) ) - bias_module = BiasDiff(bias) + bias_module = BiasDiff(bias).train().requires_grad_(True) lora_params = {f"{module_name}.diff_b": bias} return bias_module, lora_params @@ -961,6 +961,9 @@ class TrainLoraNode(io.ComfyNode): positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode) with torch.inference_mode(False): + # Setup models for training + mp.model.requires_grad_(False) + # Load existing LoRA weights if provided existing_weights, existing_steps = _load_existing_lora(existing_lora) @@ -982,8 +985,6 @@ class TrainLoraNode(io.ComfyNode): ): patch(m) - # Setup models for training - mp.model.requires_grad_(False) torch.cuda.empty_cache() # With force_full_load=False we should be able to have offloading # But for offloading in training we need custom AutoGrad hooks for fwd/bwd