diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index dbe84b8bb..5471763bb 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -421,7 +421,7 @@ class WanModel(torch.nn.Module): e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # context - context = self.text_embedding(torch.cat([context, context.new_zeros(context.size(0), self.text_len - context.size(1), context.size(2))], dim=1)) + context = self.text_embedding(context) if clip_fea is not None and self.img_emb is not None: context_clip = self.img_emb(clip_fea) # bs x 257 x dim diff --git a/comfy/model_management.py b/comfy/model_management.py index a6c2c4ef9..a9e10bb46 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -96,6 +96,13 @@ try: except: npu_available = False +try: + import torch_mlu # noqa: F401 + _ = torch.mlu.device_count() + mlu_available = torch.mlu.is_available() +except: + mlu_available = False + if args.cpu: cpu_state = CPUState.CPU @@ -113,6 +120,12 @@ def is_ascend_npu(): return True return False +def is_mlu(): + global mlu_available + if mlu_available: + return True + return False + def get_torch_device(): global directml_enabled global cpu_state @@ -128,6 +141,8 @@ def get_torch_device(): return torch.device("xpu", torch.xpu.current_device()) elif is_ascend_npu(): return torch.device("npu", torch.npu.current_device()) + elif is_mlu(): + return torch.device("mlu", torch.mlu.current_device()) else: return torch.device(torch.cuda.current_device()) @@ -154,6 +169,12 @@ def get_total_memory(dev=None, torch_total_too=False): _, mem_total_npu = torch.npu.mem_get_info(dev) mem_total_torch = mem_reserved mem_total = mem_total_npu + elif is_mlu(): + stats = torch.mlu.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total_mlu = torch.mlu.mem_get_info(dev) + mem_total_torch = mem_reserved + mem_total = mem_total_mlu else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -233,7 +254,7 @@ try: if torch_version_numeric[0] >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True - if is_intel_xpu() or is_ascend_npu(): + if is_intel_xpu() or is_ascend_npu() or is_mlu(): if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True except: @@ -317,6 +338,8 @@ def get_torch_device_name(device): return "{} {}".format(device, torch.xpu.get_device_name(device)) elif is_ascend_npu(): return "{} {}".format(device, torch.npu.get_device_name(device)) + elif is_mlu(): + return "{} {}".format(device, torch.mlu.get_device_name(device)) else: return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) @@ -906,6 +929,8 @@ def xformers_enabled(): return False if is_ascend_npu(): return False + if is_mlu(): + return False if directml_enabled: return False return XFORMERS_IS_AVAILABLE @@ -937,6 +962,8 @@ def pytorch_attention_flash_attention(): return True if is_ascend_npu(): return True + if is_mlu(): + return True if is_amd(): return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention return False @@ -985,6 +1012,13 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_npu, _ = torch.npu.mem_get_info(dev) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_npu + mem_free_torch + elif is_mlu(): + stats = torch.mlu.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_mlu, _ = torch.mlu.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_mlu + mem_free_torch else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -1054,6 +1088,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if is_ascend_npu(): return True + if is_mlu(): + return True + if torch.version.hip: return True @@ -1122,6 +1159,11 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False props = torch.cuda.get_device_properties(device) + + if is_mlu(): + if props.major > 3: + return True + if props.major >= 8: return True diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py index d98c9ad28..971ac8fa8 100644 --- a/comfy/text_encoders/wan.py +++ b/comfy/text_encoders/wan.py @@ -11,7 +11,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel): class UMT5XXlTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0) + super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} diff --git a/comfyui_version.py b/comfyui_version.py index 2c00ff181..9d69edfc1 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.17" +__version__ = "0.3.18" diff --git a/pyproject.toml b/pyproject.toml index d119e9834..be52c6028 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.17" +version = "0.3.18" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9"