Reformat model_base

This commit is contained in:
doctorpangloss 2024-06-19 20:43:20 -07:00
parent facf68e7b9
commit 185ba7e990

View File

@ -1,21 +1,23 @@
import torch
import logging
from enum import Enum
import math
import torch
from . import conds
from . import latent_formats
from . import model_management
from . import ops
from . import utils
from .ldm.audio.dit import AudioDiffusionTransformer
from .ldm.audio.embedders import NumberConditioner
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
from . import model_management
from . import conds
from . import ops
from .ldm.cascade.stage_c import StageC
from .ldm.cascade.stage_b import StageB
from enum import Enum
from . import utils
from . import latent_formats
import math
from .ldm.cascade.stage_c import StageC
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
class ModelType(Enum):
EPS = 1
@ -146,7 +148,7 @@ class BaseModel(torch.nn.Module):
if denoise_mask is not None:
if len(denoise_mask.shape) == len(noise.shape):
denoise_mask = denoise_mask[:,:1]
denoise_mask = denoise_mask[:, :1]
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
if denoise_mask.shape[-2:] != noise.shape[-2:]:
@ -158,10 +160,10 @@ class BaseModel(torch.nn.Module):
if ck == "mask":
cond_concat.append(denoise_mask.to(device))
elif ck == "masked_image":
cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
cond_concat.append(concat_latent_image.to(device)) # NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
cond_concat.append(torch.ones_like(noise)[:, :1])
elif ck == "masked_image":
cond_concat.append(self.blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1)
@ -230,14 +232,16 @@ class BaseModel(torch.nn.Module):
def set_inpaint(self):
self.concat_keys = ("mask", "masked_image")
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
blank_image[:, 0] *= 0.8223
blank_image[:, 1] *= -0.6876
blank_image[:, 2] *= 0.6364
blank_image[:, 3] *= 0.1380
return blank_image
self.blank_inpaint_image_like = blank_inpaint_image_like
def memory_required(self, input_shape):
@ -245,11 +249,11 @@ class BaseModel(torch.nn.Module):
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked
# TODO: this needs to be tweaked
area = input_shape[0] * math.prod(input_shape[2:])
return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024)
else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
# TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * math.prod(input_shape[2:])
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
@ -278,6 +282,7 @@ def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge
return adm_out
class SD21UNCLIP(BaseModel):
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device)
@ -291,12 +296,14 @@ class SD21UNCLIP(BaseModel):
else:
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
def sdxl_pooled(args, noise_augmentor):
if "unclip_conditioning" in args:
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280]
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:, :1280]
else:
return args["pooled_output"]
class SDXLRefiner(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
@ -324,6 +331,7 @@ class SDXLRefiner(BaseModel):
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SDXL(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
@ -349,6 +357,7 @@ class SDXL(BaseModel):
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SVD_img2vid(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device)
@ -397,6 +406,7 @@ class SVD_img2vid(BaseModel):
out['num_video_frames'] = conds.CONDConstant(noise.shape[0])
return out
class SV3D_u(SVD_img2vid):
def encode_adm(self, **kwargs):
augmentation = kwargs.get("augmentation_level", 0)
@ -407,6 +417,7 @@ class SV3D_u(SVD_img2vid):
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
return flat
class SV3D_p(SVD_img2vid):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device)
@ -414,7 +425,7 @@ class SV3D_p(SVD_img2vid):
def encode_adm(self, **kwargs):
augmentation = kwargs.get("augmentation_level", 0)
elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here
elevation = kwargs.get("elevation", 0) # elevation and azimuth are in degrees here
azimuth = kwargs.get("azimuth", 0)
noise = kwargs.get("noise", None)
@ -457,6 +468,7 @@ class Stable_Zero123(BaseModel):
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
return out
class SD_X4Upscaler(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device)
@ -474,7 +486,7 @@ class SD_X4Upscaler(BaseModel):
noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
if image is None:
image = torch.zeros_like(noise)[:,:3]
image = torch.zeros_like(noise)[:, :3]
if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
@ -489,6 +501,7 @@ class SD_X4Upscaler(BaseModel):
out['y'] = conds.CONDRegular(noise_level)
return out
class IP2P(BaseModel):
def process_ip2p_image_in(self, image):
return None
@ -514,18 +527,20 @@ class IP2P(BaseModel):
out['y'] = conds.CONDRegular(adm)
return out
class SD15_instructpix2pix(IP2P, BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
self.process_ip2p_image_in = lambda image: image
class SDXL_instructpix2pix(IP2P, SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
if model_type == ModelType.V_PREDICTION_EDM:
self.process_ip2p_image_in = lambda image: latent_formats.SDXL().process_in(image) #cosxl ip2p
self.process_ip2p_image_in = lambda image: latent_formats.SDXL().process_in(image) # cosxl ip2p
else:
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
self.process_ip2p_image_in = lambda image: image # diffusers ip2p
class StableCascade_C(BaseModel):
@ -570,7 +585,7 @@ class StableCascade_B(BaseModel):
if clip_text_pooled is not None:
out['clip'] = conds.CONDRegular(clip_text_pooled)
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
# size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
out["effnet"] = conds.CONDRegular(prior)
@ -597,7 +612,7 @@ class SD3(BaseModel):
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this probably needs to be tweaked
# TODO: this probably needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
else: