mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 03:50:50 +08:00
ensure the model train properly in both grad ckpt or not
This commit is contained in:
parent
37139daa98
commit
305602c668
@ -623,13 +623,13 @@ def _create_weight_adapter(
|
||||
for name, parameter in train_adapter.named_parameters():
|
||||
lora_params[f"{module_name}.{name}"] = parameter
|
||||
|
||||
return train_adapter, lora_params
|
||||
return train_adapter.train().requires_grad_(True), lora_params
|
||||
else:
|
||||
# 1D weight - use BiasDiff
|
||||
diff = torch.nn.Parameter(
|
||||
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
|
||||
)
|
||||
diff_module = BiasDiff(diff)
|
||||
diff_module = BiasDiff(diff).train().requires_grad_(True)
|
||||
lora_params[f"{module_name}.diff"] = diff
|
||||
return diff_module, lora_params
|
||||
|
||||
@ -648,7 +648,7 @@ def _create_bias_adapter(module, module_name, lora_dtype):
|
||||
bias = torch.nn.Parameter(
|
||||
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
|
||||
)
|
||||
bias_module = BiasDiff(bias)
|
||||
bias_module = BiasDiff(bias).train().requires_grad_(True)
|
||||
lora_params = {f"{module_name}.diff_b": bias}
|
||||
return bias_module, lora_params
|
||||
|
||||
@ -961,6 +961,9 @@ class TrainLoraNode(io.ComfyNode):
|
||||
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||
|
||||
with torch.inference_mode(False):
|
||||
# Setup models for training
|
||||
mp.model.requires_grad_(False)
|
||||
|
||||
# Load existing LoRA weights if provided
|
||||
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
||||
|
||||
@ -982,8 +985,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
):
|
||||
patch(m)
|
||||
|
||||
# Setup models for training
|
||||
mp.model.requires_grad_(False)
|
||||
torch.cuda.empty_cache()
|
||||
# With force_full_load=False we should be able to have offloading
|
||||
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
||||
|
||||
Loading…
Reference in New Issue
Block a user