From 21f04fe632a5192d8f29f1bc0c852b24eb9dce2f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 14 Jun 2023 19:46:08 -0400 Subject: [PATCH 1/2] Disable default weight values in unet conv2d for faster loading. --- comfy/ldm/modules/diffusionmodules/util.py | 6 +++--- comfy/ops.py | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 82ea3f0a6..d6a4778e4 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -16,7 +16,7 @@ import numpy as np from einops import repeat from comfy.ldm.util import instantiate_from_config - +import comfy.ops def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if schedule == "linear": @@ -233,7 +233,7 @@ def conv_nd(dims, *args, **kwargs): if dims == 1: return nn.Conv1d(*args, **kwargs) elif dims == 2: - return nn.Conv2d(*args, **kwargs) + return comfy.ops.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") @@ -243,7 +243,7 @@ def linear(*args, **kwargs): """ Create a linear module. """ - return nn.Linear(*args, **kwargs) + return comfy.ops.Linear(*args, **kwargs) def avg_pool_nd(dims, *args, **kwargs): diff --git a/comfy/ops.py b/comfy/ops.py index 0654dbcd9..c39b994ab 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -15,3 +15,7 @@ class Linear(torch.nn.Module): def forward(self, input): return torch.nn.functional.linear(input, self.weight, self.bias) + +class Conv2d(torch.nn.Conv2d): + def reset_parameters(self): + return None From bb1f45d6e8628f8ee4970b294fa2771e5cdb34ba Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 14 Jun 2023 20:13:08 -0400 Subject: [PATCH 2/2] Properly disable weight initialization in clip models. --- comfy/clip_vision.py | 6 ++++-- comfy/ops.py | 11 +++++++++++ comfy/sd1_clip.py | 6 ++++-- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index a95707e46..7a59ef6e2 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -2,12 +2,14 @@ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPIm from .utils import load_torch_file, transformers_convert import os import torch +import comfy.ops class ClipVisionModel(): def __init__(self, json_config): config = CLIPVisionConfig.from_json_file(json_config) - with modeling_utils.no_init_weights(): - self.model = CLIPVisionModelWithProjection(config) + with comfy.ops.use_comfy_ops(): + with modeling_utils.no_init_weights(): + self.model = CLIPVisionModelWithProjection(config) self.processor = CLIPImageProcessor(crop_size=224, do_center_crop=True, do_convert_rgb=True, diff --git a/comfy/ops.py b/comfy/ops.py index c39b994ab..2e72030bd 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1,4 +1,5 @@ import torch +from contextlib import contextmanager class Linear(torch.nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool = True, @@ -19,3 +20,13 @@ class Linear(torch.nn.Module): class Conv2d(torch.nn.Conv2d): def reset_parameters(self): return None + + +@contextmanager +def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way + old_torch_nn_linear = torch.nn.Linear + torch.nn.Linear = Linear + try: + yield + finally: + torch.nn.Linear = old_torch_nn_linear diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 0df3d9d91..c2d4df092 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -1,6 +1,7 @@ import os from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils +import comfy.ops import torch import traceback import zipfile @@ -38,8 +39,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") config = CLIPTextConfig.from_json_file(textmodel_json_config) - with modeling_utils.no_init_weights(): - self.transformer = CLIPTextModel(config) + with comfy.ops.use_comfy_ops(): + with modeling_utils.no_init_weights(): + self.transformer = CLIPTextModel(config) self.device = device self.max_length = max_length