diff --git a/comfy/model_management.py b/comfy/model_management.py index 72348258b..b6291f340 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1724,11 +1724,9 @@ def soft_empty_cache(force=False): elif is_mlu(): torch.mlu.empty_cache() elif torch.cuda.is_available(): - if comfy.memory_management.aimdo_allocator is None: - #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist - torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def unload_all_models(): free_memory(1e30, get_torch_device()) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index cdf289395..d888dbcfb 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1400,7 +1400,7 @@ class ModelPatcher: continue key = "diffusion_model." + k unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key)) - return self.model.state_dict_for_saving(unet_state_dict) + return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) def __del__(self): self.unpin_all_weights() diff --git a/comfy/ops.py b/comfy/ops.py index 53c5e4dc3..0f4eca7c7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -54,6 +54,8 @@ try: SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) def scaled_dot_product_attention(q, k, v, *args, **kwargs): + if q.nelement() < 1024 * 128: # arbitrary number, for small inputs cudnn attention seems slower + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) else: diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 73d710671..fce2b67ce 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -19,6 +19,7 @@ def sample_manual_loop_no_classes( min_tokens: int = 1, max_new_tokens: int = 2048, audio_start_id: int = 151669, # The cutoff ID for audio codes + audio_end_id: int = 215669, eos_token_id: int = 151645, ): device = model.execution_device @@ -60,6 +61,7 @@ def sample_manual_loop_no_classes( remove_logit_value = torch.finfo(cfg_logits.dtype).min # Only generate audio tokens cfg_logits[:, :audio_start_id] = remove_logit_value + cfg_logits[:, audio_end_id:] = remove_logit_value if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step: cfg_logits[:, eos_token_id] = eos_score diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 3afd094d1..b6735d210 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -651,10 +651,10 @@ class Llama2_(nn.Module): mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) - mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min) + mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4) if seq_len > 1: - causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min).triu_(1) + causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1) if mask is not None: mask += causal_mask else: diff --git a/comfy/utils.py b/comfy/utils.py index c1b536833..1337e2205 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -82,14 +82,12 @@ _TYPES = { def load_safetensors(ckpt): f = open(ckpt, "rb") mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + mv = memoryview(mapping) header_size = struct.unpack("