ensure the model train properly in both grad ckpt or not

This commit is contained in:
Kohaku-Blueleaf 2025-12-09 22:52:32 +08:00
parent 37139daa98
commit 305602c668

View File

@ -623,13 +623,13 @@ def _create_weight_adapter(
for name, parameter in train_adapter.named_parameters():
lora_params[f"{module_name}.{name}"] = parameter
return train_adapter, lora_params
return train_adapter.train().requires_grad_(True), lora_params
else:
# 1D weight - use BiasDiff
diff = torch.nn.Parameter(
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
)
diff_module = BiasDiff(diff)
diff_module = BiasDiff(diff).train().requires_grad_(True)
lora_params[f"{module_name}.diff"] = diff
return diff_module, lora_params
@ -648,7 +648,7 @@ def _create_bias_adapter(module, module_name, lora_dtype):
bias = torch.nn.Parameter(
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
)
bias_module = BiasDiff(bias)
bias_module = BiasDiff(bias).train().requires_grad_(True)
lora_params = {f"{module_name}.diff_b": bias}
return bias_module, lora_params
@ -961,6 +961,9 @@ class TrainLoraNode(io.ComfyNode):
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
with torch.inference_mode(False):
# Setup models for training
mp.model.requires_grad_(False)
# Load existing LoRA weights if provided
existing_weights, existing_steps = _load_existing_lora(existing_lora)
@ -982,8 +985,6 @@ class TrainLoraNode(io.ComfyNode):
):
patch(m)
# Setup models for training
mp.model.requires_grad_(False)
torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd