mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-30 00:00:26 +08:00
Reformat model_base
This commit is contained in:
parent
facf68e7b9
commit
185ba7e990
@ -1,21 +1,23 @@
|
|||||||
import torch
|
|
||||||
import logging
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from . import conds
|
||||||
|
from . import latent_formats
|
||||||
|
from . import model_management
|
||||||
|
from . import ops
|
||||||
|
from . import utils
|
||||||
from .ldm.audio.dit import AudioDiffusionTransformer
|
from .ldm.audio.dit import AudioDiffusionTransformer
|
||||||
from .ldm.audio.embedders import NumberConditioner
|
from .ldm.audio.embedders import NumberConditioner
|
||||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
|
||||||
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
|
||||||
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
|
||||||
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
|
||||||
from . import model_management
|
|
||||||
from . import conds
|
|
||||||
from . import ops
|
|
||||||
from .ldm.cascade.stage_c import StageC
|
|
||||||
from .ldm.cascade.stage_b import StageB
|
from .ldm.cascade.stage_b import StageB
|
||||||
from enum import Enum
|
from .ldm.cascade.stage_c import StageC
|
||||||
from . import utils
|
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||||
from . import latent_formats
|
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
import math
|
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
|
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
|
|
||||||
|
|
||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
EPS = 1
|
EPS = 1
|
||||||
@ -146,7 +148,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
if len(denoise_mask.shape) == len(noise.shape):
|
if len(denoise_mask.shape) == len(noise.shape):
|
||||||
denoise_mask = denoise_mask[:,:1]
|
denoise_mask = denoise_mask[:, :1]
|
||||||
|
|
||||||
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
||||||
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
||||||
@ -158,10 +160,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
if ck == "mask":
|
if ck == "mask":
|
||||||
cond_concat.append(denoise_mask.to(device))
|
cond_concat.append(denoise_mask.to(device))
|
||||||
elif ck == "masked_image":
|
elif ck == "masked_image":
|
||||||
cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
|
cond_concat.append(concat_latent_image.to(device)) # NOTE: the latent_image should be masked by the mask in pixel space
|
||||||
else:
|
else:
|
||||||
if ck == "mask":
|
if ck == "mask":
|
||||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
cond_concat.append(torch.ones_like(noise)[:, :1])
|
||||||
elif ck == "masked_image":
|
elif ck == "masked_image":
|
||||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||||
data = torch.cat(cond_concat, dim=1)
|
data = torch.cat(cond_concat, dim=1)
|
||||||
@ -230,14 +232,16 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
def set_inpaint(self):
|
def set_inpaint(self):
|
||||||
self.concat_keys = ("mask", "masked_image")
|
self.concat_keys = ("mask", "masked_image")
|
||||||
|
|
||||||
def blank_inpaint_image_like(latent_image):
|
def blank_inpaint_image_like(latent_image):
|
||||||
blank_image = torch.ones_like(latent_image)
|
blank_image = torch.ones_like(latent_image)
|
||||||
# these are the values for "zero" in pixel space translated to latent space
|
# these are the values for "zero" in pixel space translated to latent space
|
||||||
blank_image[:,0] *= 0.8223
|
blank_image[:, 0] *= 0.8223
|
||||||
blank_image[:,1] *= -0.6876
|
blank_image[:, 1] *= -0.6876
|
||||||
blank_image[:,2] *= 0.6364
|
blank_image[:, 2] *= 0.6364
|
||||||
blank_image[:,3] *= 0.1380
|
blank_image[:, 3] *= 0.1380
|
||||||
return blank_image
|
return blank_image
|
||||||
|
|
||||||
self.blank_inpaint_image_like = blank_inpaint_image_like
|
self.blank_inpaint_image_like = blank_inpaint_image_like
|
||||||
|
|
||||||
def memory_required(self, input_shape):
|
def memory_required(self, input_shape):
|
||||||
@ -245,11 +249,11 @@ class BaseModel(torch.nn.Module):
|
|||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
if self.manual_cast_dtype is not None:
|
if self.manual_cast_dtype is not None:
|
||||||
dtype = self.manual_cast_dtype
|
dtype = self.manual_cast_dtype
|
||||||
#TODO: this needs to be tweaked
|
# TODO: this needs to be tweaked
|
||||||
area = input_shape[0] * math.prod(input_shape[2:])
|
area = input_shape[0] * math.prod(input_shape[2:])
|
||||||
return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024)
|
return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024)
|
||||||
else:
|
else:
|
||||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
# TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||||
area = input_shape[0] * math.prod(input_shape[2:])
|
area = input_shape[0] * math.prod(input_shape[2:])
|
||||||
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
||||||
|
|
||||||
@ -278,6 +282,7 @@ def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge
|
|||||||
|
|
||||||
return adm_out
|
return adm_out
|
||||||
|
|
||||||
|
|
||||||
class SD21UNCLIP(BaseModel):
|
class SD21UNCLIP(BaseModel):
|
||||||
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
|
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||||
super().__init__(model_config, model_type, device=device)
|
super().__init__(model_config, model_type, device=device)
|
||||||
@ -291,12 +296,14 @@ class SD21UNCLIP(BaseModel):
|
|||||||
else:
|
else:
|
||||||
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
|
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
|
||||||
|
|
||||||
|
|
||||||
def sdxl_pooled(args, noise_augmentor):
|
def sdxl_pooled(args, noise_augmentor):
|
||||||
if "unclip_conditioning" in args:
|
if "unclip_conditioning" in args:
|
||||||
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280]
|
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:, :1280]
|
||||||
else:
|
else:
|
||||||
return args["pooled_output"]
|
return args["pooled_output"]
|
||||||
|
|
||||||
|
|
||||||
class SDXLRefiner(BaseModel):
|
class SDXLRefiner(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__(model_config, model_type, device=device)
|
super().__init__(model_config, model_type, device=device)
|
||||||
@ -324,6 +331,7 @@ class SDXLRefiner(BaseModel):
|
|||||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
|
|
||||||
|
|
||||||
class SDXL(BaseModel):
|
class SDXL(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__(model_config, model_type, device=device)
|
super().__init__(model_config, model_type, device=device)
|
||||||
@ -349,6 +357,7 @@ class SDXL(BaseModel):
|
|||||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
|
|
||||||
|
|
||||||
class SVD_img2vid(BaseModel):
|
class SVD_img2vid(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
||||||
super().__init__(model_config, model_type, device=device)
|
super().__init__(model_config, model_type, device=device)
|
||||||
@ -397,6 +406,7 @@ class SVD_img2vid(BaseModel):
|
|||||||
out['num_video_frames'] = conds.CONDConstant(noise.shape[0])
|
out['num_video_frames'] = conds.CONDConstant(noise.shape[0])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class SV3D_u(SVD_img2vid):
|
class SV3D_u(SVD_img2vid):
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
augmentation = kwargs.get("augmentation_level", 0)
|
augmentation = kwargs.get("augmentation_level", 0)
|
||||||
@ -407,6 +417,7 @@ class SV3D_u(SVD_img2vid):
|
|||||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
|
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
|
||||||
return flat
|
return flat
|
||||||
|
|
||||||
|
|
||||||
class SV3D_p(SVD_img2vid):
|
class SV3D_p(SVD_img2vid):
|
||||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
||||||
super().__init__(model_config, model_type, device=device)
|
super().__init__(model_config, model_type, device=device)
|
||||||
@ -414,7 +425,7 @@ class SV3D_p(SVD_img2vid):
|
|||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
augmentation = kwargs.get("augmentation_level", 0)
|
augmentation = kwargs.get("augmentation_level", 0)
|
||||||
elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here
|
elevation = kwargs.get("elevation", 0) # elevation and azimuth are in degrees here
|
||||||
azimuth = kwargs.get("azimuth", 0)
|
azimuth = kwargs.get("azimuth", 0)
|
||||||
noise = kwargs.get("noise", None)
|
noise = kwargs.get("noise", None)
|
||||||
|
|
||||||
@ -457,6 +468,7 @@ class Stable_Zero123(BaseModel):
|
|||||||
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
|
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class SD_X4Upscaler(BaseModel):
|
class SD_X4Upscaler(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||||
super().__init__(model_config, model_type, device=device)
|
super().__init__(model_config, model_type, device=device)
|
||||||
@ -474,7 +486,7 @@ class SD_X4Upscaler(BaseModel):
|
|||||||
noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
|
noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
|
||||||
|
|
||||||
if image is None:
|
if image is None:
|
||||||
image = torch.zeros_like(noise)[:,:3]
|
image = torch.zeros_like(noise)[:, :3]
|
||||||
|
|
||||||
if image.shape[1:] != noise.shape[1:]:
|
if image.shape[1:] != noise.shape[1:]:
|
||||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
@ -489,6 +501,7 @@ class SD_X4Upscaler(BaseModel):
|
|||||||
out['y'] = conds.CONDRegular(noise_level)
|
out['y'] = conds.CONDRegular(noise_level)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class IP2P(BaseModel):
|
class IP2P(BaseModel):
|
||||||
def process_ip2p_image_in(self, image):
|
def process_ip2p_image_in(self, image):
|
||||||
return None
|
return None
|
||||||
@ -514,18 +527,20 @@ class IP2P(BaseModel):
|
|||||||
out['y'] = conds.CONDRegular(adm)
|
out['y'] = conds.CONDRegular(adm)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class SD15_instructpix2pix(IP2P, BaseModel):
|
class SD15_instructpix2pix(IP2P, BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__(model_config, model_type, device=device)
|
super().__init__(model_config, model_type, device=device)
|
||||||
self.process_ip2p_image_in = lambda image: image
|
self.process_ip2p_image_in = lambda image: image
|
||||||
|
|
||||||
|
|
||||||
class SDXL_instructpix2pix(IP2P, SDXL):
|
class SDXL_instructpix2pix(IP2P, SDXL):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__(model_config, model_type, device=device)
|
super().__init__(model_config, model_type, device=device)
|
||||||
if model_type == ModelType.V_PREDICTION_EDM:
|
if model_type == ModelType.V_PREDICTION_EDM:
|
||||||
self.process_ip2p_image_in = lambda image: latent_formats.SDXL().process_in(image) #cosxl ip2p
|
self.process_ip2p_image_in = lambda image: latent_formats.SDXL().process_in(image) # cosxl ip2p
|
||||||
else:
|
else:
|
||||||
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
|
self.process_ip2p_image_in = lambda image: image # diffusers ip2p
|
||||||
|
|
||||||
|
|
||||||
class StableCascade_C(BaseModel):
|
class StableCascade_C(BaseModel):
|
||||||
@ -570,7 +585,7 @@ class StableCascade_B(BaseModel):
|
|||||||
if clip_text_pooled is not None:
|
if clip_text_pooled is not None:
|
||||||
out['clip'] = conds.CONDRegular(clip_text_pooled)
|
out['clip'] = conds.CONDRegular(clip_text_pooled)
|
||||||
|
|
||||||
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
|
# size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
|
||||||
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
|
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
|
||||||
|
|
||||||
out["effnet"] = conds.CONDRegular(prior)
|
out["effnet"] = conds.CONDRegular(prior)
|
||||||
@ -597,7 +612,7 @@ class SD3(BaseModel):
|
|||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
if self.manual_cast_dtype is not None:
|
if self.manual_cast_dtype is not None:
|
||||||
dtype = self.manual_cast_dtype
|
dtype = self.manual_cast_dtype
|
||||||
#TODO: this probably needs to be tweaked
|
# TODO: this probably needs to be tweaked
|
||||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||||
return (area * model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
|
return (area * model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user