Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2024-08-20 23:07:54 +03:00 committed by GitHub
commit 5656b5b956
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 121 additions and 19 deletions

View File

@ -123,6 +123,7 @@ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")

View File

@ -88,10 +88,11 @@ class CLIPTextModel_(torch.nn.Module):
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
num_positions = config_dict["max_position_embeddings"]
self.eos_token_id = config_dict["eos_token_id"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations)
self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
@ -123,7 +124,6 @@ class CLIPTextModel(torch.nn.Module):
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
embed_dim = config_dict["hidden_size"]
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
def get_input_embeddings(self):

View File

@ -318,7 +318,7 @@ def model_lora_keys_unet(model, key_map={}):
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
key_map[key_lora] = to
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
return key_map

View File

@ -96,10 +96,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
if self.manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)

View File

@ -993,7 +993,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if torch.version.hip:
return True
props = torch.cuda.get_device_properties("cuda")
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
@ -1049,7 +1049,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_intel_xpu():
return True
props = torch.cuda.get_device_properties("cuda")
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
@ -1062,6 +1062,16 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
def supports_fp8_compute(device=None):
props = torch.cuda.get_device_properties(device)
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False
return True
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:

View File

@ -18,9 +18,11 @@
import torch
import comfy.model_management
from comfy.cli_args import args
def cast_to(weight, dtype=None, device=None, non_blocking=False):
if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
return weight
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r
@ -240,3 +242,42 @@ class manual_cast(disable_weight_init):
class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True
def fp8_linear(self, input):
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None
if len(input.shape) == 3:
out = torch.empty((input.shape[0], input.shape[1], self.weight.shape[0]), device=input.device, dtype=input.dtype)
inn = input.to(dtype)
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
for i in range(input.shape[0]):
if self.bias is not None:
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking))
else:
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype)
out[i] = o
return out
return None
class fp8_ops(manual_cast):
class Linear(manual_cast.Linear):
def forward_comfy_cast_weights(self, input):
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None):
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
if args.fast:
if comfy.model_management.supports_fp8_compute(load_device):
return fp8_ops
return manual_cast

View File

@ -24,6 +24,7 @@ import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.long_clipl
import comfy.model_patcher
import comfy.lora
@ -443,8 +444,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel
clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2:
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)

View File

@ -75,7 +75,6 @@ class ClipTokenWeightEncoder:
return r
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
"pooled",
@ -556,8 +555,12 @@ class SD1Tokenizer:
def state_dict(self):
return {}
class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, name=None, **kwargs):
super().__init__()
if name is not None:

View File

@ -10,7 +10,7 @@ class SDXLClipG(sd1_clip.SDClipModel):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, model_options=model_options)
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
def load_sd(self, sd):
return super().load_sd(sd)
@ -82,7 +82,7 @@ class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, model_options=model_options)
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)
def load_sd(self, sd):
return super().load_sd(sd)

View File

@ -0,0 +1,25 @@
{
"_name_or_path": "openai/clip-vit-large-patch14",
"architectures": [
"CLIPTextModel"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 49407,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 248,
"model_type": "clip_text_model",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"projection_dim": 768,
"torch_dtype": "float32",
"transformers_version": "4.24.0",
"vocab_size": 49408
}

View File

@ -0,0 +1,19 @@
from comfy import sd1_clip
import os
class LongClipTokenizer_(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
class LongClipModel_(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
class LongClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)

View File

@ -8,7 +8,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, model_options=model_options)
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=True, model_options=model_options)
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):

View File

@ -15,7 +15,7 @@ class T5XXLModel(sd1_clip.SDClipModel):
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
class SD3Tokenizer: