force clone inside training mode to avoid inference tensor

This commit is contained in:
Kohaku-Blueleaf 2026-05-19 00:21:51 +08:00
parent da468125da
commit a359c5b654

View File

@ -1143,64 +1143,42 @@ class TrainLoraNode(io.ComfyNode):
# Process conditioning # Process conditioning
positive = _process_conditioning(positive) positive = _process_conditioning(positive)
# Setup model and dtype
mp = model.clone()
use_grad_scaler = False
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
if training_dtype != "none":
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
else:
# Detect model's native dtype for autocast
model_dtype = mp.model.get_dtype()
if model_dtype == torch.float16:
dtype = torch.float16
# GradScaler only supports float16 gradients, not bfloat16.
# Only enable it when lora params will also be in float16.
if lora_dtype != torch.bfloat16:
use_grad_scaler = True
# Warn about fp16 accumulation instability during training
if PerformanceFeature.Fp16Accumulation in args.fast:
logging.warning(
"WARNING: FP16 model detected with fp16_accumulation enabled. "
"This combination can be numerically unstable during training and may cause NaN values. "
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
)
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
# Prepare latents and compute counts
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
latents, num_images, multi_res = _prepare_latents_and_count(
latents, latents_dtype, bucket_mode
)
# Validate and expand conditioning
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
with torch.inference_mode(False): with torch.inference_mode(False):
# Now ComfyUI will load model in inference mode # Setup model and dtype
# which make all parameter is now inference mode tensors mp = model.clone(force_deepcopy=True)
# to make the training correctly working use_grad_scaler = False
# we re-build the parameters in training mode lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
for module in mp.model.modules(): if training_dtype != "none":
for name, param in list(module._parameters.items()): dtype = node_helpers.string_to_torch_dtype(training_dtype)
if param is not None: mp.set_model_compute_dtype(dtype)
try: else:
_ = param._version # Detect model's native dtype for autocast
except Exception: model_dtype = mp.model.get_dtype()
module._parameters[name] = torch.nn.Parameter( if model_dtype == torch.float16:
param.detach().clone(), dtype = torch.float16
requires_grad=param.requires_grad, # GradScaler only supports float16 gradients, not bfloat16.
) # Only enable it when lora params will also be in float16.
if lora_dtype != torch.bfloat16:
use_grad_scaler = True
# Warn about fp16 accumulation instability during training
if PerformanceFeature.Fp16Accumulation in args.fast:
logging.warning(
"WARNING: FP16 model detected with fp16_accumulation enabled. "
"This combination can be numerically unstable during training and may cause NaN values. "
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
)
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
for name, buf in list(module._buffers.items()): # Prepare latents and compute counts
if buf is not None: latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
try: latents, num_images, multi_res = _prepare_latents_and_count(
_ = buf._version latents, latents_dtype, bucket_mode
except Exception: )
module._buffers[name] = buf.detach().clone()
# Validate and expand conditioning
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
# Setup models for training # Setup models for training
mp.model.requires_grad_(False).train() mp.model.requires_grad_(False).train()