diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 7a59ef6e2..2036175b8 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -21,7 +21,7 @@ class ClipVisionModel(): size=224) def load_sd(self, sd): - self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False) def encode_image(self, image): img = torch.clip((255. * image[0]), 0, 255).round().int() @@ -59,7 +59,13 @@ def load_clipvision_from_sd(sd): else: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") clip = ClipVisionModel(json_config) - clip.load_sd(sd) + m, u = clip.load_sd(sd) + u = set(u) + keys = list(sd.keys()) + for k in keys: + if k not in u: + t = sd.pop(k) + del t return clip def load(ckpt_path): diff --git a/comfy/gligen.py b/comfy/gligen.py index 8c7cb432e..fe3895c48 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -260,7 +260,8 @@ class Gligen(nn.Module): return r return func_lowvram else: - def func(key, x): + def func(x, extra_options): + key = extra_options["transformer_index"] module = self.module_list[key] return module(x, objs) return func diff --git a/comfy/k_diffusion/external.py b/comfy/k_diffusion/external.py index d15eb5951..49ce5ae39 100644 --- a/comfy/k_diffusion/external.py +++ b/comfy/k_diffusion/external.py @@ -134,7 +134,7 @@ class CompVisDenoiser(DiscreteEpsDDPMDenoiser): """A wrapper for CompVis diffusion models.""" def __init__(self, model, quantize=False, device='cpu'): - super().__init__(model, model.alphas_cumprod.float(), quantize=quantize) + super().__init__(model, model.alphas_cumprod, quantize=quantize) def get_eps(self, *args, **kwargs): return self.inner_model.apply_model(*args, **kwargs) @@ -173,7 +173,7 @@ class CompVisVDenoiser(DiscreteVDDPMDenoiser): """A wrapper for CompVis diffusion models that output v.""" def __init__(self, model, quantize=False, device='cpu'): - super().__init__(model, model.alphas_cumprod.float(), quantize=quantize) + super().__init__(model, model.alphas_cumprod, quantize=quantize) def get_v(self, x, t, cond, **kwargs): return self.inner_model.apply_model(x, t, cond) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 62707dfd2..a0d695693 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -524,9 +524,11 @@ class BasicTransformerBlock(nn.Module): return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) def _forward(self, x, context=None, transformer_options={}): - current_index = None + extra_options = {} if "current_index" in transformer_options: - current_index = transformer_options["current_index"] + extra_options["transformer_index"] = transformer_options["current_index"] + if "block_index" in transformer_options: + extra_options["block_index"] = transformer_options["block_index"] if "patches" in transformer_options: transformer_patches = transformer_options["patches"] else: @@ -545,7 +547,7 @@ class BasicTransformerBlock(nn.Module): context_attn1 = n value_attn1 = context_attn1 for p in patch: - n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1) + n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options) if "tomesd" in transformer_options: m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) @@ -557,7 +559,7 @@ class BasicTransformerBlock(nn.Module): if "middle_patch" in transformer_patches: patch = transformer_patches["middle_patch"] for p in patch: - x = p(current_index, x) + x = p(x, extra_options) n = self.norm2(x) @@ -567,10 +569,15 @@ class BasicTransformerBlock(nn.Module): patch = transformer_patches["attn2_patch"] value_attn2 = context_attn2 for p in patch: - n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2) + n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) n = self.attn2(n, context=context_attn2, value=value_attn2) + if "attn2_output_patch" in transformer_patches: + patch = transformer_patches["attn2_output_patch"] + for p in patch: + n = p(n, extra_options) + x += n x = self.ff(self.norm3(x)) + x return x @@ -631,6 +638,7 @@ class SpatialTransformer(nn.Module): if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): + transformer_options["block_index"] = i x = block(x, context=context[i], transformer_options=transformer_options) if self.use_linear: x = self.proj_out(x) diff --git a/comfy/sd.py b/comfy/sd.py index 7f04ae3a7..e6cda5131 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -331,6 +331,9 @@ class ModelPatcher: def set_model_attn2_patch(self, patch): self.set_model_patch(patch, "attn2_patch") + def set_model_attn2_output_patch(self, patch): + self.set_model_patch(patch, "attn2_output_patch") + def model_patches_to(self, device): to = self.model_options["transformer_options"] if "patches" in to: diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index c19b5e4c7..d16c49aeb 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -68,7 +68,7 @@ def load_hypernetwork_patch(path, strength): def __init__(self, hypernet, strength): self.hypernet = hypernet self.strength = strength - def __call__(self, current_index, q, k, v): + def __call__(self, q, k, v, extra_options): dim = k.shape[-1] if dim in self.hypernet: hn = self.hypernet[dim]