mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 21:00:16 +08:00
Remove unused k_diffusion utils code
This commit is contained in:
parent
426e8e66a5
commit
aab3b6c5a4
@ -1,23 +1,3 @@
|
|||||||
from contextlib import contextmanager
|
|
||||||
import hashlib
|
|
||||||
import math
|
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
|
||||||
import urllib
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
from torch import nn, optim
|
|
||||||
from torch.utils import data
|
|
||||||
|
|
||||||
|
|
||||||
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
|
|
||||||
"""Apply passed in transforms for HuggingFace Datasets."""
|
|
||||||
images = [transform(image.convert(mode)) for image in examples[image_key]]
|
|
||||||
return {image_key: images}
|
|
||||||
|
|
||||||
|
|
||||||
def append_dims(x, target_dims):
|
def append_dims(x, target_dims):
|
||||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||||
dims_to_append = target_dims - x.ndim
|
dims_to_append = target_dims - x.ndim
|
||||||
@ -27,287 +7,3 @@ def append_dims(x, target_dims):
|
|||||||
# MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
|
# MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
|
||||||
# https://github.com/pytorch/pytorch/issues/84364
|
# https://github.com/pytorch/pytorch/issues/84364
|
||||||
return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
|
return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
|
||||||
|
|
||||||
|
|
||||||
def n_params(module):
|
|
||||||
"""Returns the number of trainable parameters in a module."""
|
|
||||||
return sum(p.numel() for p in module.parameters())
|
|
||||||
|
|
||||||
|
|
||||||
def download_file(path, url, digest=None):
|
|
||||||
"""Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
|
|
||||||
path = Path(path)
|
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
if not path.exists():
|
|
||||||
with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
|
|
||||||
shutil.copyfileobj(response, f)
|
|
||||||
if digest is not None:
|
|
||||||
file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
|
|
||||||
if digest != file_digest:
|
|
||||||
raise OSError(f'hash of {path} (url: {url}) failed to validate')
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def train_mode(model, mode=True):
|
|
||||||
"""A context manager that places a model into training mode and restores
|
|
||||||
the previous mode on exit."""
|
|
||||||
modes = [module.training for module in model.modules()]
|
|
||||||
try:
|
|
||||||
yield model.train(mode)
|
|
||||||
finally:
|
|
||||||
for i, module in enumerate(model.modules()):
|
|
||||||
module.training = modes[i]
|
|
||||||
|
|
||||||
|
|
||||||
def eval_mode(model):
|
|
||||||
"""A context manager that places a model into evaluation mode and restores
|
|
||||||
the previous mode on exit."""
|
|
||||||
return train_mode(model, False)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def ema_update(model, averaged_model, decay):
|
|
||||||
"""Incorporates updated model parameters into an exponential moving averaged
|
|
||||||
version of a model. It should be called after each optimizer step."""
|
|
||||||
model_params = dict(model.named_parameters())
|
|
||||||
averaged_params = dict(averaged_model.named_parameters())
|
|
||||||
assert model_params.keys() == averaged_params.keys()
|
|
||||||
|
|
||||||
for name, param in model_params.items():
|
|
||||||
averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
|
|
||||||
|
|
||||||
model_buffers = dict(model.named_buffers())
|
|
||||||
averaged_buffers = dict(averaged_model.named_buffers())
|
|
||||||
assert model_buffers.keys() == averaged_buffers.keys()
|
|
||||||
|
|
||||||
for name, buf in model_buffers.items():
|
|
||||||
averaged_buffers[name].copy_(buf)
|
|
||||||
|
|
||||||
|
|
||||||
class EMAWarmup:
|
|
||||||
"""Implements an EMA warmup using an inverse decay schedule.
|
|
||||||
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
|
|
||||||
good values for models you plan to train for a million or more steps (reaches decay
|
|
||||||
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
|
|
||||||
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
|
||||||
215.4k steps).
|
|
||||||
Args:
|
|
||||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
|
||||||
power (float): Exponential factor of EMA warmup. Default: 1.
|
|
||||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
|
||||||
max_value (float): The maximum EMA decay rate. Default: 1.
|
|
||||||
start_at (int): The epoch to start averaging at. Default: 0.
|
|
||||||
last_epoch (int): The index of last epoch. Default: 0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
|
|
||||||
last_epoch=0):
|
|
||||||
self.inv_gamma = inv_gamma
|
|
||||||
self.power = power
|
|
||||||
self.min_value = min_value
|
|
||||||
self.max_value = max_value
|
|
||||||
self.start_at = start_at
|
|
||||||
self.last_epoch = last_epoch
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
"""Returns the state of the class as a :class:`dict`."""
|
|
||||||
return dict(self.__dict__.items())
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
|
||||||
"""Loads the class's state.
|
|
||||||
Args:
|
|
||||||
state_dict (dict): scaler state. Should be an object returned
|
|
||||||
from a call to :meth:`state_dict`.
|
|
||||||
"""
|
|
||||||
self.__dict__.update(state_dict)
|
|
||||||
|
|
||||||
def get_value(self):
|
|
||||||
"""Gets the current EMA decay rate."""
|
|
||||||
epoch = max(0, self.last_epoch - self.start_at)
|
|
||||||
value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
|
|
||||||
return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
|
|
||||||
|
|
||||||
def step(self):
|
|
||||||
"""Updates the step count."""
|
|
||||||
self.last_epoch += 1
|
|
||||||
|
|
||||||
|
|
||||||
class InverseLR(optim.lr_scheduler._LRScheduler):
|
|
||||||
"""Implements an inverse decay learning rate schedule with an optional exponential
|
|
||||||
warmup. When last_epoch=-1, sets initial lr as lr.
|
|
||||||
inv_gamma is the number of steps/epochs required for the learning rate to decay to
|
|
||||||
(1 / 2)**power of its original value.
|
|
||||||
Args:
|
|
||||||
optimizer (Optimizer): Wrapped optimizer.
|
|
||||||
inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
|
|
||||||
power (float): Exponential factor of learning rate decay. Default: 1.
|
|
||||||
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
|
||||||
Default: 0.
|
|
||||||
min_lr (float): The minimum learning rate. Default: 0.
|
|
||||||
last_epoch (int): The index of last epoch. Default: -1.
|
|
||||||
verbose (bool): If ``True``, prints a message to stdout for
|
|
||||||
each update. Default: ``False``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
|
|
||||||
last_epoch=-1, verbose=False):
|
|
||||||
self.inv_gamma = inv_gamma
|
|
||||||
self.power = power
|
|
||||||
if not 0. <= warmup < 1:
|
|
||||||
raise ValueError('Invalid value for warmup')
|
|
||||||
self.warmup = warmup
|
|
||||||
self.min_lr = min_lr
|
|
||||||
super().__init__(optimizer, last_epoch, verbose)
|
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
if not self._get_lr_called_within_step:
|
|
||||||
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
||||||
"please use `get_last_lr()`.")
|
|
||||||
|
|
||||||
return self._get_closed_form_lr()
|
|
||||||
|
|
||||||
def _get_closed_form_lr(self):
|
|
||||||
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
|
||||||
lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
|
|
||||||
return [warmup * max(self.min_lr, base_lr * lr_mult)
|
|
||||||
for base_lr in self.base_lrs]
|
|
||||||
|
|
||||||
|
|
||||||
class ExponentialLR(optim.lr_scheduler._LRScheduler):
|
|
||||||
"""Implements an exponential learning rate schedule with an optional exponential
|
|
||||||
warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
|
|
||||||
continuously by decay (default 0.5) every num_steps steps.
|
|
||||||
Args:
|
|
||||||
optimizer (Optimizer): Wrapped optimizer.
|
|
||||||
num_steps (float): The number of steps to decay the learning rate by decay in.
|
|
||||||
decay (float): The factor by which to decay the learning rate every num_steps
|
|
||||||
steps. Default: 0.5.
|
|
||||||
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
|
||||||
Default: 0.
|
|
||||||
min_lr (float): The minimum learning rate. Default: 0.
|
|
||||||
last_epoch (int): The index of last epoch. Default: -1.
|
|
||||||
verbose (bool): If ``True``, prints a message to stdout for
|
|
||||||
each update. Default: ``False``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
|
|
||||||
last_epoch=-1, verbose=False):
|
|
||||||
self.num_steps = num_steps
|
|
||||||
self.decay = decay
|
|
||||||
if not 0. <= warmup < 1:
|
|
||||||
raise ValueError('Invalid value for warmup')
|
|
||||||
self.warmup = warmup
|
|
||||||
self.min_lr = min_lr
|
|
||||||
super().__init__(optimizer, last_epoch, verbose)
|
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
if not self._get_lr_called_within_step:
|
|
||||||
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
||||||
"please use `get_last_lr()`.")
|
|
||||||
|
|
||||||
return self._get_closed_form_lr()
|
|
||||||
|
|
||||||
def _get_closed_form_lr(self):
|
|
||||||
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
|
||||||
lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
|
|
||||||
return [warmup * max(self.min_lr, base_lr * lr_mult)
|
|
||||||
for base_lr in self.base_lrs]
|
|
||||||
|
|
||||||
|
|
||||||
def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
|
|
||||||
"""Draws samples from an lognormal distribution."""
|
|
||||||
return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
|
|
||||||
|
|
||||||
|
|
||||||
def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
|
|
||||||
"""Draws samples from an optionally truncated log-logistic distribution."""
|
|
||||||
min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
|
|
||||||
max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
|
|
||||||
min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
|
|
||||||
max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
|
|
||||||
u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
|
|
||||||
return u.logit().mul(scale).add(loc).exp().to(dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
|
|
||||||
"""Draws samples from an log-uniform distribution."""
|
|
||||||
min_value = math.log(min_value)
|
|
||||||
max_value = math.log(max_value)
|
|
||||||
return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
|
|
||||||
|
|
||||||
|
|
||||||
def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
|
|
||||||
"""Draws samples from a truncated v-diffusion training timestep distribution."""
|
|
||||||
min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
|
|
||||||
max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
|
|
||||||
u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
|
|
||||||
return torch.tan(u * math.pi / 2) * sigma_data
|
|
||||||
|
|
||||||
|
|
||||||
def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
|
|
||||||
"""Draws samples from a split lognormal distribution."""
|
|
||||||
n = torch.randn(shape, device=device, dtype=dtype).abs()
|
|
||||||
u = torch.rand(shape, device=device, dtype=dtype)
|
|
||||||
n_left = n * -scale_1 + loc
|
|
||||||
n_right = n * scale_2 + loc
|
|
||||||
ratio = scale_1 / (scale_1 + scale_2)
|
|
||||||
return torch.where(u < ratio, n_left, n_right).exp()
|
|
||||||
|
|
||||||
|
|
||||||
class FolderOfImages(data.Dataset):
|
|
||||||
"""Recursively finds all images in a directory. It does not support
|
|
||||||
classes/targets."""
|
|
||||||
|
|
||||||
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
|
|
||||||
|
|
||||||
def __init__(self, root, transform=None):
|
|
||||||
super().__init__()
|
|
||||||
self.root = Path(root)
|
|
||||||
self.transform = nn.Identity() if transform is None else transform
|
|
||||||
self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'FolderOfImages(root="{self.root}", len: {len(self)})'
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.paths)
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
path = self.paths[key]
|
|
||||||
with open(path, 'rb') as f:
|
|
||||||
image = Image.open(f).convert('RGB')
|
|
||||||
image = self.transform(image)
|
|
||||||
return image,
|
|
||||||
|
|
||||||
|
|
||||||
class CSVLogger:
|
|
||||||
def __init__(self, filename, columns):
|
|
||||||
self.filename = Path(filename)
|
|
||||||
self.columns = columns
|
|
||||||
if self.filename.exists():
|
|
||||||
self.file = open(self.filename, 'a')
|
|
||||||
else:
|
|
||||||
self.file = open(self.filename, 'w')
|
|
||||||
self.write(*self.columns)
|
|
||||||
|
|
||||||
def write(self, *args):
|
|
||||||
print(*args, sep=',', file=self.file, flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def tf32_mode(cudnn=None, matmul=None):
|
|
||||||
"""A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
|
|
||||||
cudnn_old = torch.backends.cudnn.allow_tf32
|
|
||||||
matmul_old = torch.backends.cuda.matmul.allow_tf32
|
|
||||||
try:
|
|
||||||
if cudnn is not None:
|
|
||||||
torch.backends.cudnn.allow_tf32 = cudnn
|
|
||||||
if matmul is not None:
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = matmul
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
if cudnn is not None:
|
|
||||||
torch.backends.cudnn.allow_tf32 = cudnn_old
|
|
||||||
if matmul is not None:
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = matmul_old
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user