This commit is contained in:
edikius 2023-03-06 14:07:00 +01:00
commit e75c662b35
9 changed files with 165 additions and 14 deletions

View File

@ -489,6 +489,8 @@ if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv:
if "--use-pytorch-cross-attention" in sys.argv:
print("Using pytorch cross attention")
torch.backends.cuda.enable_math_sdp(False)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
CrossAttention = CrossAttentionPytorch
else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
@ -497,6 +499,7 @@ else:
print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False):

View File

@ -7,6 +7,7 @@ from einops import rearrange
from typing import Optional, Any
from ldm.modules.attention import MemoryEfficientCrossAttention
import model_management
try:
import xformers
@ -199,12 +200,7 @@ class AttnBlock(nn.Module):
r1 = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
mem_free_total = model_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()

View File

@ -613,11 +613,7 @@ class T2IAdapter:
def load_t2i_adapter(ckpt_path, model=None):
t2i_data = load_torch_file(ckpt_path)
keys = t2i_data.keys()
if "style_embedding" in keys:
pass
# TODO
# model_ad = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
elif "body.0.in_conv.weight" in keys:
if "body.0.in_conv.weight" in keys:
cin = t2i_data['body.0.in_conv.weight'].shape[1]
model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
else:
@ -626,6 +622,26 @@ def load_t2i_adapter(ckpt_path, model=None):
model_ad.load_state_dict(t2i_data)
return T2IAdapter(model_ad, cin // 64)
class StyleModel:
def __init__(self, model, device="cpu"):
self.model = model
def get_cond(self, input):
return self.model(input.last_hidden_state)
def load_style_model(ckpt_path):
model_data = load_torch_file(ckpt_path)
keys = model_data.keys()
if "style_embedding" in keys:
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
else:
raise Exception("invalid style model {}".format(ckpt_path))
model.load_state_dict(model_data)
return StyleModel(model)
def load_clip(ckpt_path, embedding_directory=None):
clip_data = load_torch_file(ckpt_path)
config = {}

View File

@ -0,0 +1,32 @@
from transformers import CLIPVisionModel, CLIPVisionConfig, CLIPImageProcessor
from comfy.sd import load_torch_file
import os
class ClipVisionModel():
def __init__(self):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config.json")
config = CLIPVisionConfig.from_json_file(json_config)
self.model = CLIPVisionModel(config)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
do_normalize=True,
do_resize=True,
image_mean=[ 0.48145466,0.4578275,0.40821073],
image_std=[0.26862954,0.26130258,0.27577711],
resample=3, #bicubic
size=224)
def load_sd(self, sd):
self.model.load_state_dict(sd, strict=False)
def encode_image(self, image):
inputs = self.processor(images=[image[0]], return_tensors="pt")
outputs = self.model(**inputs)
return outputs
def load(ckpt_path):
clip_data = load_torch_file(ckpt_path)
clip = ClipVisionModel()
clip.load_sd(clip_data)
return clip

View File

@ -0,0 +1,23 @@
{
"_name_or_path": "openai/clip-vit-large-patch14",
"architectures": [
"CLIPVisionModel"
],
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"image_size": 224,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"torch_dtype": "float32",
"transformers_version": "4.24.0"
}

View File

@ -18,6 +18,8 @@ import comfy.samplers
import comfy.sd
import comfy.utils
import comfy_extras.clip_vision
import model_management
import importlib
@ -370,6 +372,76 @@ class CLIPLoader:
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory)
return (clip,)
class CLIPVisionLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
clip_dir = os.path.join(models_dir, "clip_vision")
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),
}}
RETURN_TYPES = ("CLIP_VISION",)
FUNCTION = "load_clip"
CATEGORY = "loaders"
def load_clip(self, clip_name):
clip_path = os.path.join(self.clip_dir, clip_name)
clip_vision = comfy_extras.clip_vision.load(clip_path)
return (clip_vision,)
class CLIPVisionEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"image": ("IMAGE",)
}}
RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
FUNCTION = "encode"
CATEGORY = "conditioning/style_model"
def encode(self, clip_vision, image):
output = clip_vision.encode_image(image)
return (output,)
class StyleModelLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
style_model_dir = os.path.join(models_dir, "style_models")
@classmethod
def INPUT_TYPES(s):
return {"required": { "style_model_name": (filter_files_extensions(recursive_search(s.style_model_dir), supported_pt_extensions), )}}
RETURN_TYPES = ("STYLE_MODEL",)
FUNCTION = "load_style_model"
CATEGORY = "loaders"
def load_style_model(self, style_model_name):
style_model_path = os.path.join(self.style_model_dir, style_model_name)
style_model = comfy.sd.load_style_model(style_model_path)
return (style_model,)
class StyleModelApply:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"style_model": ("STYLE_MODEL", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_stylemodel"
CATEGORY = "conditioning/style_model"
def apply_stylemodel(self, clip_vision_output, style_model, conditioning):
cond = style_model.get_cond(clip_vision_output)
c = []
for t in conditioning:
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
c.append(n)
return (c, )
class EmptyLatentImage:
def __init__(self, device="cpu"):
self.device = device
@ -419,7 +491,7 @@ class LatentRotate:
RETURN_TYPES = ("LATENT",)
FUNCTION = "rotate"
CATEGORY = "latent"
CATEGORY = "latent/transform"
def rotate(self, samples, rotation):
s = samples.copy()
@ -443,7 +515,7 @@ class LatentFlip:
RETURN_TYPES = ("LATENT",)
FUNCTION = "flip"
CATEGORY = "latent"
CATEGORY = "latent/transform"
def flip(self, samples, flip_method):
s = samples.copy()
@ -508,7 +580,7 @@ class LatentCrop:
RETURN_TYPES = ("LATENT",)
FUNCTION = "crop"
CATEGORY = "latent"
CATEGORY = "latent/transform"
def crop(self, samples, width, height, x, y):
s = samples.copy()
@ -866,10 +938,14 @@ NODE_CLASS_MAPPINGS = {
"LatentCrop": LatentCrop,
"LoraLoader": LoraLoader,
"CLIPLoader": CLIPLoader,
"CLIPVisionEncode": CLIPVisionEncode,
"StyleModelApply": StyleModelApply,
"ControlNetApply": ControlNetApply,
"ControlNetLoader": ControlNetLoader,
"DiffControlNetLoader": DiffControlNetLoader,
"T2IAdapterLoader": T2IAdapterLoader,
"StyleModelLoader": StyleModelLoader,
"CLIPVisionLoader": CLIPVisionLoader,
"VAEDecodeTiled": VAEDecodeTiled,
}

View File

@ -89,6 +89,11 @@
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_color_sd14v1.pth -P ./models/t2i_adapter/\n",
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_canny_sd14v1.pth -P ./models/t2i_adapter/\n",
"\n",
"# T2I Styles Model\n",
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_style_sd14v1.pth -P ./models/style_models/\n",
"\n",
"# CLIPVision model (needed for styles model)\n",
"#!wget -c https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin -O ./models/clip_vision/clip_vit14.bin\n",
"\n",
"\n",
"# ControlNet\n",