mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
Fix LoRA Trainer bugs with FP8 models. (#9854)
* Fix adapter weight init * Fix fp8 model training * Avoid inference tensor
This commit is contained in:
parent
9ed3c5cc09
commit
7be2b49b6b
13
comfy/ops.py
13
comfy/ops.py
@ -365,12 +365,13 @@ class fp8_ops(manual_cast):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
try:
|
if not self.training:
|
||||||
out = fp8_linear(self, input)
|
try:
|
||||||
if out is not None:
|
out = fp8_linear(self, input)
|
||||||
return out
|
if out is not None:
|
||||||
except Exception as e:
|
return out
|
||||||
logging.info("Exception during fp8 op: {}".format(e))
|
except Exception as e:
|
||||||
|
logging.info("Exception during fp8 op: {}".format(e))
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|||||||
@ -130,12 +130,12 @@ class LoHaAdapter(WeightAdapterBase):
|
|||||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
out_dim = weight.shape[0]
|
out_dim = weight.shape[0]
|
||||||
in_dim = weight.shape[1:].numel()
|
in_dim = weight.shape[1:].numel()
|
||||||
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
|
||||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
|
||||||
torch.nn.init.normal_(mat1, 0.1)
|
torch.nn.init.normal_(mat1, 0.1)
|
||||||
torch.nn.init.constant_(mat2, 0.0)
|
torch.nn.init.constant_(mat2, 0.0)
|
||||||
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
|
||||||
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
|
||||||
torch.nn.init.normal_(mat3, 0.1)
|
torch.nn.init.normal_(mat3, 0.1)
|
||||||
torch.nn.init.normal_(mat4, 0.01)
|
torch.nn.init.normal_(mat4, 0.01)
|
||||||
return LohaDiff(
|
return LohaDiff(
|
||||||
|
|||||||
@ -89,8 +89,8 @@ class LoKrAdapter(WeightAdapterBase):
|
|||||||
in_dim = weight.shape[1:].numel()
|
in_dim = weight.shape[1:].numel()
|
||||||
out1, out2 = factorization(out_dim, rank)
|
out1, out2 = factorization(out_dim, rank)
|
||||||
in1, in2 = factorization(in_dim, rank)
|
in1, in2 = factorization(in_dim, rank)
|
||||||
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
|
mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32)
|
||||||
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
|
mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32)
|
||||||
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
|
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
|
||||||
torch.nn.init.constant_(mat1, 0.0)
|
torch.nn.init.constant_(mat1, 0.0)
|
||||||
return LokrDiff(
|
return LokrDiff(
|
||||||
|
|||||||
@ -66,8 +66,8 @@ class LoRAAdapter(WeightAdapterBase):
|
|||||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
out_dim = weight.shape[0]
|
out_dim = weight.shape[0]
|
||||||
in_dim = weight.shape[1:].numel()
|
in_dim = weight.shape[1:].numel()
|
||||||
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
|
||||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
|
||||||
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
|
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
|
||||||
torch.nn.init.constant_(mat2, 0.0)
|
torch.nn.init.constant_(mat2, 0.0)
|
||||||
return LoraDiff(
|
return LoraDiff(
|
||||||
|
|||||||
@ -68,7 +68,7 @@ class OFTAdapter(WeightAdapterBase):
|
|||||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
out_dim = weight.shape[0]
|
out_dim = weight.shape[0]
|
||||||
block_size, block_num = factorization(out_dim, rank)
|
block_size, block_num = factorization(out_dim, rank)
|
||||||
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
|
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32)
|
||||||
return OFTDiff(
|
return OFTDiff(
|
||||||
(block, None, alpha, None)
|
(block, None, alpha, None)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -38,6 +38,23 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
|
|||||||
return new_dict
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
|
def process_cond_list(d, prefix=""):
|
||||||
|
if hasattr(d, "__iter__") and not hasattr(d, "items"):
|
||||||
|
for index, item in enumerate(d):
|
||||||
|
process_cond_list(item, f"{prefix}.{index}")
|
||||||
|
return d
|
||||||
|
elif hasattr(d, "items"):
|
||||||
|
for k, v in list(d.items()):
|
||||||
|
if isinstance(v, dict):
|
||||||
|
process_cond_list(v, f"{prefix}.{k}")
|
||||||
|
elif isinstance(v, torch.Tensor):
|
||||||
|
d[k] = v.clone()
|
||||||
|
elif isinstance(v, (list, tuple)):
|
||||||
|
for index, item in enumerate(v):
|
||||||
|
process_cond_list(item, f"{prefix}.{k}.{index}")
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
class TrainSampler(comfy.samplers.Sampler):
|
class TrainSampler(comfy.samplers.Sampler):
|
||||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
@ -50,6 +67,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self.training_dtype = training_dtype
|
self.training_dtype = training_dtype
|
||||||
|
|
||||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
|
model_wrap.conds = process_cond_list(model_wrap.conds)
|
||||||
cond = model_wrap.conds["positive"]
|
cond = model_wrap.conds["positive"]
|
||||||
dataset_size = sigmas.size(0)
|
dataset_size = sigmas.size(0)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user