mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-17 13:30:16 +08:00
Merge a359c5b654 into 7ec7b6ffe9
This commit is contained in:
commit
84d12071a4
@ -1143,45 +1143,45 @@ class TrainLoraNode(io.ComfyNode):
|
||||
# Process conditioning
|
||||
positive = _process_conditioning(positive)
|
||||
|
||||
# Setup model and dtype
|
||||
mp = model.clone()
|
||||
use_grad_scaler = False
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
if training_dtype != "none":
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
else:
|
||||
# Detect model's native dtype for autocast
|
||||
model_dtype = mp.model.get_dtype()
|
||||
if model_dtype == torch.float16:
|
||||
dtype = torch.float16
|
||||
# GradScaler only supports float16 gradients, not bfloat16.
|
||||
# Only enable it when lora params will also be in float16.
|
||||
if lora_dtype != torch.bfloat16:
|
||||
use_grad_scaler = True
|
||||
# Warn about fp16 accumulation instability during training
|
||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
logging.warning(
|
||||
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
||||
"This combination can be numerically unstable during training and may cause NaN values. "
|
||||
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
||||
)
|
||||
else:
|
||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||
latents, latents_dtype, bucket_mode
|
||||
)
|
||||
|
||||
# Validate and expand conditioning
|
||||
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||
|
||||
with torch.inference_mode(False):
|
||||
# Setup model and dtype
|
||||
mp = model.clone(force_deepcopy=True)
|
||||
use_grad_scaler = False
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
if training_dtype != "none":
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
else:
|
||||
# Detect model's native dtype for autocast
|
||||
model_dtype = mp.model.get_dtype()
|
||||
if model_dtype == torch.float16:
|
||||
dtype = torch.float16
|
||||
# GradScaler only supports float16 gradients, not bfloat16.
|
||||
# Only enable it when lora params will also be in float16.
|
||||
if lora_dtype != torch.bfloat16:
|
||||
use_grad_scaler = True
|
||||
# Warn about fp16 accumulation instability during training
|
||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
logging.warning(
|
||||
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
||||
"This combination can be numerically unstable during training and may cause NaN values. "
|
||||
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
||||
)
|
||||
else:
|
||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||
latents, latents_dtype, bucket_mode
|
||||
)
|
||||
|
||||
# Validate and expand conditioning
|
||||
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||
|
||||
# Setup models for training
|
||||
mp.model.requires_grad_(False)
|
||||
mp.model.requires_grad_(False).train()
|
||||
|
||||
# Load existing LoRA weights if provided
|
||||
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user