mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-19 22:39:24 +08:00
force clone inside training mode to avoid inference tensor
This commit is contained in:
parent
da468125da
commit
a359c5b654
@ -1143,64 +1143,42 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
# Process conditioning
|
# Process conditioning
|
||||||
positive = _process_conditioning(positive)
|
positive = _process_conditioning(positive)
|
||||||
|
|
||||||
# Setup model and dtype
|
|
||||||
mp = model.clone()
|
|
||||||
use_grad_scaler = False
|
|
||||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
|
||||||
if training_dtype != "none":
|
|
||||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
|
||||||
mp.set_model_compute_dtype(dtype)
|
|
||||||
else:
|
|
||||||
# Detect model's native dtype for autocast
|
|
||||||
model_dtype = mp.model.get_dtype()
|
|
||||||
if model_dtype == torch.float16:
|
|
||||||
dtype = torch.float16
|
|
||||||
# GradScaler only supports float16 gradients, not bfloat16.
|
|
||||||
# Only enable it when lora params will also be in float16.
|
|
||||||
if lora_dtype != torch.bfloat16:
|
|
||||||
use_grad_scaler = True
|
|
||||||
# Warn about fp16 accumulation instability during training
|
|
||||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
|
||||||
logging.warning(
|
|
||||||
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
|
||||||
"This combination can be numerically unstable during training and may cause NaN values. "
|
|
||||||
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
|
|
||||||
# Prepare latents and compute counts
|
|
||||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
|
||||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
|
||||||
latents, latents_dtype, bucket_mode
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate and expand conditioning
|
|
||||||
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
|
||||||
|
|
||||||
with torch.inference_mode(False):
|
with torch.inference_mode(False):
|
||||||
# Now ComfyUI will load model in inference mode
|
# Setup model and dtype
|
||||||
# which make all parameter is now inference mode tensors
|
mp = model.clone(force_deepcopy=True)
|
||||||
# to make the training correctly working
|
use_grad_scaler = False
|
||||||
# we re-build the parameters in training mode
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
for module in mp.model.modules():
|
if training_dtype != "none":
|
||||||
for name, param in list(module._parameters.items()):
|
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||||
if param is not None:
|
mp.set_model_compute_dtype(dtype)
|
||||||
try:
|
else:
|
||||||
_ = param._version
|
# Detect model's native dtype for autocast
|
||||||
except Exception:
|
model_dtype = mp.model.get_dtype()
|
||||||
module._parameters[name] = torch.nn.Parameter(
|
if model_dtype == torch.float16:
|
||||||
param.detach().clone(),
|
dtype = torch.float16
|
||||||
requires_grad=param.requires_grad,
|
# GradScaler only supports float16 gradients, not bfloat16.
|
||||||
)
|
# Only enable it when lora params will also be in float16.
|
||||||
|
if lora_dtype != torch.bfloat16:
|
||||||
|
use_grad_scaler = True
|
||||||
|
# Warn about fp16 accumulation instability during training
|
||||||
|
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||||
|
logging.warning(
|
||||||
|
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
||||||
|
"This combination can be numerically unstable during training and may cause NaN values. "
|
||||||
|
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
for name, buf in list(module._buffers.items()):
|
# Prepare latents and compute counts
|
||||||
if buf is not None:
|
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||||
try:
|
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||||
_ = buf._version
|
latents, latents_dtype, bucket_mode
|
||||||
except Exception:
|
)
|
||||||
module._buffers[name] = buf.detach().clone()
|
|
||||||
|
# Validate and expand conditioning
|
||||||
|
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||||
|
|
||||||
# Setup models for training
|
# Setup models for training
|
||||||
mp.model.requires_grad_(False).train()
|
mp.model.requires_grad_(False).train()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user