From 51d54775794a9810bb5d9ac3f8fcf8d2b68581aa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 18 Jul 2023 00:25:53 -0400 Subject: [PATCH 1/2] Add key to indicate checkpoint is v_prediction when saving. --- comfy/model_base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index c73f2aa07..2d2d35814 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -99,6 +99,10 @@ class BaseModel(torch.nn.Module): if self.get_dtype() == torch.float16: clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16) vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16) + + if self.model_type == ModelType.V_PREDICTION: + unet_state_dict["v_pred"] = torch.tensor([]) + return {**unet_state_dict, **vae_state_dict, **clip_state_dict} From 9ba440995a41f3da266012e014d1ad55ad91a032 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 18 Jul 2023 21:36:35 -0400 Subject: [PATCH 2/2] It's actually possible to torch.compile the unet now. --- comfy/ldm/modules/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 2284bcbdb..1379b7704 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -36,7 +36,7 @@ def uniq(arr): def default(val, d): if exists(val): return val - return d() if isfunction(d) else d + return d def max_neg_value(t):