diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 59683f645..692952f32 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -489,6 +489,8 @@ if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv: if "--use-pytorch-cross-attention" in sys.argv: print("Using pytorch cross attention") torch.backends.cuda.enable_math_sdp(False) + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) CrossAttention = CrossAttentionPytorch else: print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") @@ -497,6 +499,7 @@ else: print("Using xformers cross attention") CrossAttention = MemoryEfficientCrossAttention + class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False): diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 6f0b41dce..01ab2ede9 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -7,6 +7,7 @@ from einops import rearrange from typing import Optional, Any from ldm.modules.attention import MemoryEfficientCrossAttention +import model_management try: import xformers @@ -199,12 +200,7 @@ class AttnBlock(nn.Module): r1 = torch.zeros_like(k, device=q.device) - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total = model_management.get_free_memory(q.device) gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() diff --git a/comfy/sd.py b/comfy/sd.py index 19722113a..eb4ea7938 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -613,11 +613,7 @@ class T2IAdapter: def load_t2i_adapter(ckpt_path, model=None): t2i_data = load_torch_file(ckpt_path) keys = t2i_data.keys() - if "style_embedding" in keys: - pass - # TODO - # model_ad = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) - elif "body.0.in_conv.weight" in keys: + if "body.0.in_conv.weight" in keys: cin = t2i_data['body.0.in_conv.weight'].shape[1] model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4) else: @@ -626,6 +622,26 @@ def load_t2i_adapter(ckpt_path, model=None): model_ad.load_state_dict(t2i_data) return T2IAdapter(model_ad, cin // 64) + +class StyleModel: + def __init__(self, model, device="cpu"): + self.model = model + + def get_cond(self, input): + return self.model(input.last_hidden_state) + + +def load_style_model(ckpt_path): + model_data = load_torch_file(ckpt_path) + keys = model_data.keys() + if "style_embedding" in keys: + model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) + else: + raise Exception("invalid style model {}".format(ckpt_path)) + model.load_state_dict(model_data) + return StyleModel(model) + + def load_clip(ckpt_path, embedding_directory=None): clip_data = load_torch_file(ckpt_path) config = {} diff --git a/comfy_extras/clip_vision.py b/comfy_extras/clip_vision.py new file mode 100644 index 000000000..58d79a83e --- /dev/null +++ b/comfy_extras/clip_vision.py @@ -0,0 +1,32 @@ +from transformers import CLIPVisionModel, CLIPVisionConfig, CLIPImageProcessor +from comfy.sd import load_torch_file +import os + +class ClipVisionModel(): + def __init__(self): + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config.json") + config = CLIPVisionConfig.from_json_file(json_config) + self.model = CLIPVisionModel(config) + self.processor = CLIPImageProcessor(crop_size=224, + do_center_crop=True, + do_convert_rgb=True, + do_normalize=True, + do_resize=True, + image_mean=[ 0.48145466,0.4578275,0.40821073], + image_std=[0.26862954,0.26130258,0.27577711], + resample=3, #bicubic + size=224) + + def load_sd(self, sd): + self.model.load_state_dict(sd, strict=False) + + def encode_image(self, image): + inputs = self.processor(images=[image[0]], return_tensors="pt") + outputs = self.model(**inputs) + return outputs + +def load(ckpt_path): + clip_data = load_torch_file(ckpt_path) + clip = ClipVisionModel() + clip.load_sd(clip_data) + return clip diff --git a/comfy_extras/clip_vision_config.json b/comfy_extras/clip_vision_config.json new file mode 100644 index 000000000..0e4db13d9 --- /dev/null +++ b/comfy_extras/clip_vision_config.json @@ -0,0 +1,23 @@ +{ + "_name_or_path": "openai/clip-vit-large-patch14", + "architectures": [ + "CLIPVisionModel" + ], + "attention_dropout": 0.0, + "dropout": 0.0, + "hidden_act": "quick_gelu", + "hidden_size": 1024, + "image_size": 224, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "model_type": "clip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "patch_size": 14, + "projection_dim": 768, + "torch_dtype": "float32", + "transformers_version": "4.24.0" +} diff --git a/models/clip_vision/put_clip_vision_models_here b/models/clip_vision/put_clip_vision_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/models/style_models/put_t2i_style_model_here b/models/style_models/put_t2i_style_model_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 26dad5729..84510a052 100644 --- a/nodes.py +++ b/nodes.py @@ -18,6 +18,8 @@ import comfy.samplers import comfy.sd import comfy.utils +import comfy_extras.clip_vision + import model_management import importlib @@ -370,6 +372,76 @@ class CLIPLoader: clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory) return (clip,) +class CLIPVisionLoader: + models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") + clip_dir = os.path.join(models_dir, "clip_vision") + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ), + }} + RETURN_TYPES = ("CLIP_VISION",) + FUNCTION = "load_clip" + + CATEGORY = "loaders" + + def load_clip(self, clip_name): + clip_path = os.path.join(self.clip_dir, clip_name) + clip_vision = comfy_extras.clip_vision.load(clip_path) + return (clip_vision,) + +class CLIPVisionEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "image": ("IMAGE",) + }} + RETURN_TYPES = ("CLIP_VISION_OUTPUT",) + FUNCTION = "encode" + + CATEGORY = "conditioning/style_model" + + def encode(self, clip_vision, image): + output = clip_vision.encode_image(image) + return (output,) + +class StyleModelLoader: + models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") + style_model_dir = os.path.join(models_dir, "style_models") + @classmethod + def INPUT_TYPES(s): + return {"required": { "style_model_name": (filter_files_extensions(recursive_search(s.style_model_dir), supported_pt_extensions), )}} + + RETURN_TYPES = ("STYLE_MODEL",) + FUNCTION = "load_style_model" + + CATEGORY = "loaders" + + def load_style_model(self, style_model_name): + style_model_path = os.path.join(self.style_model_dir, style_model_name) + style_model = comfy.sd.load_style_model(style_model_path) + return (style_model,) + + +class StyleModelApply: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "style_model": ("STYLE_MODEL", ), + "clip_vision_output": ("CLIP_VISION_OUTPUT", ), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "apply_stylemodel" + + CATEGORY = "conditioning/style_model" + + def apply_stylemodel(self, clip_vision_output, style_model, conditioning): + cond = style_model.get_cond(clip_vision_output) + c = [] + for t in conditioning: + n = [torch.cat((t[0], cond), dim=1), t[1].copy()] + c.append(n) + return (c, ) + class EmptyLatentImage: def __init__(self, device="cpu"): self.device = device @@ -419,7 +491,7 @@ class LatentRotate: RETURN_TYPES = ("LATENT",) FUNCTION = "rotate" - CATEGORY = "latent" + CATEGORY = "latent/transform" def rotate(self, samples, rotation): s = samples.copy() @@ -443,7 +515,7 @@ class LatentFlip: RETURN_TYPES = ("LATENT",) FUNCTION = "flip" - CATEGORY = "latent" + CATEGORY = "latent/transform" def flip(self, samples, flip_method): s = samples.copy() @@ -508,7 +580,7 @@ class LatentCrop: RETURN_TYPES = ("LATENT",) FUNCTION = "crop" - CATEGORY = "latent" + CATEGORY = "latent/transform" def crop(self, samples, width, height, x, y): s = samples.copy() @@ -866,10 +938,14 @@ NODE_CLASS_MAPPINGS = { "LatentCrop": LatentCrop, "LoraLoader": LoraLoader, "CLIPLoader": CLIPLoader, + "CLIPVisionEncode": CLIPVisionEncode, + "StyleModelApply": StyleModelApply, "ControlNetApply": ControlNetApply, "ControlNetLoader": ControlNetLoader, "DiffControlNetLoader": DiffControlNetLoader, "T2IAdapterLoader": T2IAdapterLoader, + "StyleModelLoader": StyleModelLoader, + "CLIPVisionLoader": CLIPVisionLoader, "VAEDecodeTiled": VAEDecodeTiled, } diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index 5315ab08e..06278b273 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -89,6 +89,11 @@ "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_color_sd14v1.pth -P ./models/t2i_adapter/\n", "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_canny_sd14v1.pth -P ./models/t2i_adapter/\n", "\n", + "# T2I Styles Model\n", + "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_style_sd14v1.pth -P ./models/style_models/\n", + "\n", + "# CLIPVision model (needed for styles model)\n", + "#!wget -c https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin -O ./models/clip_vision/clip_vit14.bin\n", "\n", "\n", "# ControlNet\n",