diff --git a/comfy/model_base.py b/comfy/model_base.py index 443ecd57f..7e7a8bae0 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1,21 +1,23 @@ -import torch import logging +from enum import Enum +import math +import torch + +from . import conds +from . import latent_formats +from . import model_management +from . import ops +from . import utils from .ldm.audio.dit import AudioDiffusionTransformer from .ldm.audio.embedders import NumberConditioner -from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep -from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation -from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation -from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper -from . import model_management -from . import conds -from . import ops -from .ldm.cascade.stage_c import StageC from .ldm.cascade.stage_b import StageB -from enum import Enum -from . import utils -from . import latent_formats -import math +from .ldm.cascade.stage_c import StageC +from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper +from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep +from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation +from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation + class ModelType(Enum): EPS = 1 @@ -146,7 +148,7 @@ class BaseModel(torch.nn.Module): if denoise_mask is not None: if len(denoise_mask.shape) == len(noise.shape): - denoise_mask = denoise_mask[:,:1] + denoise_mask = denoise_mask[:, :1] denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) if denoise_mask.shape[-2:] != noise.shape[-2:]: @@ -158,10 +160,10 @@ class BaseModel(torch.nn.Module): if ck == "mask": cond_concat.append(denoise_mask.to(device)) elif ck == "masked_image": - cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space + cond_concat.append(concat_latent_image.to(device)) # NOTE: the latent_image should be masked by the mask in pixel space else: if ck == "mask": - cond_concat.append(torch.ones_like(noise)[:,:1]) + cond_concat.append(torch.ones_like(noise)[:, :1]) elif ck == "masked_image": cond_concat.append(self.blank_inpaint_image_like(noise)) data = torch.cat(cond_concat, dim=1) @@ -230,14 +232,16 @@ class BaseModel(torch.nn.Module): def set_inpaint(self): self.concat_keys = ("mask", "masked_image") + def blank_inpaint_image_like(latent_image): blank_image = torch.ones_like(latent_image) # these are the values for "zero" in pixel space translated to latent space - blank_image[:,0] *= 0.8223 - blank_image[:,1] *= -0.6876 - blank_image[:,2] *= 0.6364 - blank_image[:,3] *= 0.1380 + blank_image[:, 0] *= 0.8223 + blank_image[:, 1] *= -0.6876 + blank_image[:, 2] *= 0.6364 + blank_image[:, 3] *= 0.1380 return blank_image + self.blank_inpaint_image_like = blank_inpaint_image_like def memory_required(self, input_shape): @@ -245,11 +249,11 @@ class BaseModel(torch.nn.Module): dtype = self.get_dtype() if self.manual_cast_dtype is not None: dtype = self.manual_cast_dtype - #TODO: this needs to be tweaked + # TODO: this needs to be tweaked area = input_shape[0] * math.prod(input_shape[2:]) return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024) else: - #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. + # TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = input_shape[0] * math.prod(input_shape[2:]) return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) @@ -278,6 +282,7 @@ def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge return adm_out + class SD21UNCLIP(BaseModel): def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None): super().__init__(model_config, model_type, device=device) @@ -291,12 +296,14 @@ class SD21UNCLIP(BaseModel): else: return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10) + def sdxl_pooled(args, noise_augmentor): if "unclip_conditioning" in args: - return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280] + return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:, :1280] else: return args["pooled_output"] + class SDXLRefiner(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) @@ -324,6 +331,7 @@ class SDXLRefiner(BaseModel): flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) + class SDXL(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) @@ -349,6 +357,7 @@ class SDXL(BaseModel): flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) + class SVD_img2vid(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): super().__init__(model_config, model_type, device=device) @@ -397,6 +406,7 @@ class SVD_img2vid(BaseModel): out['num_video_frames'] = conds.CONDConstant(noise.shape[0]) return out + class SV3D_u(SVD_img2vid): def encode_adm(self, **kwargs): augmentation = kwargs.get("augmentation_level", 0) @@ -407,6 +417,7 @@ class SV3D_u(SVD_img2vid): flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) return flat + class SV3D_p(SVD_img2vid): def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): super().__init__(model_config, model_type, device=device) @@ -414,7 +425,7 @@ class SV3D_p(SVD_img2vid): def encode_adm(self, **kwargs): augmentation = kwargs.get("augmentation_level", 0) - elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here + elevation = kwargs.get("elevation", 0) # elevation and azimuth are in degrees here azimuth = kwargs.get("azimuth", 0) noise = kwargs.get("noise", None) @@ -457,6 +468,7 @@ class Stable_Zero123(BaseModel): out['c_crossattn'] = conds.CONDCrossAttn(cross_attn) return out + class SD_X4Upscaler(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): super().__init__(model_config, model_type, device=device) @@ -474,7 +486,7 @@ class SD_X4Upscaler(BaseModel): noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment) if image is None: - image = torch.zeros_like(noise)[:,:3] + image = torch.zeros_like(noise)[:, :3] if image.shape[1:] != noise.shape[1:]: image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") @@ -489,6 +501,7 @@ class SD_X4Upscaler(BaseModel): out['y'] = conds.CONDRegular(noise_level) return out + class IP2P(BaseModel): def process_ip2p_image_in(self, image): return None @@ -514,18 +527,20 @@ class IP2P(BaseModel): out['y'] = conds.CONDRegular(adm) return out + class SD15_instructpix2pix(IP2P, BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) self.process_ip2p_image_in = lambda image: image + class SDXL_instructpix2pix(IP2P, SDXL): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) if model_type == ModelType.V_PREDICTION_EDM: - self.process_ip2p_image_in = lambda image: latent_formats.SDXL().process_in(image) #cosxl ip2p + self.process_ip2p_image_in = lambda image: latent_formats.SDXL().process_in(image) # cosxl ip2p else: - self.process_ip2p_image_in = lambda image: image #diffusers ip2p + self.process_ip2p_image_in = lambda image: image # diffusers ip2p class StableCascade_C(BaseModel): @@ -570,7 +585,7 @@ class StableCascade_B(BaseModel): if clip_text_pooled is not None: out['clip'] = conds.CONDRegular(clip_text_pooled) - #size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched + # size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device)) out["effnet"] = conds.CONDRegular(prior) @@ -597,7 +612,7 @@ class SD3(BaseModel): dtype = self.get_dtype() if self.manual_cast_dtype is not None: dtype = self.manual_cast_dtype - #TODO: this probably needs to be tweaked + # TODO: this probably needs to be tweaked area = input_shape[0] * input_shape[2] * input_shape[3] return (area * model_management.dtype_size(dtype) * 0.012) * (1024 * 1024) else: