mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-21 15:29:32 +08:00
force clone inside training mode to avoid inference tensor
This commit is contained in:
parent
da468125da
commit
a359c5b654
@ -1143,8 +1143,9 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
# Process conditioning
|
# Process conditioning
|
||||||
positive = _process_conditioning(positive)
|
positive = _process_conditioning(positive)
|
||||||
|
|
||||||
|
with torch.inference_mode(False):
|
||||||
# Setup model and dtype
|
# Setup model and dtype
|
||||||
mp = model.clone()
|
mp = model.clone(force_deepcopy=True)
|
||||||
use_grad_scaler = False
|
use_grad_scaler = False
|
||||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
if training_dtype != "none":
|
if training_dtype != "none":
|
||||||
@ -1179,29 +1180,6 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
# Validate and expand conditioning
|
# Validate and expand conditioning
|
||||||
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||||
|
|
||||||
with torch.inference_mode(False):
|
|
||||||
# Now ComfyUI will load model in inference mode
|
|
||||||
# which make all parameter is now inference mode tensors
|
|
||||||
# to make the training correctly working
|
|
||||||
# we re-build the parameters in training mode
|
|
||||||
for module in mp.model.modules():
|
|
||||||
for name, param in list(module._parameters.items()):
|
|
||||||
if param is not None:
|
|
||||||
try:
|
|
||||||
_ = param._version
|
|
||||||
except Exception:
|
|
||||||
module._parameters[name] = torch.nn.Parameter(
|
|
||||||
param.detach().clone(),
|
|
||||||
requires_grad=param.requires_grad,
|
|
||||||
)
|
|
||||||
|
|
||||||
for name, buf in list(module._buffers.items()):
|
|
||||||
if buf is not None:
|
|
||||||
try:
|
|
||||||
_ = buf._version
|
|
||||||
except Exception:
|
|
||||||
module._buffers[name] = buf.detach().clone()
|
|
||||||
|
|
||||||
# Setup models for training
|
# Setup models for training
|
||||||
mp.model.requires_grad_(False).train()
|
mp.model.requires_grad_(False).train()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user