mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-30 00:00:26 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
30539d0d13
@ -1210,39 +1210,21 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
return x_next
|
return x_next
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
|
|
||||||
temp = [0]
|
|
||||||
def post_cfg_function(args):
|
|
||||||
temp[0] = args["uncond_denoised"]
|
|
||||||
return args["denoised"]
|
|
||||||
|
|
||||||
model_options = extra_args.get("model_options", {}).copy()
|
|
||||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
|
||||||
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
|
||||||
sigma_hat = sigmas[i]
|
|
||||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
|
||||||
d = to_d(x, sigma_hat, temp[0])
|
|
||||||
if callback is not None:
|
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
|
||||||
# Euler method
|
|
||||||
x = denoised + d * sigmas[i + 1]
|
|
||||||
return x
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with Euler method steps."""
|
"""Ancestral sampling with Euler method steps (CFG++)."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
|
||||||
temp = [0]
|
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||||
|
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||||
|
|
||||||
|
uncond_denoised = None
|
||||||
|
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
temp[0] = args["uncond_denoised"]
|
nonlocal uncond_denoised
|
||||||
|
uncond_denoised = args["uncond_denoised"]
|
||||||
return args["denoised"]
|
return args["denoised"]
|
||||||
|
|
||||||
model_options = extra_args.get("model_options", {}).copy()
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
@ -1251,15 +1233,33 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
|||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
d = to_d(x, sigmas[i], temp[0])
|
if sigmas[i + 1] == 0:
|
||||||
# Euler method
|
# Denoising step
|
||||||
x = denoised + d * sigma_down
|
x = denoised
|
||||||
if sigmas[i + 1] > 0:
|
else:
|
||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
|
||||||
|
alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
|
||||||
|
d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
|
||||||
|
|
||||||
|
# DDIM stochastic sampling
|
||||||
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
|
||||||
|
sigma_down = alpha_t * sigma_down
|
||||||
|
|
||||||
|
# Euler method
|
||||||
|
x = alpha_t * denoised + sigma_down * d
|
||||||
|
if eta > 0 and s_noise > 0:
|
||||||
|
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||||
|
"""Euler method steps (CFG++)."""
|
||||||
|
return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
|
|||||||
@ -187,10 +187,7 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
elif is_intel_xpu():
|
elif is_intel_xpu():
|
||||||
stats = torch.xpu.memory_stats(dev)
|
stats = torch.xpu.memory_stats(dev)
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
if torch_version_numeric < (2, 6):
|
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
|
||||||
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
|
|
||||||
else:
|
|
||||||
_, mem_total_xpu = torch.xpu.mem_get_info(dev)
|
|
||||||
mem_total_torch = mem_reserved
|
mem_total_torch = mem_reserved
|
||||||
mem_total = mem_total_xpu
|
mem_total = mem_total_xpu
|
||||||
elif is_ascend_npu():
|
elif is_ascend_npu():
|
||||||
@ -312,7 +309,10 @@ try:
|
|||||||
logging.info("ROCm version: {}".format(rocm_version))
|
logging.info("ROCm version: {}".format(rocm_version))
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950
|
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||||
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
|
if torch_version_numeric >= (2, 8):
|
||||||
|
if any((a in arch) for a in ["gfx1201"]):
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||||
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||||
@ -1111,10 +1111,7 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||||||
stats = torch.xpu.memory_stats(dev)
|
stats = torch.xpu.memory_stats(dev)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
if torch_version_numeric < (2, 6):
|
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
||||||
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
|
||||||
else:
|
|
||||||
mem_free_xpu, _ = torch.xpu.mem_get_info(dev)
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
mem_free_torch = mem_reserved - mem_active
|
||||||
mem_free_total = mem_free_xpu + mem_free_torch
|
mem_free_total = mem_free_xpu + mem_free_torch
|
||||||
elif is_ascend_npu():
|
elif is_ascend_npu():
|
||||||
|
|||||||
@ -15,9 +15,20 @@ adapters: list[type[WeightAdapterBase]] = [
|
|||||||
OFTAdapter,
|
OFTAdapter,
|
||||||
BOFTAdapter,
|
BOFTAdapter,
|
||||||
]
|
]
|
||||||
|
adapter_maps: dict[str, type[WeightAdapterBase]] = {
|
||||||
|
"LoRA": LoRAAdapter,
|
||||||
|
"LoHa": LoHaAdapter,
|
||||||
|
"LoKr": LoKrAdapter,
|
||||||
|
"OFT": OFTAdapter,
|
||||||
|
## We disable not implemented algo for now
|
||||||
|
# "GLoRA": GLoRAAdapter,
|
||||||
|
# "BOFT": BOFTAdapter,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"WeightAdapterBase",
|
"WeightAdapterBase",
|
||||||
"WeightAdapterTrainBase",
|
"WeightAdapterTrainBase",
|
||||||
"adapters"
|
"adapters",
|
||||||
|
"adapter_maps",
|
||||||
] + [a.__name__ for a in adapters]
|
] + [a.__name__ for a in adapters]
|
||||||
|
|||||||
@ -133,3 +133,43 @@ def tucker_weight_from_conv(up, down, mid):
|
|||||||
def tucker_weight(wa, wb, t):
|
def tucker_weight(wa, wb, t):
|
||||||
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
|
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
|
||||||
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
|
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
|
||||||
|
|
||||||
|
|
||||||
|
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||||
|
second value is higher or equal than first value.
|
||||||
|
|
||||||
|
examples)
|
||||||
|
factor
|
||||||
|
-1 2 4 8 16 ...
|
||||||
|
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
|
||||||
|
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
|
||||||
|
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
|
||||||
|
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
|
||||||
|
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
|
||||||
|
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
|
||||||
|
"""
|
||||||
|
|
||||||
|
if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2:
|
||||||
|
m = factor
|
||||||
|
n = dimension // factor
|
||||||
|
if m > n:
|
||||||
|
n, m = m, n
|
||||||
|
return m, n
|
||||||
|
if factor < 0:
|
||||||
|
factor = dimension
|
||||||
|
m, n = 1, dimension
|
||||||
|
length = m + n
|
||||||
|
while m < n:
|
||||||
|
new_m = m + 1
|
||||||
|
while dimension % new_m != 0:
|
||||||
|
new_m += 1
|
||||||
|
new_n = dimension // new_m
|
||||||
|
if new_m + new_n > length or new_m > factor:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
m, n = new_m, new_n
|
||||||
|
if m > n:
|
||||||
|
n, m = m, n
|
||||||
|
return m, n
|
||||||
|
|||||||
@ -3,7 +3,120 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class HadaWeight(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
|
||||||
|
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
|
||||||
|
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
|
||||||
|
return diff_weight
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
|
||||||
|
grad_out = grad_out * scale
|
||||||
|
temp = grad_out * (w2u @ w2d)
|
||||||
|
grad_w1u = temp @ w1d.T
|
||||||
|
grad_w1d = w1u.T @ temp
|
||||||
|
|
||||||
|
temp = grad_out * (w1u @ w1d)
|
||||||
|
grad_w2u = temp @ w2d.T
|
||||||
|
grad_w2d = w2u.T @ temp
|
||||||
|
|
||||||
|
del temp
|
||||||
|
return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None
|
||||||
|
|
||||||
|
|
||||||
|
class HadaWeightTucker(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)):
|
||||||
|
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
|
||||||
|
|
||||||
|
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
|
||||||
|
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
|
||||||
|
|
||||||
|
return rebuild1 * rebuild2 * scale
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
|
||||||
|
grad_out = grad_out * scale
|
||||||
|
|
||||||
|
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
|
||||||
|
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
|
||||||
|
|
||||||
|
grad_w = rebuild * grad_out
|
||||||
|
del rebuild
|
||||||
|
|
||||||
|
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||||
|
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
|
||||||
|
del grad_w, temp
|
||||||
|
|
||||||
|
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
|
||||||
|
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
|
||||||
|
del grad_temp
|
||||||
|
|
||||||
|
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
|
||||||
|
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
|
||||||
|
|
||||||
|
grad_w = rebuild * grad_out
|
||||||
|
del rebuild
|
||||||
|
|
||||||
|
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||||
|
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
|
||||||
|
del grad_w, temp
|
||||||
|
|
||||||
|
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
|
||||||
|
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
|
||||||
|
del grad_temp
|
||||||
|
return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None
|
||||||
|
|
||||||
|
|
||||||
|
class LohaDiff(WeightAdapterTrainBase):
|
||||||
|
def __init__(self, weights):
|
||||||
|
super().__init__()
|
||||||
|
# Unpack weights tuple from LoHaAdapter
|
||||||
|
w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights
|
||||||
|
|
||||||
|
# Create trainable parameters
|
||||||
|
self.hada_w1_a = torch.nn.Parameter(w1a)
|
||||||
|
self.hada_w1_b = torch.nn.Parameter(w1b)
|
||||||
|
self.hada_w2_a = torch.nn.Parameter(w2a)
|
||||||
|
self.hada_w2_b = torch.nn.Parameter(w2b)
|
||||||
|
|
||||||
|
self.use_tucker = False
|
||||||
|
if t1 is not None and t2 is not None:
|
||||||
|
self.use_tucker = True
|
||||||
|
self.hada_t1 = torch.nn.Parameter(t1)
|
||||||
|
self.hada_t2 = torch.nn.Parameter(t2)
|
||||||
|
else:
|
||||||
|
# Keep the attributes for consistent access
|
||||||
|
self.hada_t1 = None
|
||||||
|
self.hada_t2 = None
|
||||||
|
|
||||||
|
# Store rank and non-trainable alpha
|
||||||
|
self.rank = w1b.shape[0]
|
||||||
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||||
|
|
||||||
|
def __call__(self, w):
|
||||||
|
org_dtype = w.dtype
|
||||||
|
|
||||||
|
scale = self.alpha / self.rank
|
||||||
|
if self.use_tucker:
|
||||||
|
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
|
||||||
|
else:
|
||||||
|
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||||
|
|
||||||
|
# Add the scaled difference to the original weight
|
||||||
|
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
|
||||||
|
|
||||||
|
return weight.to(org_dtype)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
"""Calculates memory usage of the trainable parameters."""
|
||||||
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||||
|
|
||||||
|
|
||||||
class LoHaAdapter(WeightAdapterBase):
|
class LoHaAdapter(WeightAdapterBase):
|
||||||
@ -13,6 +126,25 @@ class LoHaAdapter(WeightAdapterBase):
|
|||||||
self.loaded_keys = loaded_keys
|
self.loaded_keys = loaded_keys
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
|
out_dim = weight.shape[0]
|
||||||
|
in_dim = weight.shape[1:].numel()
|
||||||
|
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||||
|
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||||
|
torch.nn.init.normal_(mat1, 0.1)
|
||||||
|
torch.nn.init.constant_(mat2, 0.0)
|
||||||
|
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||||
|
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||||
|
torch.nn.init.normal_(mat3, 0.1)
|
||||||
|
torch.nn.init.normal_(mat4, 0.01)
|
||||||
|
return LohaDiff(
|
||||||
|
(mat1, mat2, alpha, mat3, mat4, None, None, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_train(self):
|
||||||
|
return LohaDiff(self.weights)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -3,7 +3,77 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import (
|
||||||
|
WeightAdapterBase,
|
||||||
|
WeightAdapterTrainBase,
|
||||||
|
weight_decompose,
|
||||||
|
factorization,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LokrDiff(WeightAdapterTrainBase):
|
||||||
|
def __init__(self, weights):
|
||||||
|
super().__init__()
|
||||||
|
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
|
||||||
|
self.use_tucker = False
|
||||||
|
if lokr_w1_a is not None:
|
||||||
|
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
|
||||||
|
rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1]
|
||||||
|
self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a)
|
||||||
|
self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b)
|
||||||
|
self.w1_rebuild = True
|
||||||
|
self.ranka = rank_a
|
||||||
|
|
||||||
|
if lokr_w2_a is not None:
|
||||||
|
_, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1]
|
||||||
|
rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1]
|
||||||
|
self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a)
|
||||||
|
self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b)
|
||||||
|
if lokr_t2 is not None:
|
||||||
|
self.use_tucker = True
|
||||||
|
self.lokr_t2 = torch.nn.Parameter(lokr_t2)
|
||||||
|
self.w2_rebuild = True
|
||||||
|
self.rankb = rank_b
|
||||||
|
|
||||||
|
if lokr_w1 is not None:
|
||||||
|
self.lokr_w1 = torch.nn.Parameter(lokr_w1)
|
||||||
|
self.w1_rebuild = False
|
||||||
|
|
||||||
|
if lokr_w2 is not None:
|
||||||
|
self.lokr_w2 = torch.nn.Parameter(lokr_w2)
|
||||||
|
self.w2_rebuild = False
|
||||||
|
|
||||||
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w1(self):
|
||||||
|
if self.w1_rebuild:
|
||||||
|
return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka)
|
||||||
|
else:
|
||||||
|
return self.lokr_w1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w2(self):
|
||||||
|
if self.w2_rebuild:
|
||||||
|
if self.use_tucker:
|
||||||
|
w2 = torch.einsum(
|
||||||
|
'i j k l, j r, i p -> p r k l',
|
||||||
|
self.lokr_t2,
|
||||||
|
self.lokr_w2_b,
|
||||||
|
self.lokr_w2_a
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
||||||
|
return w2 * (self.alpha / self.rankb)
|
||||||
|
else:
|
||||||
|
return self.lokr_w2
|
||||||
|
|
||||||
|
def __call__(self, w):
|
||||||
|
diff = torch.kron(self.w1, self.w2)
|
||||||
|
return w + diff.reshape(w.shape).to(w)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||||
|
|
||||||
|
|
||||||
class LoKrAdapter(WeightAdapterBase):
|
class LoKrAdapter(WeightAdapterBase):
|
||||||
@ -13,6 +83,20 @@ class LoKrAdapter(WeightAdapterBase):
|
|||||||
self.loaded_keys = loaded_keys
|
self.loaded_keys = loaded_keys
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
|
out_dim = weight.shape[0]
|
||||||
|
in_dim = weight.shape[1:].numel()
|
||||||
|
out1, out2 = factorization(out_dim, rank)
|
||||||
|
in1, in2 = factorization(in_dim, rank)
|
||||||
|
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
|
||||||
|
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
|
||||||
|
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
|
||||||
|
torch.nn.init.constant_(mat1, 0.0)
|
||||||
|
return LokrDiff(
|
||||||
|
(mat1, mat2, alpha, None, None, None, None, None, None)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -3,7 +3,58 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
|
||||||
|
|
||||||
|
|
||||||
|
class OFTDiff(WeightAdapterTrainBase):
|
||||||
|
def __init__(self, weights):
|
||||||
|
super().__init__()
|
||||||
|
# Unpack weights tuple from LoHaAdapter
|
||||||
|
blocks, rescale, alpha, _ = weights
|
||||||
|
|
||||||
|
# Create trainable parameters
|
||||||
|
self.oft_blocks = torch.nn.Parameter(blocks)
|
||||||
|
if rescale is not None:
|
||||||
|
self.rescale = torch.nn.Parameter(rescale)
|
||||||
|
self.rescaled = True
|
||||||
|
else:
|
||||||
|
self.rescaled = False
|
||||||
|
self.block_num, self.block_size, _ = blocks.shape
|
||||||
|
self.constraint = float(alpha)
|
||||||
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||||
|
|
||||||
|
def __call__(self, w):
|
||||||
|
org_dtype = w.dtype
|
||||||
|
I = torch.eye(self.block_size, device=self.oft_blocks.device)
|
||||||
|
|
||||||
|
## generate r
|
||||||
|
# for Q = -Q^T
|
||||||
|
q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
|
||||||
|
normed_q = q
|
||||||
|
if self.constraint:
|
||||||
|
q_norm = torch.norm(q) + 1e-8
|
||||||
|
if q_norm > self.constraint:
|
||||||
|
normed_q = q * self.constraint / q_norm
|
||||||
|
# use float() to prevent unsupported type
|
||||||
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||||
|
|
||||||
|
## Apply chunked matmul on weight
|
||||||
|
_, *shape = w.shape
|
||||||
|
org_weight = w.to(dtype=r.dtype)
|
||||||
|
org_weight = org_weight.unflatten(0, (self.block_num, self.block_size))
|
||||||
|
# Init R=0, so add I on it to ensure the output of step0 is original model output
|
||||||
|
weight = torch.einsum(
|
||||||
|
"k n m, k n ... -> k m ...",
|
||||||
|
r,
|
||||||
|
org_weight,
|
||||||
|
).flatten(0, 1)
|
||||||
|
if self.rescaled:
|
||||||
|
weight = self.rescale * weight
|
||||||
|
return weight.to(org_dtype)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
"""Calculates memory usage of the trainable parameters."""
|
||||||
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||||
|
|
||||||
|
|
||||||
class OFTAdapter(WeightAdapterBase):
|
class OFTAdapter(WeightAdapterBase):
|
||||||
@ -13,6 +64,18 @@ class OFTAdapter(WeightAdapterBase):
|
|||||||
self.loaded_keys = loaded_keys
|
self.loaded_keys = loaded_keys
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
|
out_dim = weight.shape[0]
|
||||||
|
block_size, block_num = factorization(out_dim, rank)
|
||||||
|
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
|
||||||
|
return OFTDiff(
|
||||||
|
(block, None, alpha, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_train(self):
|
||||||
|
return OFTDiff(self.weights)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
@ -60,6 +123,8 @@ class OFTAdapter(WeightAdapterBase):
|
|||||||
blocks = v[0]
|
blocks = v[0]
|
||||||
rescale = v[1]
|
rescale = v[1]
|
||||||
alpha = v[2]
|
alpha = v[2]
|
||||||
|
if alpha is None:
|
||||||
|
alpha = 0
|
||||||
dora_scale = v[3]
|
dora_scale = v[3]
|
||||||
|
|
||||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||||
|
|||||||
@ -20,7 +20,7 @@ import folder_paths
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy.comfy_types.node_typing import IO
|
from comfy.comfy_types.node_typing import IO
|
||||||
from comfy.weight_adapter import adapters
|
from comfy.weight_adapter import adapters, adapter_maps
|
||||||
|
|
||||||
|
|
||||||
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||||
@ -39,13 +39,13 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
|
|||||||
|
|
||||||
|
|
||||||
class TrainSampler(comfy.samplers.Sampler):
|
class TrainSampler(comfy.samplers.Sampler):
|
||||||
|
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.loss_callback = loss_callback
|
self.loss_callback = loss_callback
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.total_steps = total_steps
|
self.total_steps = total_steps
|
||||||
|
self.grad_acc = grad_acc
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.training_dtype = training_dtype
|
self.training_dtype = training_dtype
|
||||||
|
|
||||||
@ -92,8 +92,9 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self.loss_callback(loss.item())
|
self.loss_callback(loss.item())
|
||||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||||
|
|
||||||
self.optimizer.step()
|
if (i+1) % self.grad_acc == 0:
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return torch.zeros_like(latent_image)
|
return torch.zeros_like(latent_image)
|
||||||
|
|
||||||
@ -419,6 +420,16 @@ class TrainLoraNode:
|
|||||||
"tooltip": "The batch size to use for training.",
|
"tooltip": "The batch size to use for training.",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
"grad_accumulation_steps": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 1,
|
||||||
|
"min": 1,
|
||||||
|
"max": 1024,
|
||||||
|
"step": 1,
|
||||||
|
"tooltip": "The number of gradient accumulation steps to use for training.",
|
||||||
|
}
|
||||||
|
),
|
||||||
"steps": (
|
"steps": (
|
||||||
IO.INT,
|
IO.INT,
|
||||||
{
|
{
|
||||||
@ -478,6 +489,17 @@ class TrainLoraNode:
|
|||||||
["bf16", "fp32"],
|
["bf16", "fp32"],
|
||||||
{"default": "bf16", "tooltip": "The dtype to use for lora."},
|
{"default": "bf16", "tooltip": "The dtype to use for lora."},
|
||||||
),
|
),
|
||||||
|
"algorithm": (
|
||||||
|
list(adapter_maps.keys()),
|
||||||
|
{"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."},
|
||||||
|
),
|
||||||
|
"gradient_checkpointing": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Use gradient checkpointing for training.",
|
||||||
|
}
|
||||||
|
),
|
||||||
"existing_lora": (
|
"existing_lora": (
|
||||||
folder_paths.get_filename_list("loras") + ["[None]"],
|
folder_paths.get_filename_list("loras") + ["[None]"],
|
||||||
{
|
{
|
||||||
@ -501,6 +523,7 @@ class TrainLoraNode:
|
|||||||
positive,
|
positive,
|
||||||
batch_size,
|
batch_size,
|
||||||
steps,
|
steps,
|
||||||
|
grad_accumulation_steps,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
rank,
|
rank,
|
||||||
optimizer,
|
optimizer,
|
||||||
@ -508,6 +531,8 @@ class TrainLoraNode:
|
|||||||
seed,
|
seed,
|
||||||
training_dtype,
|
training_dtype,
|
||||||
lora_dtype,
|
lora_dtype,
|
||||||
|
algorithm,
|
||||||
|
gradient_checkpointing,
|
||||||
existing_lora,
|
existing_lora,
|
||||||
):
|
):
|
||||||
mp = model.clone()
|
mp = model.clone()
|
||||||
@ -558,10 +583,8 @@ class TrainLoraNode:
|
|||||||
if existing_adapter is not None:
|
if existing_adapter is not None:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# If no existing adapter found, use LoRA
|
|
||||||
# We will add algo option in the future
|
|
||||||
existing_adapter = None
|
existing_adapter = None
|
||||||
adapter_cls = adapters[0]
|
adapter_cls = adapter_maps[algorithm]
|
||||||
|
|
||||||
if existing_adapter is not None:
|
if existing_adapter is not None:
|
||||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||||
@ -615,8 +638,9 @@ class TrainLoraNode:
|
|||||||
criterion = torch.nn.SmoothL1Loss()
|
criterion = torch.nn.SmoothL1Loss()
|
||||||
|
|
||||||
# setup models
|
# setup models
|
||||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
if gradient_checkpointing:
|
||||||
patch(m)
|
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||||
|
patch(m)
|
||||||
mp.model.requires_grad_(False)
|
mp.model.requires_grad_(False)
|
||||||
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
||||||
|
|
||||||
@ -629,7 +653,8 @@ class TrainLoraNode:
|
|||||||
optimizer,
|
optimizer,
|
||||||
loss_callback=loss_callback,
|
loss_callback=loss_callback,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
total_steps=steps,
|
grad_acc=grad_accumulation_steps,
|
||||||
|
total_steps=steps*grad_accumulation_steps,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
training_dtype=dtype
|
training_dtype=dtype
|
||||||
)
|
)
|
||||||
|
|||||||
@ -77,7 +77,8 @@ if not args.cuda_malloc:
|
|||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
version = module.__version__
|
version = module.__version__
|
||||||
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
|
|
||||||
|
if int(version[0]) >= 2 and "+cu" in version: #enable by default for torch version 2.0 and up only on cuda torch
|
||||||
args.cuda_malloc = cuda_malloc_supported()
|
args.cuda_malloc = cuda_malloc_supported()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user