force clone inside training mode to avoid inference tensor

This commit is contained in:
Kohaku-Blueleaf 2026-05-19 00:21:51 +08:00
parent da468125da
commit a359c5b654

View File

@ -1143,8 +1143,9 @@ class TrainLoraNode(io.ComfyNode):
# Process conditioning # Process conditioning
positive = _process_conditioning(positive) positive = _process_conditioning(positive)
with torch.inference_mode(False):
# Setup model and dtype # Setup model and dtype
mp = model.clone() mp = model.clone(force_deepcopy=True)
use_grad_scaler = False use_grad_scaler = False
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
if training_dtype != "none": if training_dtype != "none":
@ -1179,29 +1180,6 @@ class TrainLoraNode(io.ComfyNode):
# Validate and expand conditioning # Validate and expand conditioning
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode) positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
with torch.inference_mode(False):
# Now ComfyUI will load model in inference mode
# which make all parameter is now inference mode tensors
# to make the training correctly working
# we re-build the parameters in training mode
for module in mp.model.modules():
for name, param in list(module._parameters.items()):
if param is not None:
try:
_ = param._version
except Exception:
module._parameters[name] = torch.nn.Parameter(
param.detach().clone(),
requires_grad=param.requires_grad,
)
for name, buf in list(module._buffers.items()):
if buf is not None:
try:
_ = buf._version
except Exception:
module._buffers[name] = buf.detach().clone()
# Setup models for training # Setup models for training
mp.model.requires_grad_(False).train() mp.model.requires_grad_(False).train()