ensure the model train properly in both grad ckpt or not

This commit is contained in:
Kohaku-Blueleaf 2025-12-09 22:52:32 +08:00
parent 37139daa98
commit 305602c668

View File

@ -623,13 +623,13 @@ def _create_weight_adapter(
for name, parameter in train_adapter.named_parameters(): for name, parameter in train_adapter.named_parameters():
lora_params[f"{module_name}.{name}"] = parameter lora_params[f"{module_name}.{name}"] = parameter
return train_adapter, lora_params return train_adapter.train().requires_grad_(True), lora_params
else: else:
# 1D weight - use BiasDiff # 1D weight - use BiasDiff
diff = torch.nn.Parameter( diff = torch.nn.Parameter(
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True) torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
) )
diff_module = BiasDiff(diff) diff_module = BiasDiff(diff).train().requires_grad_(True)
lora_params[f"{module_name}.diff"] = diff lora_params[f"{module_name}.diff"] = diff
return diff_module, lora_params return diff_module, lora_params
@ -648,7 +648,7 @@ def _create_bias_adapter(module, module_name, lora_dtype):
bias = torch.nn.Parameter( bias = torch.nn.Parameter(
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True) torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
) )
bias_module = BiasDiff(bias) bias_module = BiasDiff(bias).train().requires_grad_(True)
lora_params = {f"{module_name}.diff_b": bias} lora_params = {f"{module_name}.diff_b": bias}
return bias_module, lora_params return bias_module, lora_params
@ -961,6 +961,9 @@ class TrainLoraNode(io.ComfyNode):
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode) positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
with torch.inference_mode(False): with torch.inference_mode(False):
# Setup models for training
mp.model.requires_grad_(False)
# Load existing LoRA weights if provided # Load existing LoRA weights if provided
existing_weights, existing_steps = _load_existing_lora(existing_lora) existing_weights, existing_steps = _load_existing_lora(existing_lora)
@ -982,8 +985,6 @@ class TrainLoraNode(io.ComfyNode):
): ):
patch(m) patch(m)
# Setup models for training
mp.model.requires_grad_(False)
torch.cuda.empty_cache() torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading # With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd # But for offloading in training we need custom AutoGrad hooks for fwd/bwd