diff --git a/comfy/model_management.py b/comfy/model_management.py
index b6291f340..6018c1ab6 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -19,7 +19,7 @@
import psutil
import logging
from enum import Enum
-from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
+from comfy.cli_args import args, PerformanceFeature
import threading
import torch
import sys
@@ -651,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
soft_empty_cache()
return unloaded_models
-def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
+def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
@@ -747,26 +747,6 @@ def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, m
current_loaded_models.insert(0, loaded_model)
return
-def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
- with torch.inference_mode():
- load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
- soft_empty_cache()
-
-def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
- #Deliberately load models outside of the Aimdo mempool so they can be retained accross
- #nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
- #thread local. So exploit that to escape context
- if enables_dynamic_vram():
- t = threading.Thread(
- target=load_models_gpu_thread,
- args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
- )
- t.start()
- t.join()
- else:
- load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
- minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
-
def load_model_gpu(model):
return load_models_gpu([model])
@@ -1226,21 +1206,16 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if dtype is None:
dtype = weight._model_dtype
- r = torch.empty_like(weight, dtype=dtype, device=device)
-
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
- raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
- v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
+ v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, weight._v_tensor)[0]
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)
- #always take a deep copy even if _v is good, as we have no reasonable point to unpin
- #a non comfy weight
- r.copy_(v_tensor)
- comfy_aimdo.model_vbar.vbar_unpin(weight._v)
- return r
+ return v_tensor.to(dtype=dtype)
+
+ r = torch.empty_like(weight, dtype=dtype, device=device)
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
#Offloaded casting could skip this, however it would make the quantizations
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index d888dbcfb..b9a117a7c 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -1492,7 +1492,9 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None:
vbar.prioritize()
- #We have way more tools for acceleration on comfy weight offloading, so always
+ #We force reserve VRAM for the non comfy-weight so we dont have to deal
+ #with pin and unpin syncrhonization which can be expensive for small weights
+ #with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True)
loading.sort(reverse=True)
@@ -1541,6 +1543,7 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
+ m._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(m._v, device_to)
allocated_size += v_weight_size
else:
@@ -1555,8 +1558,10 @@ class ModelPatcherDynamic(ModelPatcher):
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
+ weight._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device_to)
weight._model_dtype = model_dtype
allocated_size += weight_size
+ vbar.set_watermark_limit(allocated_size)
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
diff --git a/comfy/ops.py b/comfy/ops.py
index 0f4eca7c7..ea0d70702 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -87,7 +87,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
if signature is not None:
- xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
+ xfer_dest = s._v_tensor
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if not resident:
diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py
index 00dd5ba90..73697b3c1 100644
--- a/comfy/text_encoders/ace15.py
+++ b/comfy/text_encoders/ace15.py
@@ -3,6 +3,7 @@ import comfy.text_encoders.llama
from comfy import sd1_clip
import torch
import math
+from tqdm.auto import trange
import yaml
import comfy.utils
@@ -23,6 +24,8 @@ def sample_manual_loop_no_classes(
audio_end_id: int = 215669,
eos_token_id: int = 151645,
):
+ if ids is None:
+ return []
device = model.execution_device
if execution_dtype is None:
@@ -32,6 +35,7 @@ def sample_manual_loop_no_classes(
execution_dtype = torch.float32
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
+ embeds_batch = embeds.shape[0]
for i, t in enumerate(paddings):
attention_mask[i, :t] = 0
attention_mask[i, t:] = 1
@@ -41,22 +45,27 @@ def sample_manual_loop_no_classes(
generator = torch.Generator(device=device)
generator.manual_seed(seed)
model_config = model.transformer.model.config
+ past_kv_shape = [embeds_batch, model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim]
for x in range(model_config.num_hidden_layers):
- past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
+ past_key_values.append((torch.empty(past_kv_shape, device=device, dtype=execution_dtype), torch.empty(past_kv_shape, device=device, dtype=execution_dtype), 0))
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
- for step in range(max_new_tokens):
+ for step in trange(max_new_tokens, desc="LM sampling"):
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
past_key_values = outputs[2]
- cond_logits = next_token_logits[0:1]
- uncond_logits = next_token_logits[1:2]
- cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
+ if cfg_scale != 1.0:
+ cond_logits = next_token_logits[0:1]
+ uncond_logits = next_token_logits[1:2]
+ cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
+ else:
+ cfg_logits = next_token_logits[0:1]
- if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
+ use_eos_score = eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step
+ if use_eos_score:
eos_score = cfg_logits[:, eos_token_id].clone()
remove_logit_value = torch.finfo(cfg_logits.dtype).min
@@ -64,7 +73,7 @@ def sample_manual_loop_no_classes(
cfg_logits[:, :audio_start_id] = remove_logit_value
cfg_logits[:, audio_end_id:] = remove_logit_value
- if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
+ if use_eos_score:
cfg_logits[:, eos_token_id] = eos_score
if top_k is not None and top_k > 0:
@@ -93,8 +102,8 @@ def sample_manual_loop_no_classes(
break
embed, _, _, _ = model.process_tokens([[token]], device)
- embeds = embed.repeat(2, 1, 1)
- attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
+ embeds = embed.repeat(embeds_batch, 1, 1)
+ attention_mask = torch.cat([attention_mask, torch.ones((embeds_batch, 1), device=device, dtype=attention_mask.dtype)], dim=1)
output_audio_codes.append(token - audio_start_id)
progress_bar.update_absolute(step)
@@ -104,22 +113,29 @@ def sample_manual_loop_no_classes(
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
positive = [[token for token, _ in inner_list] for inner_list in positive]
- negative = [[token for token, _ in inner_list] for inner_list in negative]
positive = positive[0]
- negative = negative[0]
- neg_pad = 0
- if len(negative) < len(positive):
- neg_pad = (len(positive) - len(negative))
- negative = [model.special_tokens["pad"]] * neg_pad + negative
+ if cfg_scale != 1.0:
+ negative = [[token for token, _ in inner_list] for inner_list in negative]
+ negative = negative[0]
- pos_pad = 0
- if len(negative) > len(positive):
- pos_pad = (len(negative) - len(positive))
- positive = [model.special_tokens["pad"]] * pos_pad + positive
+ neg_pad = 0
+ if len(negative) < len(positive):
+ neg_pad = (len(positive) - len(negative))
+ negative = [model.special_tokens["pad"]] * neg_pad + negative
- paddings = [pos_pad, neg_pad]
- return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
+ pos_pad = 0
+ if len(negative) > len(positive):
+ pos_pad = (len(negative) - len(positive))
+ positive = [model.special_tokens["pad"]] * pos_pad + positive
+
+ paddings = [pos_pad, neg_pad]
+ ids = [positive, negative]
+ else:
+ paddings = []
+ ids = [positive]
+
+ return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
@@ -129,12 +145,12 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
user_metas = {
k: kwargs.pop(k)
- for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption")
+ for k in ("bpm", "duration", "keyscale", "timesignature")
if k in kwargs
}
timesignature = user_metas.get("timesignature")
if isinstance(timesignature, str) and timesignature.endswith("/4"):
- user_metas["timesignature"] = timesignature.rsplit("/", 1)[0]
+ user_metas["timesignature"] = timesignature[:-2]
user_metas = {
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
for k, v in user_metas.items()
@@ -147,8 +163,11 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
return f"\n{meta_yaml}\n" if not return_yaml else meta_yaml
def _metas_to_cap(self, **kwargs) -> str:
- use_keys = ("bpm", "duration", "keyscale", "timesignature")
+ use_keys = ("bpm", "timesignature", "keyscale", "duration")
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
+ timesignature = user_metas.get("timesignature")
+ if isinstance(timesignature, str) and timesignature.endswith("/4"):
+ user_metas["timesignature"] = timesignature[:-2]
duration = user_metas["duration"]
if duration == "N/A":
user_metas["duration"] = "30 seconds"
@@ -159,9 +178,13 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
- out = {}
+ text = text.strip()
+ text_negative = kwargs.get("caption_negative", text).strip()
lyrics = kwargs.get("lyrics", "")
+ lyrics_negative = kwargs.get("lyrics_negative", lyrics)
duration = kwargs.get("duration", 120)
+ if isinstance(duration, str):
+ duration = float(duration.split(None, 1)[0])
language = kwargs.get("language")
seed = kwargs.get("seed", 0)
@@ -171,21 +194,46 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
top_p = kwargs.get("top_p", 0.9)
top_k = kwargs.get("top_k", 0.0)
-
duration = math.ceil(duration)
kwargs["duration"] = duration
+ tokens_duration = duration * 5
+ min_tokens = int(kwargs.get("min_tokens", tokens_duration))
+ max_tokens = int(kwargs.get("max_tokens", tokens_duration))
- cot_text = self._metas_to_cot(caption = text, **kwargs)
+ metas_negative = {
+ k.rsplit("_", 1)[0]: kwargs.pop(k)
+ for k in ("bpm_negative", "duration_negative", "keyscale_negative", "timesignature_negative", "language_negative", "caption_negative")
+ if k in kwargs
+ }
+ if not kwargs.get("use_negative_caption"):
+ _ = metas_negative.pop("caption", None)
+
+ cot_text = self._metas_to_cot(caption=text, **kwargs)
+ cot_text_negative = "\n\n" if not metas_negative else self._metas_to_cot(**metas_negative)
meta_cap = self._metas_to_cap(**kwargs)
- lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n"
+ lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n"
+ lyrics_template = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>"
+ qwen3_06b_template = "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>"
- out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True)
- out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "\n"), disable_weights=True)
+ llm_prompts = {
+ "lm_prompt": lm_template.format(text, lyrics.strip(), cot_text),
+ "lm_prompt_negative": lm_template.format(text_negative, lyrics_negative.strip(), cot_text_negative),
+ "lyrics": lyrics_template.format(language if language is not None else "", lyrics),
+ "qwen3_06b": qwen3_06b_template.format(text, meta_cap),
+ }
- out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
- out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
- out["lm_metadata"] = {"min_tokens": duration * 5,
+ out = {
+ prompt_key: self.qwen3_06b.tokenize_with_weights(
+ prompt,
+ prompt_key == "qwen3_06b" and return_word_ids,
+ disable_weights = True,
+ **kwargs,
+ )
+ for prompt_key, prompt in llm_prompts.items()
+ }
+ out["lm_metadata"] = {"min_tokens": min_tokens,
+ "max_tokens": max_tokens,
"seed": seed,
"generate_audio_codes": generate_audio_codes,
"cfg_scale": cfg_scale,
@@ -252,7 +300,7 @@ class ACE15TEModel(torch.nn.Module):
lm_metadata = token_weight_pairs["lm_metadata"]
if lm_metadata["generate_audio_codes"]:
- audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
+ audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["max_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
out["audio_codes"] = [audio_codes]
return base_out, None, out
diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py
index 6aecf1561..8d2d7297a 100644
--- a/comfy_extras/nodes_latent.py
+++ b/comfy_extras/nodes_latent.py
@@ -391,8 +391,9 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
normalized_latent = latent / latent_vector_magnitude
- mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
- std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
+ dims = list(range(1, latent_vector_magnitude.ndim))
+ mean = torch.mean(latent_vector_magnitude, dim=dims, keepdim=True)
+ std = torch.std(latent_vector_magnitude, dim=dims, keepdim=True)
top = (std * 5 + mean) * multiplier
diff --git a/execution.py b/execution.py
index 3dbab82e6..896862c6b 100644
--- a/execution.py
+++ b/execution.py
@@ -13,8 +13,11 @@ from contextlib import nullcontext
import torch
+from comfy.cli_args import args
import comfy.memory_management
import comfy.model_management
+import comfy_aimdo.model_vbar
+
from latent_preview import set_preview_method
import nodes
from comfy_execution.caching import (
@@ -527,8 +530,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
finally:
if allocator is not None:
+ if args.verbose == "DEBUG":
+ comfy_aimdo.model_vbar.vbars_analyze()
comfy.model_management.reset_cast_buffers()
- torch.cuda.synchronize()
+ comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
diff --git a/requirements.txt b/requirements.txt
index 5e34a2a49..4fda07fde 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -22,7 +22,7 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.7
-comfy-aimdo>=0.1.7
+comfy-aimdo>=0.1.8
requests
#non essential dependencies: