mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 10:52:31 +08:00
Ensure model is not inference mode
This commit is contained in:
parent
b615af1c65
commit
90fd638c5d
@ -1180,8 +1180,30 @@ class TrainLoraNode(io.ComfyNode):
|
||||
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||
|
||||
with torch.inference_mode(False):
|
||||
# Now ComfyUI will load model in inference mode
|
||||
# which make all parameter is now inference mode tensors
|
||||
# to make the training correctly working
|
||||
# we re-build the parameters in training mode
|
||||
for module in mp.model.modules():
|
||||
for name, param in list(module._parameters.items()):
|
||||
if param is not None:
|
||||
try:
|
||||
_ = param._version
|
||||
except Exception:
|
||||
module._parameters[name] = torch.nn.Parameter(
|
||||
param.detach().clone(),
|
||||
requires_grad=param.requires_grad,
|
||||
)
|
||||
|
||||
for name, buf in list(module._buffers.items()):
|
||||
if buf is not None:
|
||||
try:
|
||||
_ = buf._version
|
||||
except Exception:
|
||||
module._buffers[name] = buf.detach().clone()
|
||||
|
||||
# Setup models for training
|
||||
mp.model.requires_grad_(False)
|
||||
mp.model.requires_grad_(False).train()
|
||||
|
||||
# Load existing LoRA weights if provided
|
||||
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user