Ensure model is not inference mode

This commit is contained in:
Kohaku-Blueleaf 2026-04-14 14:39:14 +08:00
parent b615af1c65
commit 90fd638c5d

View File

@ -1180,8 +1180,30 @@ class TrainLoraNode(io.ComfyNode):
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode) positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
with torch.inference_mode(False): with torch.inference_mode(False):
# Now ComfyUI will load model in inference mode
# which make all parameter is now inference mode tensors
# to make the training correctly working
# we re-build the parameters in training mode
for module in mp.model.modules():
for name, param in list(module._parameters.items()):
if param is not None:
try:
_ = param._version
except Exception:
module._parameters[name] = torch.nn.Parameter(
param.detach().clone(),
requires_grad=param.requires_grad,
)
for name, buf in list(module._buffers.items()):
if buf is not None:
try:
_ = buf._version
except Exception:
module._buffers[name] = buf.detach().clone()
# Setup models for training # Setup models for training
mp.model.requires_grad_(False) mp.model.requires_grad_(False).train()
# Load existing LoRA weights if provided # Load existing LoRA weights if provided
existing_weights, existing_steps = _load_existing_lora(existing_lora) existing_weights, existing_steps = _load_existing_lora(existing_lora)