mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Reformat model_base
This commit is contained in:
parent
facf68e7b9
commit
185ba7e990
@ -1,21 +1,23 @@
|
||||
import torch
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
import math
|
||||
import torch
|
||||
|
||||
from . import conds
|
||||
from . import latent_formats
|
||||
from . import model_management
|
||||
from . import ops
|
||||
from . import utils
|
||||
from .ldm.audio.dit import AudioDiffusionTransformer
|
||||
from .ldm.audio.embedders import NumberConditioner
|
||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||
from . import model_management
|
||||
from . import conds
|
||||
from . import ops
|
||||
from .ldm.cascade.stage_c import StageC
|
||||
from .ldm.cascade.stage_b import StageB
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
from . import latent_formats
|
||||
import math
|
||||
from .ldm.cascade.stage_c import StageC
|
||||
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
EPS = 1
|
||||
@ -146,7 +148,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
if denoise_mask is not None:
|
||||
if len(denoise_mask.shape) == len(noise.shape):
|
||||
denoise_mask = denoise_mask[:,:1]
|
||||
denoise_mask = denoise_mask[:, :1]
|
||||
|
||||
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
||||
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
||||
@ -158,10 +160,10 @@ class BaseModel(torch.nn.Module):
|
||||
if ck == "mask":
|
||||
cond_concat.append(denoise_mask.to(device))
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
|
||||
cond_concat.append(concat_latent_image.to(device)) # NOTE: the latent_image should be masked by the mask in pixel space
|
||||
else:
|
||||
if ck == "mask":
|
||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||
cond_concat.append(torch.ones_like(noise)[:, :1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
@ -230,14 +232,16 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
def set_inpaint(self):
|
||||
self.concat_keys = ("mask", "masked_image")
|
||||
|
||||
def blank_inpaint_image_like(latent_image):
|
||||
blank_image = torch.ones_like(latent_image)
|
||||
# these are the values for "zero" in pixel space translated to latent space
|
||||
blank_image[:,0] *= 0.8223
|
||||
blank_image[:,1] *= -0.6876
|
||||
blank_image[:,2] *= 0.6364
|
||||
blank_image[:,3] *= 0.1380
|
||||
blank_image[:, 0] *= 0.8223
|
||||
blank_image[:, 1] *= -0.6876
|
||||
blank_image[:, 2] *= 0.6364
|
||||
blank_image[:, 3] *= 0.1380
|
||||
return blank_image
|
||||
|
||||
self.blank_inpaint_image_like = blank_inpaint_image_like
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
@ -245,11 +249,11 @@ class BaseModel(torch.nn.Module):
|
||||
dtype = self.get_dtype()
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
#TODO: this needs to be tweaked
|
||||
# TODO: this needs to be tweaked
|
||||
area = input_shape[0] * math.prod(input_shape[2:])
|
||||
return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024)
|
||||
else:
|
||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||
# TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||
area = input_shape[0] * math.prod(input_shape[2:])
|
||||
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
||||
|
||||
@ -278,6 +282,7 @@ def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge
|
||||
|
||||
return adm_out
|
||||
|
||||
|
||||
class SD21UNCLIP(BaseModel):
|
||||
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
@ -291,12 +296,14 @@ class SD21UNCLIP(BaseModel):
|
||||
else:
|
||||
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
|
||||
|
||||
|
||||
def sdxl_pooled(args, noise_augmentor):
|
||||
if "unclip_conditioning" in args:
|
||||
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280]
|
||||
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:, :1280]
|
||||
else:
|
||||
return args["pooled_output"]
|
||||
|
||||
|
||||
class SDXLRefiner(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
@ -324,6 +331,7 @@ class SDXLRefiner(BaseModel):
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||
|
||||
|
||||
class SDXL(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
@ -349,6 +357,7 @@ class SDXL(BaseModel):
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||
|
||||
|
||||
class SVD_img2vid(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
@ -397,6 +406,7 @@ class SVD_img2vid(BaseModel):
|
||||
out['num_video_frames'] = conds.CONDConstant(noise.shape[0])
|
||||
return out
|
||||
|
||||
|
||||
class SV3D_u(SVD_img2vid):
|
||||
def encode_adm(self, **kwargs):
|
||||
augmentation = kwargs.get("augmentation_level", 0)
|
||||
@ -407,6 +417,7 @@ class SV3D_u(SVD_img2vid):
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
|
||||
return flat
|
||||
|
||||
|
||||
class SV3D_p(SVD_img2vid):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
@ -414,7 +425,7 @@ class SV3D_p(SVD_img2vid):
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
augmentation = kwargs.get("augmentation_level", 0)
|
||||
elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here
|
||||
elevation = kwargs.get("elevation", 0) # elevation and azimuth are in degrees here
|
||||
azimuth = kwargs.get("azimuth", 0)
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
@ -457,6 +468,7 @@ class Stable_Zero123(BaseModel):
|
||||
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class SD_X4Upscaler(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
@ -474,7 +486,7 @@ class SD_X4Upscaler(BaseModel):
|
||||
noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros_like(noise)[:,:3]
|
||||
image = torch.zeros_like(noise)[:, :3]
|
||||
|
||||
if image.shape[1:] != noise.shape[1:]:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
@ -489,6 +501,7 @@ class SD_X4Upscaler(BaseModel):
|
||||
out['y'] = conds.CONDRegular(noise_level)
|
||||
return out
|
||||
|
||||
|
||||
class IP2P(BaseModel):
|
||||
def process_ip2p_image_in(self, image):
|
||||
return None
|
||||
@ -514,18 +527,20 @@ class IP2P(BaseModel):
|
||||
out['y'] = conds.CONDRegular(adm)
|
||||
return out
|
||||
|
||||
|
||||
class SD15_instructpix2pix(IP2P, BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.process_ip2p_image_in = lambda image: image
|
||||
|
||||
|
||||
class SDXL_instructpix2pix(IP2P, SDXL):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
if model_type == ModelType.V_PREDICTION_EDM:
|
||||
self.process_ip2p_image_in = lambda image: latent_formats.SDXL().process_in(image) #cosxl ip2p
|
||||
self.process_ip2p_image_in = lambda image: latent_formats.SDXL().process_in(image) # cosxl ip2p
|
||||
else:
|
||||
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
|
||||
self.process_ip2p_image_in = lambda image: image # diffusers ip2p
|
||||
|
||||
|
||||
class StableCascade_C(BaseModel):
|
||||
@ -570,7 +585,7 @@ class StableCascade_B(BaseModel):
|
||||
if clip_text_pooled is not None:
|
||||
out['clip'] = conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
|
||||
# size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
|
||||
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
|
||||
|
||||
out["effnet"] = conds.CONDRegular(prior)
|
||||
@ -597,7 +612,7 @@ class SD3(BaseModel):
|
||||
dtype = self.get_dtype()
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
#TODO: this probably needs to be tweaked
|
||||
# TODO: this probably needs to be tweaked
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
return (area * model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user