Merge branch 'comfyanonymous:master' into master

This commit is contained in:
ssitu 2023-06-19 00:04:05 -04:00 committed by GitHub
commit 7dbe4fce51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 29 additions and 11 deletions

View File

@ -21,7 +21,7 @@ class ClipVisionModel():
size=224)
def load_sd(self, sd):
self.model.load_state_dict(sd, strict=False)
return self.model.load_state_dict(sd, strict=False)
def encode_image(self, image):
img = torch.clip((255. * image[0]), 0, 255).round().int()
@ -59,7 +59,13 @@ def load_clipvision_from_sd(sd):
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
clip = ClipVisionModel(json_config)
clip.load_sd(sd)
m, u = clip.load_sd(sd)
u = set(u)
keys = list(sd.keys())
for k in keys:
if k not in u:
t = sd.pop(k)
del t
return clip
def load(ckpt_path):

View File

@ -260,7 +260,8 @@ class Gligen(nn.Module):
return r
return func_lowvram
else:
def func(key, x):
def func(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key]
return module(x, objs)
return func

View File

@ -134,7 +134,7 @@ class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for CompVis diffusion models."""
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod.float(), quantize=quantize)
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_eps(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)
@ -173,7 +173,7 @@ class CompVisVDenoiser(DiscreteVDDPMDenoiser):
"""A wrapper for CompVis diffusion models that output v."""
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod.float(), quantize=quantize)
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond)

View File

@ -524,9 +524,11 @@ class BasicTransformerBlock(nn.Module):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}):
current_index = None
extra_options = {}
if "current_index" in transformer_options:
current_index = transformer_options["current_index"]
extra_options["transformer_index"] = transformer_options["current_index"]
if "block_index" in transformer_options:
extra_options["block_index"] = transformer_options["block_index"]
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:
@ -545,7 +547,7 @@ class BasicTransformerBlock(nn.Module):
context_attn1 = n
value_attn1 = context_attn1
for p in patch:
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
@ -557,7 +559,7 @@ class BasicTransformerBlock(nn.Module):
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
x = p(current_index, x)
x = p(x, extra_options)
n = self.norm2(x)
@ -567,10 +569,15 @@ class BasicTransformerBlock(nn.Module):
patch = transformer_patches["attn2_patch"]
value_attn2 = context_attn2
for p in patch:
n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2)
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
n = self.attn2(n, context=context_attn2, value=value_attn2)
if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"]
for p in patch:
n = p(n, extra_options)
x += n
x = self.ff(self.norm3(x)) + x
return x
@ -631,6 +638,7 @@ class SpatialTransformer(nn.Module):
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear:
x = self.proj_out(x)

View File

@ -331,6 +331,9 @@ class ModelPatcher:
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:

View File

@ -68,7 +68,7 @@ def load_hypernetwork_patch(path, strength):
def __init__(self, hypernet, strength):
self.hypernet = hypernet
self.strength = strength
def __call__(self, current_index, q, k, v):
def __call__(self, q, k, v, extra_options):
dim = k.shape[-1]
if dim in self.hypernet:
hn = self.hypernet[dim]