Reformat model_base

This commit is contained in:
doctorpangloss 2024-06-19 20:43:20 -07:00
parent facf68e7b9
commit 185ba7e990

View File

@ -1,21 +1,23 @@
import torch
import logging import logging
from enum import Enum
import math
import torch
from . import conds
from . import latent_formats
from . import model_management
from . import ops
from . import utils
from .ldm.audio.dit import AudioDiffusionTransformer from .ldm.audio.dit import AudioDiffusionTransformer
from .ldm.audio.embedders import NumberConditioner from .ldm.audio.embedders import NumberConditioner
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
from . import model_management
from . import conds
from . import ops
from .ldm.cascade.stage_c import StageC
from .ldm.cascade.stage_b import StageB from .ldm.cascade.stage_b import StageB
from enum import Enum from .ldm.cascade.stage_c import StageC
from . import utils from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
from . import latent_formats from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
import math from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
class ModelType(Enum): class ModelType(Enum):
EPS = 1 EPS = 1
@ -146,7 +148,7 @@ class BaseModel(torch.nn.Module):
if denoise_mask is not None: if denoise_mask is not None:
if len(denoise_mask.shape) == len(noise.shape): if len(denoise_mask.shape) == len(noise.shape):
denoise_mask = denoise_mask[:,:1] denoise_mask = denoise_mask[:, :1]
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
if denoise_mask.shape[-2:] != noise.shape[-2:]: if denoise_mask.shape[-2:] != noise.shape[-2:]:
@ -158,10 +160,10 @@ class BaseModel(torch.nn.Module):
if ck == "mask": if ck == "mask":
cond_concat.append(denoise_mask.to(device)) cond_concat.append(denoise_mask.to(device))
elif ck == "masked_image": elif ck == "masked_image":
cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space cond_concat.append(concat_latent_image.to(device)) # NOTE: the latent_image should be masked by the mask in pixel space
else: else:
if ck == "mask": if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1]) cond_concat.append(torch.ones_like(noise)[:, :1])
elif ck == "masked_image": elif ck == "masked_image":
cond_concat.append(self.blank_inpaint_image_like(noise)) cond_concat.append(self.blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1) data = torch.cat(cond_concat, dim=1)
@ -230,14 +232,16 @@ class BaseModel(torch.nn.Module):
def set_inpaint(self): def set_inpaint(self):
self.concat_keys = ("mask", "masked_image") self.concat_keys = ("mask", "masked_image")
def blank_inpaint_image_like(latent_image): def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image) blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space # these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223 blank_image[:, 0] *= 0.8223
blank_image[:,1] *= -0.6876 blank_image[:, 1] *= -0.6876
blank_image[:,2] *= 0.6364 blank_image[:, 2] *= 0.6364
blank_image[:,3] *= 0.1380 blank_image[:, 3] *= 0.1380
return blank_image return blank_image
self.blank_inpaint_image_like = blank_inpaint_image_like self.blank_inpaint_image_like = blank_inpaint_image_like
def memory_required(self, input_shape): def memory_required(self, input_shape):
@ -245,11 +249,11 @@ class BaseModel(torch.nn.Module):
dtype = self.get_dtype() dtype = self.get_dtype()
if self.manual_cast_dtype is not None: if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked # TODO: this needs to be tweaked
area = input_shape[0] * math.prod(input_shape[2:]) area = input_shape[0] * math.prod(input_shape[2:])
return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024) return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024)
else: else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. # TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * math.prod(input_shape[2:]) area = input_shape[0] * math.prod(input_shape[2:])
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
@ -278,6 +282,7 @@ def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge
return adm_out return adm_out
class SD21UNCLIP(BaseModel): class SD21UNCLIP(BaseModel):
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None): def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
@ -291,12 +296,14 @@ class SD21UNCLIP(BaseModel):
else: else:
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10) return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
def sdxl_pooled(args, noise_augmentor): def sdxl_pooled(args, noise_augmentor):
if "unclip_conditioning" in args: if "unclip_conditioning" in args:
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280] return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:, :1280]
else: else:
return args["pooled_output"] return args["pooled_output"]
class SDXLRefiner(BaseModel): class SDXLRefiner(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
@ -324,6 +331,7 @@ class SDXLRefiner(BaseModel):
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SDXL(BaseModel): class SDXL(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
@ -349,6 +357,7 @@ class SDXL(BaseModel):
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SVD_img2vid(BaseModel): class SVD_img2vid(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
@ -397,6 +406,7 @@ class SVD_img2vid(BaseModel):
out['num_video_frames'] = conds.CONDConstant(noise.shape[0]) out['num_video_frames'] = conds.CONDConstant(noise.shape[0])
return out return out
class SV3D_u(SVD_img2vid): class SV3D_u(SVD_img2vid):
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
augmentation = kwargs.get("augmentation_level", 0) augmentation = kwargs.get("augmentation_level", 0)
@ -407,6 +417,7 @@ class SV3D_u(SVD_img2vid):
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
return flat return flat
class SV3D_p(SVD_img2vid): class SV3D_p(SVD_img2vid):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
@ -414,7 +425,7 @@ class SV3D_p(SVD_img2vid):
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
augmentation = kwargs.get("augmentation_level", 0) augmentation = kwargs.get("augmentation_level", 0)
elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here elevation = kwargs.get("elevation", 0) # elevation and azimuth are in degrees here
azimuth = kwargs.get("azimuth", 0) azimuth = kwargs.get("azimuth", 0)
noise = kwargs.get("noise", None) noise = kwargs.get("noise", None)
@ -457,6 +468,7 @@ class Stable_Zero123(BaseModel):
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn) out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
return out return out
class SD_X4Upscaler(BaseModel): class SD_X4Upscaler(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
@ -474,7 +486,7 @@ class SD_X4Upscaler(BaseModel):
noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment) noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
if image is None: if image is None:
image = torch.zeros_like(noise)[:,:3] image = torch.zeros_like(noise)[:, :3]
if image.shape[1:] != noise.shape[1:]: if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
@ -489,6 +501,7 @@ class SD_X4Upscaler(BaseModel):
out['y'] = conds.CONDRegular(noise_level) out['y'] = conds.CONDRegular(noise_level)
return out return out
class IP2P(BaseModel): class IP2P(BaseModel):
def process_ip2p_image_in(self, image): def process_ip2p_image_in(self, image):
return None return None
@ -514,18 +527,20 @@ class IP2P(BaseModel):
out['y'] = conds.CONDRegular(adm) out['y'] = conds.CONDRegular(adm)
return out return out
class SD15_instructpix2pix(IP2P, BaseModel): class SD15_instructpix2pix(IP2P, BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
self.process_ip2p_image_in = lambda image: image self.process_ip2p_image_in = lambda image: image
class SDXL_instructpix2pix(IP2P, SDXL): class SDXL_instructpix2pix(IP2P, SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
if model_type == ModelType.V_PREDICTION_EDM: if model_type == ModelType.V_PREDICTION_EDM:
self.process_ip2p_image_in = lambda image: latent_formats.SDXL().process_in(image) #cosxl ip2p self.process_ip2p_image_in = lambda image: latent_formats.SDXL().process_in(image) # cosxl ip2p
else: else:
self.process_ip2p_image_in = lambda image: image #diffusers ip2p self.process_ip2p_image_in = lambda image: image # diffusers ip2p
class StableCascade_C(BaseModel): class StableCascade_C(BaseModel):
@ -570,7 +585,7 @@ class StableCascade_B(BaseModel):
if clip_text_pooled is not None: if clip_text_pooled is not None:
out['clip'] = conds.CONDRegular(clip_text_pooled) out['clip'] = conds.CONDRegular(clip_text_pooled)
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched # size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device)) prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
out["effnet"] = conds.CONDRegular(prior) out["effnet"] = conds.CONDRegular(prior)
@ -597,7 +612,7 @@ class SD3(BaseModel):
dtype = self.get_dtype() dtype = self.get_dtype()
if self.manual_cast_dtype is not None: if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype dtype = self.manual_cast_dtype
#TODO: this probably needs to be tweaked # TODO: this probably needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3] area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * model_management.dtype_size(dtype) * 0.012) * (1024 * 1024) return (area * model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
else: else: