force clone inside training mode to avoid inference tensor

This commit is contained in:
Kohaku-Blueleaf 2026-05-19 00:21:51 +08:00
parent da468125da
commit a359c5b654

View File

@ -1143,64 +1143,42 @@ class TrainLoraNode(io.ComfyNode):
# Process conditioning
positive = _process_conditioning(positive)
# Setup model and dtype
mp = model.clone()
use_grad_scaler = False
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
if training_dtype != "none":
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
else:
# Detect model's native dtype for autocast
model_dtype = mp.model.get_dtype()
if model_dtype == torch.float16:
dtype = torch.float16
# GradScaler only supports float16 gradients, not bfloat16.
# Only enable it when lora params will also be in float16.
if lora_dtype != torch.bfloat16:
use_grad_scaler = True
# Warn about fp16 accumulation instability during training
if PerformanceFeature.Fp16Accumulation in args.fast:
logging.warning(
"WARNING: FP16 model detected with fp16_accumulation enabled. "
"This combination can be numerically unstable during training and may cause NaN values. "
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
)
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
# Prepare latents and compute counts
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
latents, num_images, multi_res = _prepare_latents_and_count(
latents, latents_dtype, bucket_mode
)
# Validate and expand conditioning
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
with torch.inference_mode(False):
# Now ComfyUI will load model in inference mode
# which make all parameter is now inference mode tensors
# to make the training correctly working
# we re-build the parameters in training mode
for module in mp.model.modules():
for name, param in list(module._parameters.items()):
if param is not None:
try:
_ = param._version
except Exception:
module._parameters[name] = torch.nn.Parameter(
param.detach().clone(),
requires_grad=param.requires_grad,
)
# Setup model and dtype
mp = model.clone(force_deepcopy=True)
use_grad_scaler = False
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
if training_dtype != "none":
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
else:
# Detect model's native dtype for autocast
model_dtype = mp.model.get_dtype()
if model_dtype == torch.float16:
dtype = torch.float16
# GradScaler only supports float16 gradients, not bfloat16.
# Only enable it when lora params will also be in float16.
if lora_dtype != torch.bfloat16:
use_grad_scaler = True
# Warn about fp16 accumulation instability during training
if PerformanceFeature.Fp16Accumulation in args.fast:
logging.warning(
"WARNING: FP16 model detected with fp16_accumulation enabled. "
"This combination can be numerically unstable during training and may cause NaN values. "
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
)
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
for name, buf in list(module._buffers.items()):
if buf is not None:
try:
_ = buf._version
except Exception:
module._buffers[name] = buf.detach().clone()
# Prepare latents and compute counts
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
latents, num_images, multi_res = _prepare_latents_and_count(
latents, latents_dtype, bucket_mode
)
# Validate and expand conditioning
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
# Setup models for training
mp.model.requires_grad_(False).train()