diff --git a/README.md b/README.md index 7f2cdb274..94f2bacb4 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,35 @@ A vanilla, up-to-date fork of [ComfyUI](https://github.com/comfyanonymous/comfyu - Automated tests for new features. - Automatic model downloading for well-known models. +### Upstream Features + +- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. +- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) and [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/) +- Asynchronous Queue system +- Many optimizations: Only re-executes the parts of the workflow that changes between executions. +- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) +- Works even if you don't have a GPU with: ```--cpu``` (slow) +- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models. +- Embeddings/Textual inversion +- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) +- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/) +- Loading full workflows (with seeds) from generated PNG files. +- Saving/Loading workflows as Json files. +- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones. +- [Area Composition](https://comfyanonymous.github.io/ComfyUI_examples/area_composition/) +- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models. +- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/) +- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) +- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) +- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) +- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/) +- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) +- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/) +- Latent previews with [TAESD](#how-to-show-high-quality-previews) +- Starts up very fast. +- Works fully offline: will never download anything. +- [Config file](extra_model_paths.yaml.example) to set the search paths for models. + ### Table of Contents - [Workflows](https://comfyanonymous.github.io/ComfyUI_examples/) diff --git a/comfy/conds.py b/comfy/conds.py index 76ca1cfd1..1317469ab 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -29,7 +29,12 @@ class CONDRegular: class CONDNoiseShape(CONDRegular): def process_cond(self, batch_size, device, area, **kwargs): - data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + data = self.cond + if area is not None: + dims = len(area) // 2 + for i in range(dims): + data = data.narrow(i + 2, area[i + dims], area[i]) + return self._copy_with(utils.repeat_to_batch_size(data, batch_size).to(device)) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 5e7afc8d3..0cb6bd312 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -835,72 +835,11 @@ class MMDiT(nn.Module): ) self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) - # self.initialize_weights() if compile_core: assert False self.forward_core_with_concat = torch.compile(self.forward_core_with_concat) - def initialize_weights(self): - # TODO: Init context_embedder? - # Initialize transformer layers: - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - self.apply(_basic_init) - - # Initialize (and freeze) pos_embed by sin-cos embedding - if self.pos_embed is not None: - pos_embed_grid_size = ( - int(self.x_embedder.num_patches**0.5) - if self.pos_embed_max_size is None - else self.pos_embed_max_size - ) - pos_embed = get_2d_sincos_pos_embed( - self.pos_embed.shape[-1], - int(self.x_embedder.num_patches**0.5), - pos_embed_grid_size, - scaling_factor=self.pos_embed_scaling_factor, - offset=self.pos_embed_offset, - ) - - - pos_embed = get_2d_sincos_pos_embed( - self.pos_embed.shape[-1], - int(self.pos_embed.shape[-2]**0.5), - scaling_factor=self.pos_embed_scaling_factor, - ) - self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) - - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.x_embedder.proj.bias, 0) - - if hasattr(self, "y_embedder"): - nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # Zero-out adaLN modulation layers in DiT blocks: - for block in self.joint_blocks: - nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0) - nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.linear.weight, 0) - nn.init.constant_(self.final_layer.linear.bias, 0) - def cropped_pos_embed(self, hw, device=None): p = self.x_embedder.patch_size[0] h, w = hw @@ -995,7 +934,7 @@ class MMDiT(nn.Module): context = self.context_processor(context) hw = x.shape[-2:] - x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype) + x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device) c = self.t_embedder(t, dtype=x.dtype) # (N, D) if y is not None and self.y_embedder is not None: y = self.y_embedder(y) # (N, D) diff --git a/comfy/model_base.py b/comfy/model_base.py index 374d889a3..a559b1be3 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -206,9 +206,6 @@ class BaseModel(torch.nn.Module): unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) - if self.get_dtype() == torch.float16: - extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds) - if self.model_type == ModelType.V_PREDICTION: unet_state_dict["v_pred"] = torch.tensor([]) @@ -572,13 +569,20 @@ class SD3(BaseModel): return kwargs["pooled_output"] def extra_conds(self, **kwargs): - out = {} - adm = self.encode_adm(**kwargs) - if adm is not None: - out['y'] = conds.CONDRegular(adm) - + out = super().extra_conds(**kwargs) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) return out + def memory_required(self, input_shape): + if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention(): + dtype = self.get_dtype() + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype + #TODO: this probably needs to be tweaked + area = input_shape[0] * input_shape[2] * input_shape[3] + return (area * model_management.dtype_size(dtype) * 0.012) * (1024 * 1024) + else: + area = input_shape[0] * input_shape[2] * input_shape[3] + return (area * 0.3) * (1024 * 1024) diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index de903a475..4f7ea22b9 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -8,6 +8,7 @@ from typing import List, Any, Optional, Union import tqdm from huggingface_hub import hf_hub_download, scan_cache_dir +from huggingface_hub.utils import GatedRepoError from requests import Session from safetensors import safe_open from safetensors.torch import save_file @@ -74,6 +75,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi filename=known_file.filename, local_dir=hf_destination_dir, repo_type=known_file.repo_type, + revision=known_file.revision, ) if known_file.convert_to_16_bit and file_size is not None and file_size != 0: @@ -134,8 +136,14 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi assert path is not None except StopIteration: pass - except Exception as exc: - logging.error("Error while trying to download a file", exc_info=exc) + except GatedRepoError as exc_info: + exc_info.append_to_message(f""" +Visit the repository, accept the terms, and then do one of the following: + + - Set the HF_TOKEN environment variable to your Hugging Face token; or, + - Login to Hugging Face in your terminal using `huggingface-cli login` +""") + raise exc_info finally: # a path was found for any reason, so we should invalidate the cache if path is not None: @@ -167,13 +175,16 @@ KNOWN_CHECKPOINTS = [ CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"), CivitFile(112902, 351306, filename="dreamshaperXL_v21TurboDPMSDE.safetensors"), CivitFile(139562, 344487, filename="realvisxlV40_v40Bakedvae.safetensors"), - HuggingFile("SG161222/Realistic_Vision_V6.0_B1_noVAE","Realistic_Vision_V6.0_NV_B1_fp16.safetensors"), - HuggingFile("SG161222/Realistic_Vision_V5.1_noVAE","Realistic_Vision_V5.1_fp16-no-ema.safetensors"), + HuggingFile("SG161222/Realistic_Vision_V6.0_B1_noVAE", "Realistic_Vision_V6.0_NV_B1_fp16.safetensors"), + HuggingFile("SG161222/Realistic_Vision_V5.1_noVAE", "Realistic_Vision_V5.1_fp16-no-ema.safetensors"), CivitFile(4384, 128713, filename="dreamshaper_8.safetensors"), CivitFile(7371, 425083, filename="revAnimated_v2Rebirth.safetensors"), CivitFile(4468, 57618, filename="counterfeitV30_v30.safetensors"), CivitFile(241415, 272376, filename="picxReal_10.safetensors"), CivitFile(23900, 95489, filename="anyloraCheckpoint_bakedvaeBlessedFp16.safetensors"), + HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium.safetensors"), + HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips.safetensors"), + HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips_t5xxlfp8.safetensors"), ] KNOWN_UNCLIP_CHECKPOINTS = [ @@ -323,7 +334,13 @@ KNOWN_UNET_MODELS: List[Union[CivitFile | HuggingFile]] = [ HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet-Comfyui.fp16.safetensors") ] -KNOWN_CLIP_MODELS: List[Union[CivitFile | HuggingFile]] = [] +KNOWN_CLIP_MODELS: List[Union[CivitFile | HuggingFile]] = [ + # todo: is this correct? + HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp16.safetensors", save_with_filename="t5xxl_fp16.safetensors"), + HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp8_e4m3fn.safetensors", save_with_filename="t5xxl_fp8_e4m3fn.safetensors"), + HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/clip_g.safetensors", save_with_filename="clip_g.safetensors"), + HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/clip_l.safetensors", save_with_filename="clip_l.safetensors"), +] def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]: diff --git a/comfy/model_downloader_types.py b/comfy/model_downloader_types.py index b0ef948b0..bea8c6f59 100644 --- a/comfy/model_downloader_types.py +++ b/comfy/model_downloader_types.py @@ -54,6 +54,7 @@ class HuggingFile: size: Optional[int] = None force_save_in_repo_id: Optional[bool] = False repo_type: Optional[str] = 'model' + revision: Optional[str] = None def __str__(self): return self.save_with_filename or split(self.filename)[-1] diff --git a/comfy/model_management.py b/comfy/model_management.py index e2381b90b..55283f1de 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -703,6 +703,22 @@ def supports_dtype(device, dtype): # TODO return True return False +def supports_cast(device, dtype): #TODO + if dtype == torch.float32: + return True + if dtype == torch.float16: + return True + if is_device_mps(device): + return False + if directml_enabled: #TODO: test this + return False + if dtype == torch.bfloat16: + return True + if dtype == torch.float8_e4m3fn: + return True + if dtype == torch.float8_e5m2: + return True + return False def device_supports_non_blocking(device): if is_device_mps(device): diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 11c1114f2..8847ee782 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -815,7 +815,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (get_filename_list_with_downloadable("clip", KNOWN_CLIP_MODELS),), - "type": (["stable_diffusion", "stable_cascade"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3"], ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -823,10 +823,11 @@ class CLIPLoader: CATEGORY = "advanced/loaders" def load_clip(self, clip_name, type="stable_diffusion"): - if type == "stable_diffusion": - clip_type = sd.CLIPType.STABLE_DIFFUSION - elif type == "stable_cascade": + clip_type = sd.CLIPType.STABLE_DIFFUSION + if type == "stable_cascade": clip_type = sd.CLIPType.STABLE_CASCADE + elif type == "sd3": + clip_type = sd.CLIPType.SD3 else: logging.warning(f"Unknown clip type argument passed: {type} for model {clip_name}") @@ -839,16 +840,22 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"),), "clip_name2": ( folder_paths.get_filename_list("clip"),), - }} + "type": (["sdxl", "sd3"], ), + }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" CATEGORY = "advanced/loaders" - def load_clip(self, clip_name1, clip_name2): + def load_clip(self, clip_name1, clip_name2, type): clip_path1 = folder_paths.get_full_path("clip", clip_name1) clip_path2 = folder_paths.get_full_path("clip", clip_name2) - clip = sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings")) + if type == "sdxl": + clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + elif type == "sd3": + clip_type = comfy.sd.CLIPType.SD3 + + clip = sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) return (clip,) class CLIPVisionLoader: diff --git a/comfy/sd.py b/comfy/sd.py index b66d05e39..d18515290 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -100,15 +100,21 @@ class CLIP: load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() params['device'] = offload_device - params['dtype'] = model_management.text_encoder_dtype(load_device) + dtype = model_management.text_encoder_dtype(load_device) + params['dtype'] = dtype if "textmodel_json_config" not in params and textmodel_json_config is not None: params['textmodel_json_config'] = textmodel_json_config self.cond_stage_model = clip(**(params)) + for dt in self.cond_stage_model.dtypes: + if not model_management.supports_cast(load_device, dt): + load_device = offload_device + self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory) self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.layer_idx = None + logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device)) def clone(self): n = CLIP(no_init=True) @@ -372,6 +378,7 @@ def load_style_model(ckpt_path): class CLIPType(Enum): STABLE_DIFFUSION = 1 STABLE_CASCADE = 2 + SD3 = 3 @dataclasses.dataclass @@ -407,12 +414,20 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: clip_target.clip = sd2_clip.SD2ClipModel clip_target.tokenizer = sd2_clip.SD2Tokenizer + elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]: + dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype + clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5) + clip_target.tokenizer = sd3_clip.SD3Tokenizer else: clip_target.clip = sd1_clip.SD1ClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer elif len(clip_data) == 2: - clip_target.clip = sdxl_clip.SDXLClipModel - clip_target.tokenizer = sdxl_clip.SDXLTokenizer + if clip_type == CLIPType.SD3: + clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False) + clip_target.tokenizer = sd3_clip.SD3Tokenizer + else: + clip_target.clip = sdxl_clip.SDXLClipModel + clip_target.tokenizer = sdxl_clip.SDXLTokenizer elif len(clip_data) == 3: clip_target.clip = sd3_clip.SD3ClipModel clip_target.tokenizer = sd3_clip.SD3Tokenizer @@ -501,7 +516,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o vae = VAE(sd=vae_sd) if output_clip: - clip_target = model_config.clip_target() + clip_target = model_config.clip_target(state_dict=sd) if clip_target is not None: clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 5665b129f..68c3b69ad 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -8,7 +8,7 @@ import zipfile from typing import Tuple, Sequence, TypeVar import torch -from transformers import CLIPTokenizer, PreTrainedTokenizerBase +from transformers import CLIPTokenizer, PreTrainedTokenizerBase, SpecialTokensMixin from . import clip_model from . import model_management @@ -410,7 +410,7 @@ class SDTokenizer: def clone(self) -> SDTokenizerT: sd_tokenizer = copy.copy(self) # correctly copy additional vocab - sd_tokenizer.tokenizer = self.tokenizer_class.from_pretrained(self.tokenizer_path) + sd_tokenizer.tokenizer = self.tokenizer_class.from_pretrained(self.tokenizer_path, legacy=True) sd_tokenizer.add_tokens(sd_tokenizer.additional_tokens) return sd_tokenizer @@ -568,6 +568,10 @@ class SD1ClipModel(torch.nn.Module): self.clip = "clip_{}".format(self.clip_name) setattr(self, self.clip, clip_model(device=device, dtype=dtype, textmodel_json_config=textmodel_json_config, **kwargs)) + self.dtypes = set() + if dtype is not None: + self.dtypes.add(dtype) + def set_clip_options(self, options): getattr(self, self.clip).set_clip_options(options) diff --git a/comfy/sd3_clip.py b/comfy/sd3_clip.py index bbbf6affd..0713eb285 100644 --- a/comfy/sd3_clip.py +++ b/comfy/sd3_clip.py @@ -5,6 +5,7 @@ import comfy.t5 import torch import os import comfy.model_management +import logging class T5XXLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): @@ -43,42 +44,94 @@ class SD3Tokenizer: return self.clip_g.untokenize(token_weight_pair) class SD3ClipModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None): + def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None): super().__init__() - self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False) - self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype) - self.t5xxl = T5XXLModel(device=device, dtype=dtype) + self.dtypes = set() + if clip_l: + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False) + self.dtypes.add(dtype) + else: + self.clip_l = None + + if clip_g: + self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype) + self.dtypes.add(dtype) + else: + self.clip_g = None + + if t5: + if dtype_t5 is None: + dtype_t5 = dtype + elif comfy.model_management.dtype_size(dtype_t5) > comfy.model_management.dtype_size(dtype): + dtype_t5 = dtype + + if not comfy.model_management.supports_cast(device, dtype_t5): + dtype_t5 = dtype + + self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5) + self.dtypes.add(dtype_t5) + else: + self.t5xxl = None + + logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5)) def set_clip_options(self, options): - self.clip_l.set_clip_options(options) - self.clip_g.set_clip_options(options) - self.t5xxl.set_clip_options(options) + if self.clip_l is not None: + self.clip_l.set_clip_options(options) + if self.clip_g is not None: + self.clip_g.set_clip_options(options) + if self.t5xxl is not None: + self.t5xxl.set_clip_options(options) def reset_clip_options(self): - self.clip_g.reset_clip_options() - self.clip_l.reset_clip_options() - self.t5xxl.reset_clip_options() + if self.clip_l is not None: + self.clip_l.reset_clip_options() + if self.clip_g is not None: + self.clip_g.reset_clip_options() + if self.t5xxl is not None: + self.t5xxl.reset_clip_options() def encode_token_weights(self, token_weight_pairs): token_weight_pairs_l = token_weight_pairs["l"] token_weight_pairs_g = token_weight_pairs["g"] token_weight_pars_t5 = token_weight_pairs["t5xxl"] lg_out = None - if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0: - l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) - g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) - lg_out = torch.cat([l_out, g_out], dim=-1) - lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) - out = lg_out - pooled = torch.cat((l_pooled, g_pooled), dim=-1) - else: - pooled = torch.zeros((1, 1280 + 768), device=comfy.model_management.intermediate_device()) + pooled = None + out = None - t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5) - if lg_out is not None: - out = torch.cat([lg_out, t5_out], dim=-2) - else: - out = t5_out + if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0: + if self.clip_l is not None: + lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) + else: + l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device()) + + if self.clip_g is not None: + g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) + if lg_out is not None: + lg_out = torch.cat([lg_out, g_out], dim=-1) + else: + lg_out = torch.nn.functional.pad(g_out, (768, 0)) + else: + g_out = None + g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device()) + + if lg_out is not None: + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + out = lg_out + pooled = torch.cat((l_pooled, g_pooled), dim=-1) + + if self.t5xxl is not None: + t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5) + if lg_out is not None: + out = torch.cat([lg_out, t5_out], dim=-2) + else: + out = t5_out + + if out is None: + out = torch.zeros((1, 77, 4096), device=comfy.model_management.intermediate_device()) + + if pooled is None: + pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device()) return out, pooled @@ -89,3 +142,9 @@ class SD3ClipModel(torch.nn.Module): return self.clip_l.load_sd(sd) else: return self.t5xxl.load_sd(sd) + +def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None): + class SD3ClipModel_(SD3ClipModel): + def __init__(self, device="cpu", dtype=None): + super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype) + return SD3ClipModel_ diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 40988dc25..52c904c52 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -51,6 +51,7 @@ class SDXLClipModel(torch.nn.Module): super().__init__() self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_g = SDXLClipG(device=device, dtype=dtype) + self.dtypes = set([dtype]) def set_clip_options(self, options): self.clip_l.set_clip_options(options) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6bb76c96f..c8ddf3e2c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -54,7 +54,7 @@ class SD15(supported_models_base.BASE): replace_prefix = {"clip_l.": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) - def clip_target(self): + def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel) class SD20(supported_models_base.BASE): @@ -97,7 +97,7 @@ class SD20(supported_models_base.BASE): state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) return state_dict - def clip_target(self): + def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel) class SD21UnclipL(SD20): @@ -159,7 +159,7 @@ class SDXLRefiner(supported_models_base.BASE): state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) return state_dict_g - def clip_target(self): + def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel) class SDXL(supported_models_base.BASE): @@ -228,7 +228,7 @@ class SDXL(supported_models_base.BASE): state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) return state_dict_g - def clip_target(self): + def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel) class SSD1B(SDXL): @@ -299,7 +299,7 @@ class SVD_img2vid(supported_models_base.BASE): out = model_base.SVD_img2vid(self, device=device) return out - def clip_target(self): + def clip_target(self, state_dict={}): return None class SV3D_u(SVD_img2vid): @@ -365,7 +365,7 @@ class Stable_Zero123(supported_models_base.BASE): out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"]) return out - def clip_target(self): + def clip_target(self, state_dict={}): return None class SD_X4Upscaler(SD20): @@ -439,7 +439,7 @@ class Stable_Cascade_C(supported_models_base.BASE): out = model_base.StableCascade_C(self, device=device) return out - def clip_target(self): + def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel) class Stable_Cascade_B(Stable_Cascade_C): @@ -501,14 +501,28 @@ class SD3(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.SD3 - text_encoder_key_prefix = ["text_encoders."] #TODO? + text_encoder_key_prefix = ["text_encoders."] def get_model(self, state_dict, prefix="", device=None): out = model_base.SD3(self, device=device) return out - def clip_target(self): - return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.SD3ClipModel) #TODO? + def clip_target(self, state_dict={}): + clip_l = False + clip_g = False + t5 = False + dtype_t5 = None + pref = self.text_encoder_key_prefix[0] + if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict: + clip_l = True + if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict: + clip_g = True + t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) + if t5_key in state_dict: + t5 = True + dtype_t5 = state_dict[t5_key].dtype + + return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5)) models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3] diff --git a/comfy_extras/nodes/nodes_custom_sampler.py b/comfy_extras/nodes/nodes_custom_sampler.py index 5e8c897e3..abd721330 100644 --- a/comfy_extras/nodes/nodes_custom_sampler.py +++ b/comfy_extras/nodes/nodes_custom_sampler.py @@ -405,7 +405,10 @@ class SamplerCustom: def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): latent = latent_image latent_image = latent["samples"] + latent = latent.copy() latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) + latent["samples"] = latent_image + if not add_noise: noise = Noise_EmptyNoise().generate_noise(latent) else: @@ -564,7 +567,9 @@ class SamplerCustomAdvanced: def sample(self, noise, guider, sampler, sigmas, latent_image): latent = latent_image latent_image = latent["samples"] + latent = latent.copy() latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image) + latent["samples"] = latent_image noise_mask = None if "noise_mask" in latent: diff --git a/comfy_extras/nodes/nodes_model_advanced.py b/comfy_extras/nodes/nodes_model_advanced.py index 497cb501a..801ad0b77 100644 --- a/comfy_extras/nodes/nodes_model_advanced.py +++ b/comfy_extras/nodes/nodes_model_advanced.py @@ -135,7 +135,7 @@ class ModelSamplingSD3: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01}), + "shift": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step":0.01}), }} RETURN_TYPES = ("MODEL",) diff --git a/comfy_extras/nodes/nodes_model_merging.py b/comfy_extras/nodes/nodes_model_merging.py index 8562a09f3..293a7ae9b 100644 --- a/comfy_extras/nodes/nodes_model_merging.py +++ b/comfy_extras/nodes/nodes_model_merging.py @@ -182,6 +182,8 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" elif isinstance(model.model, model_base.SVD_img2vid): metadata["modelspec.architecture"] = "stable-video-diffusion-img2vid-v1" + elif isinstance(model.model, comfy.model_base.SD3): + metadata["modelspec.architecture"] = "stable-diffusion-v3-medium" #TODO: other SD3 variants else: enable_modelspec = False diff --git a/comfy_extras/nodes/nodes_sd3.py b/comfy_extras/nodes/nodes_sd3.py index 350311be5..706de7ae6 100644 --- a/comfy_extras/nodes/nodes_sd3.py +++ b/comfy_extras/nodes/nodes_sd3.py @@ -1,35 +1,42 @@ -from comfy.cmd import folder_paths -import comfy.sd -import comfy.model_management -from comfy.nodes import base_nodes as nodes import torch +import comfy.model_management +import comfy.sd +from comfy.cmd import folder_paths +from comfy.model_downloader import get_or_download, get_filename_list_with_downloadable, KNOWN_CLIP_MODELS +from comfy.nodes import base_nodes as nodes + + class TripleCLIPLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), "clip_name3": (folder_paths.get_filename_list("clip"), ) + filename_list = get_filename_list_with_downloadable("clip", KNOWN_CLIP_MODELS) + return {"required": {"clip_name1": (filename_list,), "clip_name2": (filename_list,), "clip_name3": (filename_list,) }} + RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" CATEGORY = "advanced/loaders" def load_clip(self, clip_name1, clip_name2, clip_name3): - clip_path1 = folder_paths.get_full_path("clip", clip_name1) - clip_path2 = folder_paths.get_full_path("clip", clip_name2) - clip_path3 = folder_paths.get_full_path("clip", clip_name3) + clip_path1 = get_or_download("clip", clip_name1, KNOWN_CLIP_MODELS) + clip_path2 = get_or_download("clip", clip_name2, KNOWN_CLIP_MODELS) + clip_path3 = get_or_download("clip", clip_name3, KNOWN_CLIP_MODELS) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) return (clip,) + class EmptySD3LatentImage: def __init__(self): self.device = comfy.model_management.intermediate_device() @classmethod def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 512, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} + return {"required": {"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} + RETURN_TYPES = ("LATENT",) FUNCTION = "generate" @@ -37,18 +44,20 @@ class EmptySD3LatentImage: def generate(self, width, height, batch_size=1): latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609 - return ({"samples":latent}, ) + return ({"samples": latent},) + class CLIPTextEncodeSD3: @classmethod def INPUT_TYPES(s): return {"required": { - "clip": ("CLIP", ), + "clip": ("CLIP",), "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "empty_padding": (["none", "empty_prompt"], ) - }} + "empty_padding": (["none", "empty_prompt"],) + }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" @@ -67,7 +76,7 @@ class CLIPTextEncodeSD3: tokens["l"] = clip.tokenize(clip_l)["l"] if len(t5xxl) == 0 and no_padding: - tokens["t5xxl"] = [] + tokens["t5xxl"] = [] else: tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] if len(tokens["l"]) != len(tokens["g"]): @@ -77,7 +86,7 @@ class CLIPTextEncodeSD3: while len(tokens["l"]) > len(tokens["g"]): tokens["g"] += empty["g"] cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) - return ([[cond, {"pooled_output": pooled}]], ) + return ([[cond, {"pooled_output": pooled}]],) NODE_CLASS_MAPPINGS = { diff --git a/requirements.txt b/requirements.txt index c4dcf9ce1..8bc0a6766 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ transformers>=4.29.1 peft torchinfo fschat[model_worker] -safetensors>=0.3.0 +safetensors>=0.4.2 bitsandbytes pytorch-lightning>=2.0.0 aiohttp>=3.8.4