From aab3b6c5a4e161180fed2e0c70919e716aba6183 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Sun, 16 Jun 2024 11:27:01 -0700 Subject: [PATCH] Remove unused k_diffusion utils code --- comfy/k_diffusion/utils.py | 304 ------------------------------------- 1 file changed, 304 deletions(-) diff --git a/comfy/k_diffusion/utils.py b/comfy/k_diffusion/utils.py index a644df2f3..7d2f0f787 100644 --- a/comfy/k_diffusion/utils.py +++ b/comfy/k_diffusion/utils.py @@ -1,23 +1,3 @@ -from contextlib import contextmanager -import hashlib -import math -from pathlib import Path -import shutil -import urllib -import warnings - -from PIL import Image -import torch -from torch import nn, optim -from torch.utils import data - - -def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'): - """Apply passed in transforms for HuggingFace Datasets.""" - images = [transform(image.convert(mode)) for image in examples[image_key]] - return {image_key: images} - - def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim @@ -27,287 +7,3 @@ def append_dims(x, target_dims): # MPS will get inf values if it tries to index into the new axes, but detaching fixes this. # https://github.com/pytorch/pytorch/issues/84364 return expanded.detach().clone() if expanded.device.type == 'mps' else expanded - - -def n_params(module): - """Returns the number of trainable parameters in a module.""" - return sum(p.numel() for p in module.parameters()) - - -def download_file(path, url, digest=None): - """Downloads a file if it does not exist, optionally checking its SHA-256 hash.""" - path = Path(path) - path.parent.mkdir(parents=True, exist_ok=True) - if not path.exists(): - with urllib.request.urlopen(url) as response, open(path, 'wb') as f: - shutil.copyfileobj(response, f) - if digest is not None: - file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest() - if digest != file_digest: - raise OSError(f'hash of {path} (url: {url}) failed to validate') - return path - - -@contextmanager -def train_mode(model, mode=True): - """A context manager that places a model into training mode and restores - the previous mode on exit.""" - modes = [module.training for module in model.modules()] - try: - yield model.train(mode) - finally: - for i, module in enumerate(model.modules()): - module.training = modes[i] - - -def eval_mode(model): - """A context manager that places a model into evaluation mode and restores - the previous mode on exit.""" - return train_mode(model, False) - - -@torch.no_grad() -def ema_update(model, averaged_model, decay): - """Incorporates updated model parameters into an exponential moving averaged - version of a model. It should be called after each optimizer step.""" - model_params = dict(model.named_parameters()) - averaged_params = dict(averaged_model.named_parameters()) - assert model_params.keys() == averaged_params.keys() - - for name, param in model_params.items(): - averaged_params[name].mul_(decay).add_(param, alpha=1 - decay) - - model_buffers = dict(model.named_buffers()) - averaged_buffers = dict(averaged_model.named_buffers()) - assert model_buffers.keys() == averaged_buffers.keys() - - for name, buf in model_buffers.items(): - averaged_buffers[name].copy_(buf) - - -class EMAWarmup: - """Implements an EMA warmup using an inverse decay schedule. - If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are - good values for models you plan to train for a million or more steps (reaches decay - factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models - you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at - 215.4k steps). - Args: - inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. - power (float): Exponential factor of EMA warmup. Default: 1. - min_value (float): The minimum EMA decay rate. Default: 0. - max_value (float): The maximum EMA decay rate. Default: 1. - start_at (int): The epoch to start averaging at. Default: 0. - last_epoch (int): The index of last epoch. Default: 0. - """ - - def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0, - last_epoch=0): - self.inv_gamma = inv_gamma - self.power = power - self.min_value = min_value - self.max_value = max_value - self.start_at = start_at - self.last_epoch = last_epoch - - def state_dict(self): - """Returns the state of the class as a :class:`dict`.""" - return dict(self.__dict__.items()) - - def load_state_dict(self, state_dict): - """Loads the class's state. - Args: - state_dict (dict): scaler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - self.__dict__.update(state_dict) - - def get_value(self): - """Gets the current EMA decay rate.""" - epoch = max(0, self.last_epoch - self.start_at) - value = 1 - (1 + epoch / self.inv_gamma) ** -self.power - return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value)) - - def step(self): - """Updates the step count.""" - self.last_epoch += 1 - - -class InverseLR(optim.lr_scheduler._LRScheduler): - """Implements an inverse decay learning rate schedule with an optional exponential - warmup. When last_epoch=-1, sets initial lr as lr. - inv_gamma is the number of steps/epochs required for the learning rate to decay to - (1 / 2)**power of its original value. - Args: - optimizer (Optimizer): Wrapped optimizer. - inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. - power (float): Exponential factor of learning rate decay. Default: 1. - warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) - Default: 0. - min_lr (float): The minimum learning rate. Default: 0. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - """ - - def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0., - last_epoch=-1, verbose=False): - self.inv_gamma = inv_gamma - self.power = power - if not 0. <= warmup < 1: - raise ValueError('Invalid value for warmup') - self.warmup = warmup - self.min_lr = min_lr - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.") - - return self._get_closed_form_lr() - - def _get_closed_form_lr(self): - warmup = 1 - self.warmup ** (self.last_epoch + 1) - lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power - return [warmup * max(self.min_lr, base_lr * lr_mult) - for base_lr in self.base_lrs] - - -class ExponentialLR(optim.lr_scheduler._LRScheduler): - """Implements an exponential learning rate schedule with an optional exponential - warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate - continuously by decay (default 0.5) every num_steps steps. - Args: - optimizer (Optimizer): Wrapped optimizer. - num_steps (float): The number of steps to decay the learning rate by decay in. - decay (float): The factor by which to decay the learning rate every num_steps - steps. Default: 0.5. - warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) - Default: 0. - min_lr (float): The minimum learning rate. Default: 0. - last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for - each update. Default: ``False``. - """ - - def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0., - last_epoch=-1, verbose=False): - self.num_steps = num_steps - self.decay = decay - if not 0. <= warmup < 1: - raise ValueError('Invalid value for warmup') - self.warmup = warmup - self.min_lr = min_lr - super().__init__(optimizer, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.") - - return self._get_closed_form_lr() - - def _get_closed_form_lr(self): - warmup = 1 - self.warmup ** (self.last_epoch + 1) - lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch - return [warmup * max(self.min_lr, base_lr * lr_mult) - for base_lr in self.base_lrs] - - -def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): - """Draws samples from an lognormal distribution.""" - return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp() - - -def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): - """Draws samples from an optionally truncated log-logistic distribution.""" - min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64) - max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64) - min_cdf = min_value.log().sub(loc).div(scale).sigmoid() - max_cdf = max_value.log().sub(loc).div(scale).sigmoid() - u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf - return u.logit().mul(scale).add(loc).exp().to(dtype) - - -def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): - """Draws samples from an log-uniform distribution.""" - min_value = math.log(min_value) - max_value = math.log(max_value) - return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp() - - -def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): - """Draws samples from a truncated v-diffusion training timestep distribution.""" - min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi - max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi - u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf - return torch.tan(u * math.pi / 2) * sigma_data - - -def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32): - """Draws samples from a split lognormal distribution.""" - n = torch.randn(shape, device=device, dtype=dtype).abs() - u = torch.rand(shape, device=device, dtype=dtype) - n_left = n * -scale_1 + loc - n_right = n * scale_2 + loc - ratio = scale_1 / (scale_1 + scale_2) - return torch.where(u < ratio, n_left, n_right).exp() - - -class FolderOfImages(data.Dataset): - """Recursively finds all images in a directory. It does not support - classes/targets.""" - - IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'} - - def __init__(self, root, transform=None): - super().__init__() - self.root = Path(root) - self.transform = nn.Identity() if transform is None else transform - self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS) - - def __repr__(self): - return f'FolderOfImages(root="{self.root}", len: {len(self)})' - - def __len__(self): - return len(self.paths) - - def __getitem__(self, key): - path = self.paths[key] - with open(path, 'rb') as f: - image = Image.open(f).convert('RGB') - image = self.transform(image) - return image, - - -class CSVLogger: - def __init__(self, filename, columns): - self.filename = Path(filename) - self.columns = columns - if self.filename.exists(): - self.file = open(self.filename, 'a') - else: - self.file = open(self.filename, 'w') - self.write(*self.columns) - - def write(self, *args): - print(*args, sep=',', file=self.file, flush=True) - - -@contextmanager -def tf32_mode(cudnn=None, matmul=None): - """A context manager that sets whether TF32 is allowed on cuDNN or matmul.""" - cudnn_old = torch.backends.cudnn.allow_tf32 - matmul_old = torch.backends.cuda.matmul.allow_tf32 - try: - if cudnn is not None: - torch.backends.cudnn.allow_tf32 = cudnn - if matmul is not None: - torch.backends.cuda.matmul.allow_tf32 = matmul - yield - finally: - if cudnn is not None: - torch.backends.cudnn.allow_tf32 = cudnn_old - if matmul is not None: - torch.backends.cuda.matmul.allow_tf32 = matmul_old