This commit is contained in:
Kohaku-Blueleaf 2026-05-20 11:05:14 +08:00 committed by GitHub
commit 84d12071a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1143,45 +1143,45 @@ class TrainLoraNode(io.ComfyNode):
# Process conditioning
positive = _process_conditioning(positive)
# Setup model and dtype
mp = model.clone()
use_grad_scaler = False
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
if training_dtype != "none":
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
else:
# Detect model's native dtype for autocast
model_dtype = mp.model.get_dtype()
if model_dtype == torch.float16:
dtype = torch.float16
# GradScaler only supports float16 gradients, not bfloat16.
# Only enable it when lora params will also be in float16.
if lora_dtype != torch.bfloat16:
use_grad_scaler = True
# Warn about fp16 accumulation instability during training
if PerformanceFeature.Fp16Accumulation in args.fast:
logging.warning(
"WARNING: FP16 model detected with fp16_accumulation enabled. "
"This combination can be numerically unstable during training and may cause NaN values. "
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
)
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
# Prepare latents and compute counts
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
latents, num_images, multi_res = _prepare_latents_and_count(
latents, latents_dtype, bucket_mode
)
# Validate and expand conditioning
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
with torch.inference_mode(False):
# Setup model and dtype
mp = model.clone(force_deepcopy=True)
use_grad_scaler = False
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
if training_dtype != "none":
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
else:
# Detect model's native dtype for autocast
model_dtype = mp.model.get_dtype()
if model_dtype == torch.float16:
dtype = torch.float16
# GradScaler only supports float16 gradients, not bfloat16.
# Only enable it when lora params will also be in float16.
if lora_dtype != torch.bfloat16:
use_grad_scaler = True
# Warn about fp16 accumulation instability during training
if PerformanceFeature.Fp16Accumulation in args.fast:
logging.warning(
"WARNING: FP16 model detected with fp16_accumulation enabled. "
"This combination can be numerically unstable during training and may cause NaN values. "
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
)
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
# Prepare latents and compute counts
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
latents, num_images, multi_res = _prepare_latents_and_count(
latents, latents_dtype, bucket_mode
)
# Validate and expand conditioning
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
# Setup models for training
mp.model.requires_grad_(False)
mp.model.requires_grad_(False).train()
# Load existing LoRA weights if provided
existing_weights, existing_steps = _load_existing_lora(existing_lora)