diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 77009c912..c98d4dfaf 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -123,6 +123,7 @@ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") +parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 3c67b737a..42cdc4f6e 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -88,10 +88,11 @@ class CLIPTextModel_(torch.nn.Module): heads = config_dict["num_attention_heads"] intermediate_size = config_dict["intermediate_size"] intermediate_activation = config_dict["hidden_act"] + num_positions = config_dict["max_position_embeddings"] self.eos_token_id = config_dict["eos_token_id"] super().__init__() - self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations) + self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations) self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) @@ -123,7 +124,6 @@ class CLIPTextModel(torch.nn.Module): self.text_model = CLIPTextModel_(config_dict, dtype, device, operations) embed_dim = config_dict["hidden_size"] self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) - self.text_projection.weight.copy_(torch.eye(embed_dim)) self.dtype = dtype def get_input_embeddings(self): diff --git a/comfy/lora.py b/comfy/lora.py index 3b8b6c162..7cf056d1a 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -318,7 +318,7 @@ def model_lora_keys_unet(model, key_map={}): for k in diffusers_keys: if k.endswith(".weight"): to = diffusers_keys[k] - key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format - key_map[key_lora] = to + key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format + key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris return key_map diff --git a/comfy/model_base.py b/comfy/model_base.py index 830bcc68c..9bfdb3b3e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -96,10 +96,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: - if self.manual_cast_dtype is not None: - operations = comfy.ops.manual_cast - else: - operations = comfy.ops.disable_weight_init + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) diff --git a/comfy/model_management.py b/comfy/model_management.py index c6fbd34cc..ddc79c818 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -993,7 +993,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if torch.version.hip: return True - props = torch.cuda.get_device_properties("cuda") + props = torch.cuda.get_device_properties(device) if props.major >= 8: return True @@ -1049,7 +1049,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_intel_xpu(): return True - props = torch.cuda.get_device_properties("cuda") + props = torch.cuda.get_device_properties(device) if props.major >= 8: return True @@ -1062,6 +1062,16 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False +def supports_fp8_compute(device=None): + props = torch.cuda.get_device_properties(device) + if props.major >= 9: + return True + if props.major < 8: + return False + if props.minor < 9: + return False + return True + def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: diff --git a/comfy/ops.py b/comfy/ops.py index ce132911c..fc78dd830 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -18,9 +18,11 @@ import torch import comfy.model_management - +from comfy.cli_args import args def cast_to(weight, dtype=None, device=None, non_blocking=False): + if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device): + return weight r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) return r @@ -240,3 +242,42 @@ class manual_cast(disable_weight_init): class Embedding(disable_weight_init.Embedding): comfy_cast_weights = True + + +def fp8_linear(self, input): + dtype = self.weight.dtype + if dtype not in [torch.float8_e4m3fn]: + return None + + if len(input.shape) == 3: + out = torch.empty((input.shape[0], input.shape[1], self.weight.shape[0]), device=input.device, dtype=input.dtype) + inn = input.to(dtype) + non_blocking = comfy.model_management.device_supports_non_blocking(input.device) + w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t() + for i in range(input.shape[0]): + if self.bias is not None: + o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking)) + else: + o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype) + out[i] = o + return out + return None + +class fp8_ops(manual_cast): + class Linear(manual_cast.Linear): + def forward_comfy_cast_weights(self, input): + out = fp8_linear(self, input) + if out is not None: + return out + + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.linear(input, weight, bias) + + +def pick_operations(weight_dtype, compute_dtype, load_device=None): + if compute_dtype is None or weight_dtype == compute_dtype: + return disable_weight_init + if args.fast: + if comfy.model_management.supports_fp8_compute(load_device): + return fp8_ops + return manual_cast diff --git a/comfy/sd.py b/comfy/sd.py index 86ebba2b5..1e1a6594b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -24,6 +24,7 @@ import comfy.text_encoders.sa_t5 import comfy.text_encoders.aura_t5 import comfy.text_encoders.hydit import comfy.text_encoders.flux +import comfy.text_encoders.long_clipl import comfy.model_patcher import comfy.lora @@ -443,8 +444,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer else: - clip_target.clip = sd1_clip.SD1ClipModel - clip_target.tokenizer = sd1_clip.SD1Tokenizer + w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None) + if w is not None and w.shape[0] == 248: + clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel + clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer + else: + clip_target.clip = sd1_clip.SD1ClipModel + clip_target.tokenizer = sd1_clip.SD1Tokenizer elif len(clip_data) == 2: if clip_type == CLIPType.SD3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index dc8413b7b..676653f77 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -75,7 +75,6 @@ class ClipTokenWeightEncoder: return r class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): - """Uses the CLIP transformer encoder for text (from huggingface)""" LAYERS = [ "last", "pooled", @@ -556,8 +555,12 @@ class SD1Tokenizer: def state_dict(self): return {} +class SD1CheckpointClipModel(SDClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options) + class SD1ClipModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SDClipModel, name=None, **kwargs): + def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, name=None, **kwargs): super().__init__() if name is not None: diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 860900ccd..a0145caa4 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -10,7 +10,7 @@ class SDXLClipG(sd1_clip.SDClipModel): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, - special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, model_options=model_options) + special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options) def load_sd(self, sd): return super().load_sd(sd) @@ -82,7 +82,7 @@ class StableCascadeClipG(sd1_clip.SDClipModel): def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, - special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, model_options=model_options) + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options) def load_sd(self, sd): return super().load_sd(sd) diff --git a/comfy/text_encoders/long_clipl.json b/comfy/text_encoders/long_clipl.json new file mode 100644 index 000000000..5e2056ff3 --- /dev/null +++ b/comfy/text_encoders/long_clipl.json @@ -0,0 +1,25 @@ +{ + "_name_or_path": "openai/clip-vit-large-patch14", + "architectures": [ + "CLIPTextModel" + ], + "attention_dropout": 0.0, + "bos_token_id": 0, + "dropout": 0.0, + "eos_token_id": 49407, + "hidden_act": "quick_gelu", + "hidden_size": 768, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "max_position_embeddings": 248, + "model_type": "clip_text_model", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 1, + "projection_dim": 768, + "torch_dtype": "float32", + "transformers_version": "4.24.0", + "vocab_size": 49408 +} diff --git a/comfy/text_encoders/long_clipl.py b/comfy/text_encoders/long_clipl.py new file mode 100644 index 000000000..4677fb3b0 --- /dev/null +++ b/comfy/text_encoders/long_clipl.py @@ -0,0 +1,19 @@ +from comfy import sd1_clip +import os + +class LongClipTokenizer_(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + +class LongClipModel_(sd1_clip.SDClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json") + super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options) + +class LongClipTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_) + +class LongClipModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): + super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs) diff --git a/comfy/text_encoders/sd2_clip.py b/comfy/text_encoders/sd2_clip.py index 0c98cd853..31fc89869 100644 --- a/comfy/text_encoders/sd2_clip.py +++ b/comfy/text_encoders/sd2_clip.py @@ -8,7 +8,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel): layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, model_options=model_options) + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=True, model_options=model_options) class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index 83e8fa1f3..e3832ac24 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -15,7 +15,7 @@ class T5XXLModel(sd1_clip.SDClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) class SD3Tokenizer: