mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-16 22:58:19 +08:00
Merge branch 'master' into feature/adapt-docker-resource-fetch
This commit is contained in:
commit
afa96152b8
@ -223,12 +223,19 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
del txt_k, img_k
|
del txt_k, img_k
|
||||||
v = torch.cat((txt_v, img_v), dim=2)
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
del txt_v, img_v
|
del txt_v, img_v
|
||||||
|
|
||||||
|
extra_options["img_slice"] = [txt.shape[1], q.shape[2]]
|
||||||
|
if "attn1_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn1_patch"]
|
||||||
|
for p in patch:
|
||||||
|
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
|
||||||
|
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
|
||||||
|
|
||||||
# run actual attention
|
# run actual attention
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
if "attn1_output_patch" in transformer_patches:
|
if "attn1_output_patch" in transformer_patches:
|
||||||
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
|
||||||
patch = transformer_patches["attn1_output_patch"]
|
patch = transformer_patches["attn1_output_patch"]
|
||||||
for p in patch:
|
for p in patch:
|
||||||
attn = p(attn, extra_options)
|
attn = p(attn, extra_options)
|
||||||
@ -321,6 +328,12 @@ class SingleStreamBlock(nn.Module):
|
|||||||
del qkv
|
del qkv
|
||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
|
if "attn1_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn1_patch"]
|
||||||
|
for p in patch:
|
||||||
|
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
|
||||||
|
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|||||||
@ -31,6 +31,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
|
|
||||||
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||||
|
if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]:
|
||||||
|
freqs_cis = freqs_cis[:, :, :x_.shape[2]]
|
||||||
|
|
||||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||||
|
|||||||
@ -170,7 +170,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
if "post_input" in patches:
|
if "post_input" in patches:
|
||||||
for p in patches["post_input"]:
|
for p in patches["post_input"]:
|
||||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
txt = out["txt"]
|
txt = out["txt"]
|
||||||
img_ids = out["img_ids"]
|
img_ids = out["img_ids"]
|
||||||
|
|||||||
@ -372,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
del s2
|
del s2
|
||||||
break
|
break
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except Exception as e:
|
||||||
|
model_management.raise_non_oom(e)
|
||||||
if first_op_done == False:
|
if first_op_done == False:
|
||||||
model_management.soft_empty_cache(True)
|
model_management.soft_empty_cache(True)
|
||||||
if cleared_cache == False:
|
if cleared_cache == False:
|
||||||
|
|||||||
@ -258,7 +258,8 @@ def slice_attention(q, k, v):
|
|||||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||||
del s2
|
del s2
|
||||||
break
|
break
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except Exception as e:
|
||||||
|
model_management.raise_non_oom(e)
|
||||||
model_management.soft_empty_cache(True)
|
model_management.soft_empty_cache(True)
|
||||||
steps *= 2
|
steps *= 2
|
||||||
if steps > 128:
|
if steps > 128:
|
||||||
@ -314,7 +315,8 @@ def pytorch_attention(q, k, v):
|
|||||||
try:
|
try:
|
||||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
out = out.transpose(2, 3).reshape(orig_shape)
|
out = out.transpose(2, 3).reshape(orig_shape)
|
||||||
except model_management.OOM_EXCEPTION:
|
except Exception as e:
|
||||||
|
model_management.raise_non_oom(e)
|
||||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||||
oom_fallback = True
|
oom_fallback = True
|
||||||
if oom_fallback:
|
if oom_fallback:
|
||||||
|
|||||||
@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
try:
|
try:
|
||||||
attn_probs = attn_scores.softmax(dim=-1)
|
attn_probs = attn_scores.softmax(dim=-1)
|
||||||
del attn_scores
|
del attn_scores
|
||||||
except model_management.OOM_EXCEPTION:
|
except Exception as e:
|
||||||
|
model_management.raise_non_oom(e)
|
||||||
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
|
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
|
||||||
torch.exp(attn_scores, out=attn_scores)
|
torch.exp(attn_scores, out=attn_scores)
|
||||||
|
|||||||
@ -99,6 +99,9 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||||
|
tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
|
||||||
|
if tp > 0 and not k.startswith("clip_"):
|
||||||
|
key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
|
||||||
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
clip_l_present = False
|
clip_l_present = False
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import comfy.memory_management
|
||||||
import comfy.supported_models
|
import comfy.supported_models
|
||||||
import comfy.supported_models_base
|
import comfy.supported_models_base
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@ -1118,8 +1119,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
|||||||
new[:old_weight.shape[0]] = old_weight
|
new[:old_weight.shape[0]] = old_weight
|
||||||
old_weight = new
|
old_weight = new
|
||||||
|
|
||||||
|
if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled:
|
||||||
|
old_weight = old_weight.clone()
|
||||||
|
|
||||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||||
else:
|
else:
|
||||||
|
if comfy.memory_management.aimdo_enabled:
|
||||||
|
weight = weight.clone()
|
||||||
old_weight = weight
|
old_weight = weight
|
||||||
w = weight
|
w = weight
|
||||||
w[:] = fun(weight)
|
w[:] = fun(weight)
|
||||||
|
|||||||
@ -292,6 +292,18 @@ try:
|
|||||||
except:
|
except:
|
||||||
OOM_EXCEPTION = Exception
|
OOM_EXCEPTION = Exception
|
||||||
|
|
||||||
|
def is_oom(e):
|
||||||
|
if isinstance(e, OOM_EXCEPTION):
|
||||||
|
return True
|
||||||
|
if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2:
|
||||||
|
discard_cuda_async_error()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def raise_non_oom(e):
|
||||||
|
if not is_oom(e):
|
||||||
|
raise e
|
||||||
|
|
||||||
XFORMERS_VERSION = ""
|
XFORMERS_VERSION = ""
|
||||||
XFORMERS_ENABLED_VAE = True
|
XFORMERS_ENABLED_VAE = True
|
||||||
if args.disable_xformers:
|
if args.disable_xformers:
|
||||||
|
|||||||
@ -954,7 +954,8 @@ class VAE:
|
|||||||
if pixel_samples is None:
|
if pixel_samples is None:
|
||||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
pixel_samples[x:x+batch_number] = out
|
pixel_samples[x:x+batch_number] = out
|
||||||
except model_management.OOM_EXCEPTION:
|
except Exception as e:
|
||||||
|
model_management.raise_non_oom(e)
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||||
#exception and the exception itself refs them all until we get out of this except block.
|
#exception and the exception itself refs them all until we get out of this except block.
|
||||||
@ -1029,7 +1030,8 @@ class VAE:
|
|||||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
samples[x:x + batch_number] = out
|
samples[x:x + batch_number] = out
|
||||||
|
|
||||||
except model_management.OOM_EXCEPTION:
|
except Exception as e:
|
||||||
|
model_management.raise_non_oom(e)
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||||
#exception and the exception itself refs them all until we get out of this except block.
|
#exception and the exception itself refs them all until we get out of this except block.
|
||||||
|
|||||||
@ -86,7 +86,8 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
|||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||||
oom = False
|
oom = False
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except Exception as e:
|
||||||
|
model_management.raise_non_oom(e)
|
||||||
tile //= 2
|
tile //= 2
|
||||||
if tile < 128:
|
if tile < 128:
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@ -612,7 +612,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
tips = ""
|
tips = ""
|
||||||
|
|
||||||
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
if comfy.model_management.is_oom(ex):
|
||||||
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
||||||
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
||||||
logging.error("Got an OOM, unloading all loaded models.")
|
logging.error("Got an OOM, unloading all loaded models.")
|
||||||
|
|||||||
13
main.py
13
main.py
@ -3,6 +3,7 @@ comfy.options.enable_args_parsing()
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
import shutil
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import time
|
import time
|
||||||
@ -64,8 +65,15 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
|
|
||||||
def handle_comfyui_manager_unavailable():
|
def handle_comfyui_manager_unavailable():
|
||||||
if not args.windows_standalone_build:
|
manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt")
|
||||||
logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n")
|
uv_available = shutil.which("uv") is not None
|
||||||
|
|
||||||
|
pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}"
|
||||||
|
msg = f"\n\nTo use the `--enable-manager` feature, the `comfyui-manager` package must be installed first.\ncommand:\n\t{pip_cmd}"
|
||||||
|
if uv_available:
|
||||||
|
msg += f"\nor using uv:\n\tuv pip install -r {manager_req_path}"
|
||||||
|
msg += "\n"
|
||||||
|
logging.warning(msg)
|
||||||
args.enable_manager = False
|
args.enable_manager = False
|
||||||
|
|
||||||
|
|
||||||
@ -173,7 +181,6 @@ execute_prestartup_script()
|
|||||||
|
|
||||||
# Main code
|
# Main code
|
||||||
import asyncio
|
import asyncio
|
||||||
import shutil
|
|
||||||
import threading
|
import threading
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
comfyui_manager==4.1b1
|
comfyui_manager==4.1b2
|
||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.39.19
|
comfyui-frontend-package==1.39.19
|
||||||
comfyui-workflow-templates==0.9.11
|
comfyui-workflow-templates==0.9.18
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user