mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 19:02:31 +08:00
Ensure model is not inference mode
This commit is contained in:
parent
b615af1c65
commit
90fd638c5d
@ -1180,8 +1180,30 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||||
|
|
||||||
with torch.inference_mode(False):
|
with torch.inference_mode(False):
|
||||||
|
# Now ComfyUI will load model in inference mode
|
||||||
|
# which make all parameter is now inference mode tensors
|
||||||
|
# to make the training correctly working
|
||||||
|
# we re-build the parameters in training mode
|
||||||
|
for module in mp.model.modules():
|
||||||
|
for name, param in list(module._parameters.items()):
|
||||||
|
if param is not None:
|
||||||
|
try:
|
||||||
|
_ = param._version
|
||||||
|
except Exception:
|
||||||
|
module._parameters[name] = torch.nn.Parameter(
|
||||||
|
param.detach().clone(),
|
||||||
|
requires_grad=param.requires_grad,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, buf in list(module._buffers.items()):
|
||||||
|
if buf is not None:
|
||||||
|
try:
|
||||||
|
_ = buf._version
|
||||||
|
except Exception:
|
||||||
|
module._buffers[name] = buf.detach().clone()
|
||||||
|
|
||||||
# Setup models for training
|
# Setup models for training
|
||||||
mp.model.requires_grad_(False)
|
mp.model.requires_grad_(False).train()
|
||||||
|
|
||||||
# Load existing LoRA weights if provided
|
# Load existing LoRA weights if provided
|
||||||
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user