Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2025-02-27 13:06:17 +03:00 committed by GitHub
commit c4fb9f2a63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 47 additions and 5 deletions

View File

@ -421,7 +421,7 @@ class WanModel(torch.nn.Module):
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context
context = self.text_embedding(torch.cat([context, context.new_zeros(context.size(0), self.text_len - context.size(1), context.size(2))], dim=1))
context = self.text_embedding(context)
if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim

View File

@ -96,6 +96,13 @@ try:
except:
npu_available = False
try:
import torch_mlu # noqa: F401
_ = torch.mlu.device_count()
mlu_available = torch.mlu.is_available()
except:
mlu_available = False
if args.cpu:
cpu_state = CPUState.CPU
@ -113,6 +120,12 @@ def is_ascend_npu():
return True
return False
def is_mlu():
global mlu_available
if mlu_available:
return True
return False
def get_torch_device():
global directml_enabled
global cpu_state
@ -128,6 +141,8 @@ def get_torch_device():
return torch.device("xpu", torch.xpu.current_device())
elif is_ascend_npu():
return torch.device("npu", torch.npu.current_device())
elif is_mlu():
return torch.device("mlu", torch.mlu.current_device())
else:
return torch.device(torch.cuda.current_device())
@ -154,6 +169,12 @@ def get_total_memory(dev=None, torch_total_too=False):
_, mem_total_npu = torch.npu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_npu
elif is_mlu():
stats = torch.mlu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_mlu
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@ -233,7 +254,7 @@ try:
if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if is_intel_xpu() or is_ascend_npu():
if is_intel_xpu() or is_ascend_npu() or is_mlu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
except:
@ -317,6 +338,8 @@ def get_torch_device_name(device):
return "{} {}".format(device, torch.xpu.get_device_name(device))
elif is_ascend_npu():
return "{} {}".format(device, torch.npu.get_device_name(device))
elif is_mlu():
return "{} {}".format(device, torch.mlu.get_device_name(device))
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
@ -906,6 +929,8 @@ def xformers_enabled():
return False
if is_ascend_npu():
return False
if is_mlu():
return False
if directml_enabled:
return False
return XFORMERS_IS_AVAILABLE
@ -937,6 +962,8 @@ def pytorch_attention_flash_attention():
return True
if is_ascend_npu():
return True
if is_mlu():
return True
if is_amd():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
return False
@ -985,6 +1012,13 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_npu, _ = torch.npu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_npu + mem_free_torch
elif is_mlu():
stats = torch.mlu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_mlu + mem_free_torch
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
@ -1054,6 +1088,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_ascend_npu():
return True
if is_mlu():
return True
if torch.version.hip:
return True
@ -1122,6 +1159,11 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
props = torch.cuda.get_device_properties(device)
if is_mlu():
if props.major > 3:
return True
if props.major >= 8:
return True

View File

@ -11,7 +11,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel):
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0)
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.17"
__version__ = "0.3.18"

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.17"
version = "0.3.18"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"