mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-05 03:00:33 +08:00
Fix some issues with mac. (#12247)
This commit is contained in:
parent
f5030e26fd
commit
affe881354
@ -57,8 +57,9 @@ def sample_manual_loop_no_classes(
|
|||||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
||||||
eos_score = cfg_logits[:, eos_token_id].clone()
|
eos_score = cfg_logits[:, eos_token_id].clone()
|
||||||
|
|
||||||
|
remove_logit_value = torch.finfo(cfg_logits.dtype).min
|
||||||
# Only generate audio tokens
|
# Only generate audio tokens
|
||||||
cfg_logits[:, :audio_start_id] = float('-inf')
|
cfg_logits[:, :audio_start_id] = remove_logit_value
|
||||||
|
|
||||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
||||||
cfg_logits[:, eos_token_id] = eos_score
|
cfg_logits[:, eos_token_id] = eos_score
|
||||||
@ -66,7 +67,7 @@ def sample_manual_loop_no_classes(
|
|||||||
if top_k is not None and top_k > 0:
|
if top_k is not None and top_k > 0:
|
||||||
top_k_vals, _ = torch.topk(cfg_logits, top_k)
|
top_k_vals, _ = torch.topk(cfg_logits, top_k)
|
||||||
min_val = top_k_vals[..., -1, None]
|
min_val = top_k_vals[..., -1, None]
|
||||||
cfg_logits[cfg_logits < min_val] = float('-inf')
|
cfg_logits[cfg_logits < min_val] = remove_logit_value
|
||||||
|
|
||||||
if top_p is not None and top_p < 1.0:
|
if top_p is not None and top_p < 1.0:
|
||||||
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
|
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
|
||||||
@ -75,7 +76,7 @@ def sample_manual_loop_no_classes(
|
|||||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||||
sorted_indices_to_remove[..., 0] = 0
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
cfg_logits[indices_to_remove] = float('-inf')
|
cfg_logits[indices_to_remove] = remove_logit_value
|
||||||
|
|
||||||
if temperature > 0:
|
if temperature > 0:
|
||||||
cfg_logits = cfg_logits / temperature
|
cfg_logits = cfg_logits / temperature
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user