Add known SD3 model files, merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-06-12 10:56:41 -07:00
commit cac6690481
18 changed files with 269 additions and 142 deletions

View File

@ -18,6 +18,35 @@ A vanilla, up-to-date fork of [ComfyUI](https://github.com/comfyanonymous/comfyu
- Automated tests for new features.
- Automatic model downloading for well-known models.
### Upstream Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) and [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
- Works even if you don't have a GPU with: ```--cpu``` (slow)
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
- Embeddings/Textual inversion
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
- Loading full workflows (with seeds) from generated PNG files.
- Saving/Loading workflows as Json files.
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
- [Area Composition](https://comfyanonymous.github.io/ComfyUI_examples/area_composition/)
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast.
- Works fully offline: will never download anything.
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
### Table of Contents
- [Workflows](https://comfyanonymous.github.io/ComfyUI_examples/)

View File

@ -29,7 +29,12 @@ class CONDRegular:
class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, device, area, **kwargs):
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
data = self.cond
if area is not None:
dims = len(area) // 2
for i in range(dims):
data = data.narrow(i + 2, area[i + dims], area[i])
return self._copy_with(utils.repeat_to_batch_size(data, batch_size).to(device))

View File

@ -835,72 +835,11 @@ class MMDiT(nn.Module):
)
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
# self.initialize_weights()
if compile_core:
assert False
self.forward_core_with_concat = torch.compile(self.forward_core_with_concat)
def initialize_weights(self):
# TODO: Init context_embedder?
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding
if self.pos_embed is not None:
pos_embed_grid_size = (
int(self.x_embedder.num_patches**0.5)
if self.pos_embed_max_size is None
else self.pos_embed_max_size
)
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1],
int(self.x_embedder.num_patches**0.5),
pos_embed_grid_size,
scaling_factor=self.pos_embed_scaling_factor,
offset=self.pos_embed_offset,
)
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1],
int(self.pos_embed.shape[-2]**0.5),
scaling_factor=self.pos_embed_scaling_factor,
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
if hasattr(self, "y_embedder"):
nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.joint_blocks:
nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def cropped_pos_embed(self, hw, device=None):
p = self.x_embedder.patch_size[0]
h, w = hw
@ -995,7 +934,7 @@ class MMDiT(nn.Module):
context = self.context_processor(context)
hw = x.shape[-2:]
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype)
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
if y is not None and self.y_embedder is not None:
y = self.y_embedder(y) # (N, D)

View File

@ -206,9 +206,6 @@ class BaseModel(torch.nn.Module):
unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.get_dtype() == torch.float16:
extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds)
if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([])
@ -572,13 +569,20 @@ class SD3(BaseModel):
return kwargs["pooled_output"]
def extra_conds(self, **kwargs):
out = {}
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = conds.CONDRegular(adm)
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = conds.CONDRegular(cross_attn)
return out
def memory_required(self, input_shape):
if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this probably needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
else:
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * 0.3) * (1024 * 1024)

View File

@ -8,6 +8,7 @@ from typing import List, Any, Optional, Union
import tqdm
from huggingface_hub import hf_hub_download, scan_cache_dir
from huggingface_hub.utils import GatedRepoError
from requests import Session
from safetensors import safe_open
from safetensors.torch import save_file
@ -74,6 +75,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
filename=known_file.filename,
local_dir=hf_destination_dir,
repo_type=known_file.repo_type,
revision=known_file.revision,
)
if known_file.convert_to_16_bit and file_size is not None and file_size != 0:
@ -134,8 +136,14 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
assert path is not None
except StopIteration:
pass
except Exception as exc:
logging.error("Error while trying to download a file", exc_info=exc)
except GatedRepoError as exc_info:
exc_info.append_to_message(f"""
Visit the repository, accept the terms, and then do one of the following:
- Set the HF_TOKEN environment variable to your Hugging Face token; or,
- Login to Hugging Face in your terminal using `huggingface-cli login`
""")
raise exc_info
finally:
# a path was found for any reason, so we should invalidate the cache
if path is not None:
@ -167,13 +175,16 @@ KNOWN_CHECKPOINTS = [
CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"),
CivitFile(112902, 351306, filename="dreamshaperXL_v21TurboDPMSDE.safetensors"),
CivitFile(139562, 344487, filename="realvisxlV40_v40Bakedvae.safetensors"),
HuggingFile("SG161222/Realistic_Vision_V6.0_B1_noVAE","Realistic_Vision_V6.0_NV_B1_fp16.safetensors"),
HuggingFile("SG161222/Realistic_Vision_V5.1_noVAE","Realistic_Vision_V5.1_fp16-no-ema.safetensors"),
HuggingFile("SG161222/Realistic_Vision_V6.0_B1_noVAE", "Realistic_Vision_V6.0_NV_B1_fp16.safetensors"),
HuggingFile("SG161222/Realistic_Vision_V5.1_noVAE", "Realistic_Vision_V5.1_fp16-no-ema.safetensors"),
CivitFile(4384, 128713, filename="dreamshaper_8.safetensors"),
CivitFile(7371, 425083, filename="revAnimated_v2Rebirth.safetensors"),
CivitFile(4468, 57618, filename="counterfeitV30_v30.safetensors"),
CivitFile(241415, 272376, filename="picxReal_10.safetensors"),
CivitFile(23900, 95489, filename="anyloraCheckpoint_bakedvaeBlessedFp16.safetensors"),
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium.safetensors"),
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips.safetensors"),
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips_t5xxlfp8.safetensors"),
]
KNOWN_UNCLIP_CHECKPOINTS = [
@ -323,7 +334,13 @@ KNOWN_UNET_MODELS: List[Union[CivitFile | HuggingFile]] = [
HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet-Comfyui.fp16.safetensors")
]
KNOWN_CLIP_MODELS: List[Union[CivitFile | HuggingFile]] = []
KNOWN_CLIP_MODELS: List[Union[CivitFile | HuggingFile]] = [
# todo: is this correct?
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp16.safetensors", save_with_filename="t5xxl_fp16.safetensors"),
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp8_e4m3fn.safetensors", save_with_filename="t5xxl_fp8_e4m3fn.safetensors"),
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/clip_g.safetensors", save_with_filename="clip_g.safetensors"),
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/clip_l.safetensors", save_with_filename="clip_l.safetensors"),
]
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:

View File

@ -54,6 +54,7 @@ class HuggingFile:
size: Optional[int] = None
force_save_in_repo_id: Optional[bool] = False
repo_type: Optional[str] = 'model'
revision: Optional[str] = None
def __str__(self):
return self.save_with_filename or split(self.filename)[-1]

View File

@ -703,6 +703,22 @@ def supports_dtype(device, dtype): # TODO
return True
return False
def supports_cast(device, dtype): #TODO
if dtype == torch.float32:
return True
if dtype == torch.float16:
return True
if is_device_mps(device):
return False
if directml_enabled: #TODO: test this
return False
if dtype == torch.bfloat16:
return True
if dtype == torch.float8_e4m3fn:
return True
if dtype == torch.float8_e5m2:
return True
return False
def device_supports_non_blocking(device):
if is_device_mps(device):

View File

@ -815,7 +815,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (get_filename_list_with_downloadable("clip", KNOWN_CLIP_MODELS),),
"type": (["stable_diffusion", "stable_cascade"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@ -823,10 +823,11 @@ class CLIPLoader:
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name, type="stable_diffusion"):
if type == "stable_diffusion":
clip_type = sd.CLIPType.STABLE_DIFFUSION
elif type == "stable_cascade":
clip_type = sd.CLIPType.STABLE_DIFFUSION
if type == "stable_cascade":
clip_type = sd.CLIPType.STABLE_CASCADE
elif type == "sd3":
clip_type = sd.CLIPType.SD3
else:
logging.warning(f"Unknown clip type argument passed: {type} for model {clip_name}")
@ -839,16 +840,22 @@ class DualCLIPLoader:
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"),), "clip_name2": (
folder_paths.get_filename_list("clip"),),
}}
"type": (["sdxl", "sd3"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2):
def load_clip(self, clip_name1, clip_name2, type):
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip = sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"))
if type == "sdxl":
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3
clip = sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,)
class CLIPVisionLoader:

View File

@ -100,15 +100,21 @@ class CLIP:
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device
params['dtype'] = model_management.text_encoder_dtype(load_device)
dtype = model_management.text_encoder_dtype(load_device)
params['dtype'] = dtype
if "textmodel_json_config" not in params and textmodel_json_config is not None:
params['textmodel_json_config'] = textmodel_json_config
self.cond_stage_model = clip(**(params))
for dt in self.cond_stage_model.dtypes:
if not model_management.supports_cast(load_device, dt):
load_device = offload_device
self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory)
self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.layer_idx = None
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
def clone(self):
n = CLIP(no_init=True)
@ -372,6 +378,7 @@ def load_style_model(ckpt_path):
class CLIPType(Enum):
STABLE_DIFFUSION = 1
STABLE_CASCADE = 2
SD3 = 3
@dataclasses.dataclass
@ -407,12 +414,20 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
clip_target.tokenizer = sd3_clip.SD3Tokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
if clip_type == CLIPType.SD3:
clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
clip_target.tokenizer = sd3_clip.SD3Tokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif len(clip_data) == 3:
clip_target.clip = sd3_clip.SD3ClipModel
clip_target.tokenizer = sd3_clip.SD3Tokenizer
@ -501,7 +516,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
vae = VAE(sd=vae_sd)
if output_clip:
clip_target = model_config.clip_target()
clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:

View File

@ -8,7 +8,7 @@ import zipfile
from typing import Tuple, Sequence, TypeVar
import torch
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
from transformers import CLIPTokenizer, PreTrainedTokenizerBase, SpecialTokensMixin
from . import clip_model
from . import model_management
@ -410,7 +410,7 @@ class SDTokenizer:
def clone(self) -> SDTokenizerT:
sd_tokenizer = copy.copy(self)
# correctly copy additional vocab
sd_tokenizer.tokenizer = self.tokenizer_class.from_pretrained(self.tokenizer_path)
sd_tokenizer.tokenizer = self.tokenizer_class.from_pretrained(self.tokenizer_path, legacy=True)
sd_tokenizer.add_tokens(sd_tokenizer.additional_tokens)
return sd_tokenizer
@ -568,6 +568,10 @@ class SD1ClipModel(torch.nn.Module):
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, textmodel_json_config=textmodel_json_config, **kwargs))
self.dtypes = set()
if dtype is not None:
self.dtypes.add(dtype)
def set_clip_options(self, options):
getattr(self, self.clip).set_clip_options(options)

View File

@ -5,6 +5,7 @@ import comfy.t5
import torch
import os
import comfy.model_management
import logging
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
@ -43,42 +44,94 @@ class SD3Tokenizer:
return self.clip_g.untokenize(token_weight_pair)
class SD3ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
self.t5xxl = T5XXLModel(device=device, dtype=dtype)
self.dtypes = set()
if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
self.dtypes.add(dtype)
else:
self.clip_l = None
if clip_g:
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
self.dtypes.add(dtype)
else:
self.clip_g = None
if t5:
if dtype_t5 is None:
dtype_t5 = dtype
elif comfy.model_management.dtype_size(dtype_t5) > comfy.model_management.dtype_size(dtype):
dtype_t5 = dtype
if not comfy.model_management.supports_cast(device, dtype_t5):
dtype_t5 = dtype
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
self.dtypes.add(dtype_t5)
else:
self.t5xxl = None
logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5))
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
self.clip_g.set_clip_options(options)
self.t5xxl.set_clip_options(options)
if self.clip_l is not None:
self.clip_l.set_clip_options(options)
if self.clip_g is not None:
self.clip_g.set_clip_options(options)
if self.t5xxl is not None:
self.t5xxl.set_clip_options(options)
def reset_clip_options(self):
self.clip_g.reset_clip_options()
self.clip_l.reset_clip_options()
self.t5xxl.reset_clip_options()
if self.clip_l is not None:
self.clip_l.reset_clip_options()
if self.clip_g is not None:
self.clip_g.reset_clip_options()
if self.t5xxl is not None:
self.t5xxl.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
lg_out = None
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
lg_out = torch.cat([l_out, g_out], dim=-1)
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
out = lg_out
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
else:
pooled = torch.zeros((1, 1280 + 768), device=comfy.model_management.intermediate_device())
pooled = None
out = None
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
if lg_out is not None:
out = torch.cat([lg_out, t5_out], dim=-2)
else:
out = t5_out
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
if self.clip_l is not None:
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
else:
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
if self.clip_g is not None:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
if lg_out is not None:
lg_out = torch.cat([lg_out, g_out], dim=-1)
else:
lg_out = torch.nn.functional.pad(g_out, (768, 0))
else:
g_out = None
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
if lg_out is not None:
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
out = lg_out
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
if self.t5xxl is not None:
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
if lg_out is not None:
out = torch.cat([lg_out, t5_out], dim=-2)
else:
out = t5_out
if out is None:
out = torch.zeros((1, 77, 4096), device=comfy.model_management.intermediate_device())
if pooled is None:
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
return out, pooled
@ -89,3 +142,9 @@ class SD3ClipModel(torch.nn.Module):
return self.clip_l.load_sd(sd)
else:
return self.t5xxl.load_sd(sd)
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
return SD3ClipModel_

View File

@ -51,6 +51,7 @@ class SDXLClipModel(torch.nn.Module):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype)
self.dtypes = set([dtype])
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)

View File

@ -54,7 +54,7 @@ class SD15(supported_models_base.BASE):
replace_prefix = {"clip_l.": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def clip_target(self):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
class SD20(supported_models_base.BASE):
@ -97,7 +97,7 @@ class SD20(supported_models_base.BASE):
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
return state_dict
def clip_target(self):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
class SD21UnclipL(SD20):
@ -159,7 +159,7 @@ class SDXLRefiner(supported_models_base.BASE):
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g
def clip_target(self):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
class SDXL(supported_models_base.BASE):
@ -228,7 +228,7 @@ class SDXL(supported_models_base.BASE):
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g
def clip_target(self):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
class SSD1B(SDXL):
@ -299,7 +299,7 @@ class SVD_img2vid(supported_models_base.BASE):
out = model_base.SVD_img2vid(self, device=device)
return out
def clip_target(self):
def clip_target(self, state_dict={}):
return None
class SV3D_u(SVD_img2vid):
@ -365,7 +365,7 @@ class Stable_Zero123(supported_models_base.BASE):
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
return out
def clip_target(self):
def clip_target(self, state_dict={}):
return None
class SD_X4Upscaler(SD20):
@ -439,7 +439,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
out = model_base.StableCascade_C(self, device=device)
return out
def clip_target(self):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
class Stable_Cascade_B(Stable_Cascade_C):
@ -501,14 +501,28 @@ class SD3(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.SD3
text_encoder_key_prefix = ["text_encoders."] #TODO?
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SD3(self, device=device)
return out
def clip_target(self):
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.SD3ClipModel) #TODO?
def clip_target(self, state_dict={}):
clip_l = False
clip_g = False
t5 = False
dtype_t5 = None
pref = self.text_encoder_key_prefix[0]
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
clip_l = True
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
clip_g = True
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
if t5_key in state_dict:
t5 = True
dtype_t5 = state_dict[t5_key].dtype
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]

View File

@ -405,7 +405,10 @@ class SamplerCustom:
def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image):
latent = latent_image
latent_image = latent["samples"]
latent = latent.copy()
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
latent["samples"] = latent_image
if not add_noise:
noise = Noise_EmptyNoise().generate_noise(latent)
else:
@ -564,7 +567,9 @@ class SamplerCustomAdvanced:
def sample(self, noise, guider, sampler, sigmas, latent_image):
latent = latent_image
latent_image = latent["samples"]
latent = latent.copy()
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image)
latent["samples"] = latent_image
noise_mask = None
if "noise_mask" in latent:

View File

@ -135,7 +135,7 @@ class ModelSamplingSD3:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01}),
"shift": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step":0.01}),
}}
RETURN_TYPES = ("MODEL",)

View File

@ -182,6 +182,8 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
elif isinstance(model.model, model_base.SVD_img2vid):
metadata["modelspec.architecture"] = "stable-video-diffusion-img2vid-v1"
elif isinstance(model.model, comfy.model_base.SD3):
metadata["modelspec.architecture"] = "stable-diffusion-v3-medium" #TODO: other SD3 variants
else:
enable_modelspec = False

View File

@ -1,35 +1,42 @@
from comfy.cmd import folder_paths
import comfy.sd
import comfy.model_management
from comfy.nodes import base_nodes as nodes
import torch
import comfy.model_management
import comfy.sd
from comfy.cmd import folder_paths
from comfy.model_downloader import get_or_download, get_filename_list_with_downloadable, KNOWN_CLIP_MODELS
from comfy.nodes import base_nodes as nodes
class TripleCLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), "clip_name3": (folder_paths.get_filename_list("clip"), )
filename_list = get_filename_list_with_downloadable("clip", KNOWN_CLIP_MODELS)
return {"required": {"clip_name1": (filename_list,), "clip_name2": (filename_list,), "clip_name3": (filename_list,)
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2, clip_name3):
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip_path3 = folder_paths.get_full_path("clip", clip_name3)
clip_path1 = get_or_download("clip", clip_name1, KNOWN_CLIP_MODELS)
clip_path2 = get_or_download("clip", clip_name2, KNOWN_CLIP_MODELS)
clip_path3 = get_or_download("clip", clip_name3, KNOWN_CLIP_MODELS)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
class EmptySD3LatentImage:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
return {"required": {"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
@ -37,18 +44,20 @@ class EmptySD3LatentImage:
def generate(self, width, height, batch_size=1):
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609
return ({"samples":latent}, )
return ({"samples": latent},)
class CLIPTextEncodeSD3:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip": ("CLIP",),
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"empty_padding": (["none", "empty_prompt"], )
}}
"empty_padding": (["none", "empty_prompt"],)
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
@ -67,7 +76,7 @@ class CLIPTextEncodeSD3:
tokens["l"] = clip.tokenize(clip_l)["l"]
if len(t5xxl) == 0 and no_padding:
tokens["t5xxl"] = []
tokens["t5xxl"] = []
else:
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
if len(tokens["l"]) != len(tokens["g"]):
@ -77,7 +86,7 @@ class CLIPTextEncodeSD3:
while len(tokens["l"]) > len(tokens["g"]):
tokens["g"] += empty["g"]
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
return ([[cond, {"pooled_output": pooled}]], )
return ([[cond, {"pooled_output": pooled}]],)
NODE_CLASS_MAPPINGS = {

View File

@ -8,7 +8,7 @@ transformers>=4.29.1
peft
torchinfo
fschat[model_worker]
safetensors>=0.3.0
safetensors>=0.4.2
bitsandbytes
pytorch-lightning>=2.0.0
aiohttp>=3.8.4