mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Add known SD3 model files, merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
cac6690481
29
README.md
29
README.md
@ -18,6 +18,35 @@ A vanilla, up-to-date fork of [ComfyUI](https://github.com/comfyanonymous/comfyu
|
||||
- Automated tests for new features.
|
||||
- Automatic model downloading for well-known models.
|
||||
|
||||
### Upstream Features
|
||||
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) and [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
|
||||
- Works even if you don't have a GPU with: ```--cpu``` (slow)
|
||||
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
|
||||
- Embeddings/Textual inversion
|
||||
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
|
||||
- Loading full workflows (with seeds) from generated PNG files.
|
||||
- Saving/Loading workflows as Json files.
|
||||
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
|
||||
- [Area Composition](https://comfyanonymous.github.io/ComfyUI_examples/area_composition/)
|
||||
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
|
||||
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
|
||||
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
||||
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: will never download anything.
|
||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||
|
||||
### Table of Contents
|
||||
|
||||
- [Workflows](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
|
||||
@ -29,7 +29,12 @@ class CONDRegular:
|
||||
|
||||
class CONDNoiseShape(CONDRegular):
|
||||
def process_cond(self, batch_size, device, area, **kwargs):
|
||||
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
data = self.cond
|
||||
if area is not None:
|
||||
dims = len(area) // 2
|
||||
for i in range(dims):
|
||||
data = data.narrow(i + 2, area[i + dims], area[i])
|
||||
|
||||
return self._copy_with(utils.repeat_to_batch_size(data, batch_size).to(device))
|
||||
|
||||
|
||||
|
||||
@ -835,72 +835,11 @@ class MMDiT(nn.Module):
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
# self.initialize_weights()
|
||||
|
||||
if compile_core:
|
||||
assert False
|
||||
self.forward_core_with_concat = torch.compile(self.forward_core_with_concat)
|
||||
|
||||
def initialize_weights(self):
|
||||
# TODO: Init context_embedder?
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
# Initialize (and freeze) pos_embed by sin-cos embedding
|
||||
if self.pos_embed is not None:
|
||||
pos_embed_grid_size = (
|
||||
int(self.x_embedder.num_patches**0.5)
|
||||
if self.pos_embed_max_size is None
|
||||
else self.pos_embed_max_size
|
||||
)
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
self.pos_embed.shape[-1],
|
||||
int(self.x_embedder.num_patches**0.5),
|
||||
pos_embed_grid_size,
|
||||
scaling_factor=self.pos_embed_scaling_factor,
|
||||
offset=self.pos_embed_offset,
|
||||
)
|
||||
|
||||
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
self.pos_embed.shape[-1],
|
||||
int(self.pos_embed.shape[-2]**0.5),
|
||||
scaling_factor=self.pos_embed_scaling_factor,
|
||||
)
|
||||
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
||||
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
if hasattr(self, "y_embedder"):
|
||||
nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# Zero-out adaLN modulation layers in DiT blocks:
|
||||
for block in self.joint_blocks:
|
||||
nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
def cropped_pos_embed(self, hw, device=None):
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h, w = hw
|
||||
@ -995,7 +934,7 @@ class MMDiT(nn.Module):
|
||||
context = self.context_processor(context)
|
||||
|
||||
hw = x.shape[-2:]
|
||||
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype)
|
||||
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
|
||||
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
if y is not None and self.y_embedder is not None:
|
||||
y = self.y_embedder(y) # (N, D)
|
||||
|
||||
@ -206,9 +206,6 @@ class BaseModel(torch.nn.Module):
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.get_dtype() == torch.float16:
|
||||
extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
unet_state_dict["v_pred"] = torch.tensor([])
|
||||
|
||||
@ -572,13 +569,20 @@ class SD3(BaseModel):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = conds.CONDRegular(adm)
|
||||
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention():
|
||||
dtype = self.get_dtype()
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
#TODO: this probably needs to be tweaked
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
return (area * model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
|
||||
else:
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
return (area * 0.3) * (1024 * 1024)
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import List, Any, Optional, Union
|
||||
|
||||
import tqdm
|
||||
from huggingface_hub import hf_hub_download, scan_cache_dir
|
||||
from huggingface_hub.utils import GatedRepoError
|
||||
from requests import Session
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
@ -74,6 +75,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
|
||||
filename=known_file.filename,
|
||||
local_dir=hf_destination_dir,
|
||||
repo_type=known_file.repo_type,
|
||||
revision=known_file.revision,
|
||||
)
|
||||
|
||||
if known_file.convert_to_16_bit and file_size is not None and file_size != 0:
|
||||
@ -134,8 +136,14 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
|
||||
assert path is not None
|
||||
except StopIteration:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logging.error("Error while trying to download a file", exc_info=exc)
|
||||
except GatedRepoError as exc_info:
|
||||
exc_info.append_to_message(f"""
|
||||
Visit the repository, accept the terms, and then do one of the following:
|
||||
|
||||
- Set the HF_TOKEN environment variable to your Hugging Face token; or,
|
||||
- Login to Hugging Face in your terminal using `huggingface-cli login`
|
||||
""")
|
||||
raise exc_info
|
||||
finally:
|
||||
# a path was found for any reason, so we should invalidate the cache
|
||||
if path is not None:
|
||||
@ -167,13 +175,16 @@ KNOWN_CHECKPOINTS = [
|
||||
CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"),
|
||||
CivitFile(112902, 351306, filename="dreamshaperXL_v21TurboDPMSDE.safetensors"),
|
||||
CivitFile(139562, 344487, filename="realvisxlV40_v40Bakedvae.safetensors"),
|
||||
HuggingFile("SG161222/Realistic_Vision_V6.0_B1_noVAE","Realistic_Vision_V6.0_NV_B1_fp16.safetensors"),
|
||||
HuggingFile("SG161222/Realistic_Vision_V5.1_noVAE","Realistic_Vision_V5.1_fp16-no-ema.safetensors"),
|
||||
HuggingFile("SG161222/Realistic_Vision_V6.0_B1_noVAE", "Realistic_Vision_V6.0_NV_B1_fp16.safetensors"),
|
||||
HuggingFile("SG161222/Realistic_Vision_V5.1_noVAE", "Realistic_Vision_V5.1_fp16-no-ema.safetensors"),
|
||||
CivitFile(4384, 128713, filename="dreamshaper_8.safetensors"),
|
||||
CivitFile(7371, 425083, filename="revAnimated_v2Rebirth.safetensors"),
|
||||
CivitFile(4468, 57618, filename="counterfeitV30_v30.safetensors"),
|
||||
CivitFile(241415, 272376, filename="picxReal_10.safetensors"),
|
||||
CivitFile(23900, 95489, filename="anyloraCheckpoint_bakedvaeBlessedFp16.safetensors"),
|
||||
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium.safetensors"),
|
||||
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips.safetensors"),
|
||||
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips_t5xxlfp8.safetensors"),
|
||||
]
|
||||
|
||||
KNOWN_UNCLIP_CHECKPOINTS = [
|
||||
@ -323,7 +334,13 @@ KNOWN_UNET_MODELS: List[Union[CivitFile | HuggingFile]] = [
|
||||
HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet-Comfyui.fp16.safetensors")
|
||||
]
|
||||
|
||||
KNOWN_CLIP_MODELS: List[Union[CivitFile | HuggingFile]] = []
|
||||
KNOWN_CLIP_MODELS: List[Union[CivitFile | HuggingFile]] = [
|
||||
# todo: is this correct?
|
||||
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp16.safetensors", save_with_filename="t5xxl_fp16.safetensors"),
|
||||
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp8_e4m3fn.safetensors", save_with_filename="t5xxl_fp8_e4m3fn.safetensors"),
|
||||
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/clip_g.safetensors", save_with_filename="clip_g.safetensors"),
|
||||
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/clip_l.safetensors", save_with_filename="clip_l.safetensors"),
|
||||
]
|
||||
|
||||
|
||||
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
|
||||
|
||||
@ -54,6 +54,7 @@ class HuggingFile:
|
||||
size: Optional[int] = None
|
||||
force_save_in_repo_id: Optional[bool] = False
|
||||
repo_type: Optional[str] = 'model'
|
||||
revision: Optional[str] = None
|
||||
|
||||
def __str__(self):
|
||||
return self.save_with_filename or split(self.filename)[-1]
|
||||
|
||||
@ -703,6 +703,22 @@ def supports_dtype(device, dtype): # TODO
|
||||
return True
|
||||
return False
|
||||
|
||||
def supports_cast(device, dtype): #TODO
|
||||
if dtype == torch.float32:
|
||||
return True
|
||||
if dtype == torch.float16:
|
||||
return True
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
if directml_enabled: #TODO: test this
|
||||
return False
|
||||
if dtype == torch.bfloat16:
|
||||
return True
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return True
|
||||
if dtype == torch.float8_e5m2:
|
||||
return True
|
||||
return False
|
||||
|
||||
def device_supports_non_blocking(device):
|
||||
if is_device_mps(device):
|
||||
|
||||
@ -815,7 +815,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (get_filename_list_with_downloadable("clip", KNOWN_CLIP_MODELS),),
|
||||
"type": (["stable_diffusion", "stable_cascade"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
@ -823,10 +823,11 @@ class CLIPLoader:
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion"):
|
||||
if type == "stable_diffusion":
|
||||
clip_type = sd.CLIPType.STABLE_DIFFUSION
|
||||
elif type == "stable_cascade":
|
||||
clip_type = sd.CLIPType.STABLE_DIFFUSION
|
||||
if type == "stable_cascade":
|
||||
clip_type = sd.CLIPType.STABLE_CASCADE
|
||||
elif type == "sd3":
|
||||
clip_type = sd.CLIPType.SD3
|
||||
else:
|
||||
logging.warning(f"Unknown clip type argument passed: {type} for model {clip_name}")
|
||||
|
||||
@ -839,16 +840,22 @@ class DualCLIPLoader:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"),), "clip_name2": (
|
||||
folder_paths.get_filename_list("clip"),),
|
||||
}}
|
||||
"type": (["sdxl", "sd3"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2):
|
||||
def load_clip(self, clip_name1, clip_name2, type):
|
||||
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
||||
clip = sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
if type == "sdxl":
|
||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||
elif type == "sd3":
|
||||
clip_type = comfy.sd.CLIPType.SD3
|
||||
|
||||
clip = sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||
return (clip,)
|
||||
|
||||
class CLIPVisionLoader:
|
||||
|
||||
23
comfy/sd.py
23
comfy/sd.py
@ -100,15 +100,21 @@ class CLIP:
|
||||
load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
params['device'] = offload_device
|
||||
params['dtype'] = model_management.text_encoder_dtype(load_device)
|
||||
dtype = model_management.text_encoder_dtype(load_device)
|
||||
params['dtype'] = dtype
|
||||
if "textmodel_json_config" not in params and textmodel_json_config is not None:
|
||||
params['textmodel_json_config'] = textmodel_json_config
|
||||
|
||||
self.cond_stage_model = clip(**(params))
|
||||
|
||||
for dt in self.cond_stage_model.dtypes:
|
||||
if not model_management.supports_cast(load_device, dt):
|
||||
load_device = offload_device
|
||||
|
||||
self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory)
|
||||
self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.layer_idx = None
|
||||
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
|
||||
|
||||
def clone(self):
|
||||
n = CLIP(no_init=True)
|
||||
@ -372,6 +378,7 @@ def load_style_model(ckpt_path):
|
||||
class CLIPType(Enum):
|
||||
STABLE_DIFFUSION = 1
|
||||
STABLE_CASCADE = 2
|
||||
SD3 = 3
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -407,12 +414,20 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
||||
clip_target.clip = sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
|
||||
dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype
|
||||
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
elif len(clip_data) == 2:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif len(clip_data) == 3:
|
||||
clip_target.clip = sd3_clip.SD3ClipModel
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
@ -501,7 +516,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
vae = VAE(sd=vae_sd)
|
||||
|
||||
if output_clip:
|
||||
clip_target = model_config.clip_target()
|
||||
clip_target = model_config.clip_target(state_dict=sd)
|
||||
if clip_target is not None:
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
|
||||
@ -8,7 +8,7 @@ import zipfile
|
||||
from typing import Tuple, Sequence, TypeVar
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
|
||||
from transformers import CLIPTokenizer, PreTrainedTokenizerBase, SpecialTokensMixin
|
||||
|
||||
from . import clip_model
|
||||
from . import model_management
|
||||
@ -410,7 +410,7 @@ class SDTokenizer:
|
||||
def clone(self) -> SDTokenizerT:
|
||||
sd_tokenizer = copy.copy(self)
|
||||
# correctly copy additional vocab
|
||||
sd_tokenizer.tokenizer = self.tokenizer_class.from_pretrained(self.tokenizer_path)
|
||||
sd_tokenizer.tokenizer = self.tokenizer_class.from_pretrained(self.tokenizer_path, legacy=True)
|
||||
sd_tokenizer.add_tokens(sd_tokenizer.additional_tokens)
|
||||
return sd_tokenizer
|
||||
|
||||
@ -568,6 +568,10 @@ class SD1ClipModel(torch.nn.Module):
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, textmodel_json_config=textmodel_json_config, **kwargs))
|
||||
|
||||
self.dtypes = set()
|
||||
if dtype is not None:
|
||||
self.dtypes.add(dtype)
|
||||
|
||||
def set_clip_options(self, options):
|
||||
getattr(self, self.clip).set_clip_options(options)
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ import comfy.t5
|
||||
import torch
|
||||
import os
|
||||
import comfy.model_management
|
||||
import logging
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
@ -43,42 +44,94 @@ class SD3Tokenizer:
|
||||
return self.clip_g.untokenize(token_weight_pair)
|
||||
|
||||
class SD3ClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
|
||||
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype)
|
||||
self.dtypes = set()
|
||||
if clip_l:
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
|
||||
self.dtypes.add(dtype)
|
||||
else:
|
||||
self.clip_l = None
|
||||
|
||||
if clip_g:
|
||||
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
|
||||
self.dtypes.add(dtype)
|
||||
else:
|
||||
self.clip_g = None
|
||||
|
||||
if t5:
|
||||
if dtype_t5 is None:
|
||||
dtype_t5 = dtype
|
||||
elif comfy.model_management.dtype_size(dtype_t5) > comfy.model_management.dtype_size(dtype):
|
||||
dtype_t5 = dtype
|
||||
|
||||
if not comfy.model_management.supports_cast(device, dtype_t5):
|
||||
dtype_t5 = dtype
|
||||
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
|
||||
self.dtypes.add(dtype_t5)
|
||||
else:
|
||||
self.t5xxl = None
|
||||
|
||||
logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5))
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.clip_l.set_clip_options(options)
|
||||
self.clip_g.set_clip_options(options)
|
||||
self.t5xxl.set_clip_options(options)
|
||||
if self.clip_l is not None:
|
||||
self.clip_l.set_clip_options(options)
|
||||
if self.clip_g is not None:
|
||||
self.clip_g.set_clip_options(options)
|
||||
if self.t5xxl is not None:
|
||||
self.t5xxl.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.clip_g.reset_clip_options()
|
||||
self.clip_l.reset_clip_options()
|
||||
self.t5xxl.reset_clip_options()
|
||||
if self.clip_l is not None:
|
||||
self.clip_l.reset_clip_options()
|
||||
if self.clip_g is not None:
|
||||
self.clip_g.reset_clip_options()
|
||||
if self.t5xxl is not None:
|
||||
self.t5xxl.reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_l = token_weight_pairs["l"]
|
||||
token_weight_pairs_g = token_weight_pairs["g"]
|
||||
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
|
||||
lg_out = None
|
||||
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
out = lg_out
|
||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
else:
|
||||
pooled = torch.zeros((1, 1280 + 768), device=comfy.model_management.intermediate_device())
|
||||
pooled = None
|
||||
out = None
|
||||
|
||||
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
|
||||
if lg_out is not None:
|
||||
out = torch.cat([lg_out, t5_out], dim=-2)
|
||||
else:
|
||||
out = t5_out
|
||||
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||
if self.clip_l is not None:
|
||||
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
else:
|
||||
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if self.clip_g is not None:
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||
if lg_out is not None:
|
||||
lg_out = torch.cat([lg_out, g_out], dim=-1)
|
||||
else:
|
||||
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
||||
else:
|
||||
g_out = None
|
||||
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if lg_out is not None:
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
out = lg_out
|
||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
if self.t5xxl is not None:
|
||||
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
|
||||
if lg_out is not None:
|
||||
out = torch.cat([lg_out, t5_out], dim=-2)
|
||||
else:
|
||||
out = t5_out
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros((1, 77, 4096), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if pooled is None:
|
||||
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
||||
|
||||
return out, pooled
|
||||
|
||||
@ -89,3 +142,9 @@ class SD3ClipModel(torch.nn.Module):
|
||||
return self.clip_l.load_sd(sd)
|
||||
else:
|
||||
return self.t5xxl.load_sd(sd)
|
||||
|
||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
|
||||
class SD3ClipModel_(SD3ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
|
||||
return SD3ClipModel_
|
||||
|
||||
@ -51,6 +51,7 @@ class SDXLClipModel(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||
self.dtypes = set([dtype])
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.clip_l.set_clip_options(options)
|
||||
|
||||
@ -54,7 +54,7 @@ class SD15(supported_models_base.BASE):
|
||||
replace_prefix = {"clip_l.": "cond_stage_model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
||||
|
||||
class SD20(supported_models_base.BASE):
|
||||
@ -97,7 +97,7 @@ class SD20(supported_models_base.BASE):
|
||||
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||
return state_dict
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
||||
|
||||
class SD21UnclipL(SD20):
|
||||
@ -159,7 +159,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
return state_dict_g
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
||||
|
||||
class SDXL(supported_models_base.BASE):
|
||||
@ -228,7 +228,7 @@ class SDXL(supported_models_base.BASE):
|
||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
return state_dict_g
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
||||
|
||||
class SSD1B(SDXL):
|
||||
@ -299,7 +299,7 @@ class SVD_img2vid(supported_models_base.BASE):
|
||||
out = model_base.SVD_img2vid(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class SV3D_u(SVD_img2vid):
|
||||
@ -365,7 +365,7 @@ class Stable_Zero123(supported_models_base.BASE):
|
||||
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
|
||||
return out
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class SD_X4Upscaler(SD20):
|
||||
@ -439,7 +439,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
|
||||
out = model_base.StableCascade_C(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
|
||||
|
||||
class Stable_Cascade_B(Stable_Cascade_C):
|
||||
@ -501,14 +501,28 @@ class SD3(supported_models_base.BASE):
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.SD3
|
||||
text_encoder_key_prefix = ["text_encoders."] #TODO?
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SD3(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self):
|
||||
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.SD3ClipModel) #TODO?
|
||||
def clip_target(self, state_dict={}):
|
||||
clip_l = False
|
||||
clip_g = False
|
||||
t5 = False
|
||||
dtype_t5 = None
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||
clip_l = True
|
||||
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||
clip_g = True
|
||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||
if t5_key in state_dict:
|
||||
t5 = True
|
||||
dtype_t5 = state_dict[t5_key].dtype
|
||||
|
||||
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
|
||||
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]
|
||||
|
||||
@ -405,7 +405,10 @@ class SamplerCustom:
|
||||
def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image):
|
||||
latent = latent_image
|
||||
latent_image = latent["samples"]
|
||||
latent = latent.copy()
|
||||
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
|
||||
latent["samples"] = latent_image
|
||||
|
||||
if not add_noise:
|
||||
noise = Noise_EmptyNoise().generate_noise(latent)
|
||||
else:
|
||||
@ -564,7 +567,9 @@ class SamplerCustomAdvanced:
|
||||
def sample(self, noise, guider, sampler, sigmas, latent_image):
|
||||
latent = latent_image
|
||||
latent_image = latent["samples"]
|
||||
latent = latent.copy()
|
||||
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image)
|
||||
latent["samples"] = latent_image
|
||||
|
||||
noise_mask = None
|
||||
if "noise_mask" in latent:
|
||||
|
||||
@ -135,7 +135,7 @@ class ModelSamplingSD3:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
"shift": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
|
||||
@ -182,6 +182,8 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
||||
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
|
||||
elif isinstance(model.model, model_base.SVD_img2vid):
|
||||
metadata["modelspec.architecture"] = "stable-video-diffusion-img2vid-v1"
|
||||
elif isinstance(model.model, comfy.model_base.SD3):
|
||||
metadata["modelspec.architecture"] = "stable-diffusion-v3-medium" #TODO: other SD3 variants
|
||||
else:
|
||||
enable_modelspec = False
|
||||
|
||||
|
||||
@ -1,35 +1,42 @@
|
||||
from comfy.cmd import folder_paths
|
||||
import comfy.sd
|
||||
import comfy.model_management
|
||||
from comfy.nodes import base_nodes as nodes
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.sd
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.model_downloader import get_or_download, get_filename_list_with_downloadable, KNOWN_CLIP_MODELS
|
||||
from comfy.nodes import base_nodes as nodes
|
||||
|
||||
|
||||
class TripleCLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), "clip_name3": (folder_paths.get_filename_list("clip"), )
|
||||
filename_list = get_filename_list_with_downloadable("clip", KNOWN_CLIP_MODELS)
|
||||
return {"required": {"clip_name1": (filename_list,), "clip_name2": (filename_list,), "clip_name3": (filename_list,)
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
||||
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
||||
clip_path3 = folder_paths.get_full_path("clip", clip_name3)
|
||||
clip_path1 = get_or_download("clip", clip_name1, KNOWN_CLIP_MODELS)
|
||||
clip_path2 = get_or_download("clip", clip_name2, KNOWN_CLIP_MODELS)
|
||||
clip_path3 = get_or_download("clip", clip_name3, KNOWN_CLIP_MODELS)
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return (clip,)
|
||||
|
||||
|
||||
class EmptySD3LatentImage:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 512, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
return {"required": {"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
@ -37,18 +44,20 @@ class EmptySD3LatentImage:
|
||||
|
||||
def generate(self, width, height, batch_size=1):
|
||||
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609
|
||||
return ({"samples":latent}, )
|
||||
return ({"samples": latent},)
|
||||
|
||||
|
||||
class CLIPTextEncodeSD3:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"clip": ("CLIP",),
|
||||
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"empty_padding": (["none", "empty_prompt"], )
|
||||
}}
|
||||
"empty_padding": (["none", "empty_prompt"],)
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
@ -67,7 +76,7 @@ class CLIPTextEncodeSD3:
|
||||
tokens["l"] = clip.tokenize(clip_l)["l"]
|
||||
|
||||
if len(t5xxl) == 0 and no_padding:
|
||||
tokens["t5xxl"] = []
|
||||
tokens["t5xxl"] = []
|
||||
else:
|
||||
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||
if len(tokens["l"]) != len(tokens["g"]):
|
||||
@ -77,7 +86,7 @@ class CLIPTextEncodeSD3:
|
||||
while len(tokens["l"]) > len(tokens["g"]):
|
||||
tokens["g"] += empty["g"]
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return ([[cond, {"pooled_output": pooled}]], )
|
||||
return ([[cond, {"pooled_output": pooled}]],)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@ -8,7 +8,7 @@ transformers>=4.29.1
|
||||
peft
|
||||
torchinfo
|
||||
fschat[model_worker]
|
||||
safetensors>=0.3.0
|
||||
safetensors>=0.4.2
|
||||
bitsandbytes
|
||||
pytorch-lightning>=2.0.0
|
||||
aiohttp>=3.8.4
|
||||
|
||||
Loading…
Reference in New Issue
Block a user