mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 20:10:48 +08:00
ensure the model train properly in both grad ckpt or not
This commit is contained in:
parent
37139daa98
commit
305602c668
@ -623,13 +623,13 @@ def _create_weight_adapter(
|
|||||||
for name, parameter in train_adapter.named_parameters():
|
for name, parameter in train_adapter.named_parameters():
|
||||||
lora_params[f"{module_name}.{name}"] = parameter
|
lora_params[f"{module_name}.{name}"] = parameter
|
||||||
|
|
||||||
return train_adapter, lora_params
|
return train_adapter.train().requires_grad_(True), lora_params
|
||||||
else:
|
else:
|
||||||
# 1D weight - use BiasDiff
|
# 1D weight - use BiasDiff
|
||||||
diff = torch.nn.Parameter(
|
diff = torch.nn.Parameter(
|
||||||
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
|
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
|
||||||
)
|
)
|
||||||
diff_module = BiasDiff(diff)
|
diff_module = BiasDiff(diff).train().requires_grad_(True)
|
||||||
lora_params[f"{module_name}.diff"] = diff
|
lora_params[f"{module_name}.diff"] = diff
|
||||||
return diff_module, lora_params
|
return diff_module, lora_params
|
||||||
|
|
||||||
@ -648,7 +648,7 @@ def _create_bias_adapter(module, module_name, lora_dtype):
|
|||||||
bias = torch.nn.Parameter(
|
bias = torch.nn.Parameter(
|
||||||
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
|
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
|
||||||
)
|
)
|
||||||
bias_module = BiasDiff(bias)
|
bias_module = BiasDiff(bias).train().requires_grad_(True)
|
||||||
lora_params = {f"{module_name}.diff_b": bias}
|
lora_params = {f"{module_name}.diff_b": bias}
|
||||||
return bias_module, lora_params
|
return bias_module, lora_params
|
||||||
|
|
||||||
@ -961,6 +961,9 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||||
|
|
||||||
with torch.inference_mode(False):
|
with torch.inference_mode(False):
|
||||||
|
# Setup models for training
|
||||||
|
mp.model.requires_grad_(False)
|
||||||
|
|
||||||
# Load existing LoRA weights if provided
|
# Load existing LoRA weights if provided
|
||||||
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
||||||
|
|
||||||
@ -982,8 +985,6 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
):
|
):
|
||||||
patch(m)
|
patch(m)
|
||||||
|
|
||||||
# Setup models for training
|
|
||||||
mp.model.requires_grad_(False)
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
# With force_full_load=False we should be able to have offloading
|
# With force_full_load=False we should be able to have offloading
|
||||||
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user