mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Merge branch 'comfyanonymous:master' into refactor/onprompt
This commit is contained in:
commit
548392e4c2
@ -21,7 +21,7 @@ class ClipVisionModel():
|
|||||||
size=224)
|
size=224)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
self.model.load_state_dict(sd, strict=False)
|
return self.model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
def encode_image(self, image):
|
def encode_image(self, image):
|
||||||
img = torch.clip((255. * image[0]), 0, 255).round().int()
|
img = torch.clip((255. * image[0]), 0, 255).round().int()
|
||||||
@ -59,7 +59,13 @@ def load_clipvision_from_sd(sd):
|
|||||||
else:
|
else:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||||
clip = ClipVisionModel(json_config)
|
clip = ClipVisionModel(json_config)
|
||||||
clip.load_sd(sd)
|
m, u = clip.load_sd(sd)
|
||||||
|
u = set(u)
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
if k not in u:
|
||||||
|
t = sd.pop(k)
|
||||||
|
del t
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load(ckpt_path):
|
def load(ckpt_path):
|
||||||
|
|||||||
@ -260,7 +260,8 @@ class Gligen(nn.Module):
|
|||||||
return r
|
return r
|
||||||
return func_lowvram
|
return func_lowvram
|
||||||
else:
|
else:
|
||||||
def func(key, x):
|
def func(x, extra_options):
|
||||||
|
key = extra_options["transformer_index"]
|
||||||
module = self.module_list[key]
|
module = self.module_list[key]
|
||||||
return module(x, objs)
|
return module(x, objs)
|
||||||
return func
|
return func
|
||||||
|
|||||||
@ -134,7 +134,7 @@ class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
|||||||
"""A wrapper for CompVis diffusion models."""
|
"""A wrapper for CompVis diffusion models."""
|
||||||
|
|
||||||
def __init__(self, model, quantize=False, device='cpu'):
|
def __init__(self, model, quantize=False, device='cpu'):
|
||||||
super().__init__(model, model.alphas_cumprod.float(), quantize=quantize)
|
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||||
|
|
||||||
def get_eps(self, *args, **kwargs):
|
def get_eps(self, *args, **kwargs):
|
||||||
return self.inner_model.apply_model(*args, **kwargs)
|
return self.inner_model.apply_model(*args, **kwargs)
|
||||||
@ -173,7 +173,7 @@ class CompVisVDenoiser(DiscreteVDDPMDenoiser):
|
|||||||
"""A wrapper for CompVis diffusion models that output v."""
|
"""A wrapper for CompVis diffusion models that output v."""
|
||||||
|
|
||||||
def __init__(self, model, quantize=False, device='cpu'):
|
def __init__(self, model, quantize=False, device='cpu'):
|
||||||
super().__init__(model, model.alphas_cumprod.float(), quantize=quantize)
|
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||||
|
|
||||||
def get_v(self, x, t, cond, **kwargs):
|
def get_v(self, x, t, cond, **kwargs):
|
||||||
return self.inner_model.apply_model(x, t, cond)
|
return self.inner_model.apply_model(x, t, cond)
|
||||||
|
|||||||
@ -524,9 +524,11 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||||
|
|
||||||
def _forward(self, x, context=None, transformer_options={}):
|
def _forward(self, x, context=None, transformer_options={}):
|
||||||
current_index = None
|
extra_options = {}
|
||||||
if "current_index" in transformer_options:
|
if "current_index" in transformer_options:
|
||||||
current_index = transformer_options["current_index"]
|
extra_options["transformer_index"] = transformer_options["current_index"]
|
||||||
|
if "block_index" in transformer_options:
|
||||||
|
extra_options["block_index"] = transformer_options["block_index"]
|
||||||
if "patches" in transformer_options:
|
if "patches" in transformer_options:
|
||||||
transformer_patches = transformer_options["patches"]
|
transformer_patches = transformer_options["patches"]
|
||||||
else:
|
else:
|
||||||
@ -545,7 +547,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
context_attn1 = n
|
context_attn1 = n
|
||||||
value_attn1 = context_attn1
|
value_attn1 = context_attn1
|
||||||
for p in patch:
|
for p in patch:
|
||||||
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
|
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
|
||||||
|
|
||||||
if "tomesd" in transformer_options:
|
if "tomesd" in transformer_options:
|
||||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||||
@ -557,7 +559,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
if "middle_patch" in transformer_patches:
|
if "middle_patch" in transformer_patches:
|
||||||
patch = transformer_patches["middle_patch"]
|
patch = transformer_patches["middle_patch"]
|
||||||
for p in patch:
|
for p in patch:
|
||||||
x = p(current_index, x)
|
x = p(x, extra_options)
|
||||||
|
|
||||||
n = self.norm2(x)
|
n = self.norm2(x)
|
||||||
|
|
||||||
@ -567,10 +569,15 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
patch = transformer_patches["attn2_patch"]
|
patch = transformer_patches["attn2_patch"]
|
||||||
value_attn2 = context_attn2
|
value_attn2 = context_attn2
|
||||||
for p in patch:
|
for p in patch:
|
||||||
n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2)
|
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
||||||
|
|
||||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||||
|
|
||||||
|
if "attn2_output_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn2_output_patch"]
|
||||||
|
for p in patch:
|
||||||
|
n = p(n, extra_options)
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
x = self.ff(self.norm3(x)) + x
|
x = self.ff(self.norm3(x)) + x
|
||||||
return x
|
return x
|
||||||
@ -631,6 +638,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
if self.use_linear:
|
if self.use_linear:
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
x = block(x, context=context[i], transformer_options=transformer_options)
|
x = block(x, context=context[i], transformer_options=transformer_options)
|
||||||
if self.use_linear:
|
if self.use_linear:
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
|
|||||||
@ -331,6 +331,9 @@ class ModelPatcher:
|
|||||||
def set_model_attn2_patch(self, patch):
|
def set_model_attn2_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn2_patch")
|
self.set_model_patch(patch, "attn2_patch")
|
||||||
|
|
||||||
|
def set_model_attn2_output_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn2_output_patch")
|
||||||
|
|
||||||
def model_patches_to(self, device):
|
def model_patches_to(self, device):
|
||||||
to = self.model_options["transformer_options"]
|
to = self.model_options["transformer_options"]
|
||||||
if "patches" in to:
|
if "patches" in to:
|
||||||
|
|||||||
@ -68,7 +68,7 @@ def load_hypernetwork_patch(path, strength):
|
|||||||
def __init__(self, hypernet, strength):
|
def __init__(self, hypernet, strength):
|
||||||
self.hypernet = hypernet
|
self.hypernet = hypernet
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
def __call__(self, current_index, q, k, v):
|
def __call__(self, q, k, v, extra_options):
|
||||||
dim = k.shape[-1]
|
dim = k.shape[-1]
|
||||||
if dim in self.hypernet:
|
if dim in self.hypernet:
|
||||||
hn = self.hypernet[dim]
|
hn = self.hypernet[dim]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user