Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2025-02-27 13:06:17 +03:00 committed by GitHub
commit c4fb9f2a63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 47 additions and 5 deletions

View File

@ -421,7 +421,7 @@ class WanModel(torch.nn.Module):
e0 = self.time_projection(e).unflatten(1, (6, self.dim)) e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context # context
context = self.text_embedding(torch.cat([context, context.new_zeros(context.size(0), self.text_len - context.size(1), context.size(2))], dim=1)) context = self.text_embedding(context)
if clip_fea is not None and self.img_emb is not None: if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim context_clip = self.img_emb(clip_fea) # bs x 257 x dim

View File

@ -96,6 +96,13 @@ try:
except: except:
npu_available = False npu_available = False
try:
import torch_mlu # noqa: F401
_ = torch.mlu.device_count()
mlu_available = torch.mlu.is_available()
except:
mlu_available = False
if args.cpu: if args.cpu:
cpu_state = CPUState.CPU cpu_state = CPUState.CPU
@ -113,6 +120,12 @@ def is_ascend_npu():
return True return True
return False return False
def is_mlu():
global mlu_available
if mlu_available:
return True
return False
def get_torch_device(): def get_torch_device():
global directml_enabled global directml_enabled
global cpu_state global cpu_state
@ -128,6 +141,8 @@ def get_torch_device():
return torch.device("xpu", torch.xpu.current_device()) return torch.device("xpu", torch.xpu.current_device())
elif is_ascend_npu(): elif is_ascend_npu():
return torch.device("npu", torch.npu.current_device()) return torch.device("npu", torch.npu.current_device())
elif is_mlu():
return torch.device("mlu", torch.mlu.current_device())
else: else:
return torch.device(torch.cuda.current_device()) return torch.device(torch.cuda.current_device())
@ -154,6 +169,12 @@ def get_total_memory(dev=None, torch_total_too=False):
_, mem_total_npu = torch.npu.mem_get_info(dev) _, mem_total_npu = torch.npu.mem_get_info(dev)
mem_total_torch = mem_reserved mem_total_torch = mem_reserved
mem_total = mem_total_npu mem_total = mem_total_npu
elif is_mlu():
stats = torch.mlu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_mlu
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
@ -233,7 +254,7 @@ try:
if torch_version_numeric[0] >= 2: if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if is_intel_xpu() or is_ascend_npu(): if is_intel_xpu() or is_ascend_npu() or is_mlu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
except: except:
@ -317,6 +338,8 @@ def get_torch_device_name(device):
return "{} {}".format(device, torch.xpu.get_device_name(device)) return "{} {}".format(device, torch.xpu.get_device_name(device))
elif is_ascend_npu(): elif is_ascend_npu():
return "{} {}".format(device, torch.npu.get_device_name(device)) return "{} {}".format(device, torch.npu.get_device_name(device))
elif is_mlu():
return "{} {}".format(device, torch.mlu.get_device_name(device))
else: else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
@ -906,6 +929,8 @@ def xformers_enabled():
return False return False
if is_ascend_npu(): if is_ascend_npu():
return False return False
if is_mlu():
return False
if directml_enabled: if directml_enabled:
return False return False
return XFORMERS_IS_AVAILABLE return XFORMERS_IS_AVAILABLE
@ -937,6 +962,8 @@ def pytorch_attention_flash_attention():
return True return True
if is_ascend_npu(): if is_ascend_npu():
return True return True
if is_mlu():
return True
if is_amd(): if is_amd():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
return False return False
@ -985,6 +1012,13 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_npu, _ = torch.npu.mem_get_info(dev) mem_free_npu, _ = torch.npu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_npu + mem_free_torch mem_free_total = mem_free_npu + mem_free_torch
elif is_mlu():
stats = torch.mlu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_mlu + mem_free_torch
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current'] mem_active = stats['active_bytes.all.current']
@ -1054,6 +1088,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_ascend_npu(): if is_ascend_npu():
return True return True
if is_mlu():
return True
if torch.version.hip: if torch.version.hip:
return True return True
@ -1122,6 +1159,11 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False return False
props = torch.cuda.get_device_properties(device) props = torch.cuda.get_device_properties(device)
if is_mlu():
if props.major > 3:
return True
if props.major >= 8: if props.major >= 8:
return True return True

View File

@ -11,7 +11,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel):
class UMT5XXlTokenizer(sd1_clip.SDTokenizer): class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None) tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0) super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0)
def state_dict(self): def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()} return {"spiece_model": self.tokenizer.serialize_model()}

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.17" __version__ = "0.3.18"

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.17" version = "0.3.18"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"