ace15: Use dynamic_vram friendly trange (#12409)

Factor out the ksampler trange and use it in ACE LLM to prevent the
silent stall at 0 and rate distortion due to first-step model load.
This commit is contained in:
rattus 2026-02-11 11:53:42 -08:00 committed by GitHub
parent d297a749a2
commit 2a4328d639
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 32 deletions

View File

@ -1,12 +1,11 @@
import math import math
import time
from functools import partial from functools import partial
from scipy import integrate from scipy import integrate
import torch import torch
from torch import nn from torch import nn
import torchsde import torchsde
from tqdm.auto import trange as trange_, tqdm from tqdm.auto import tqdm
from . import utils from . import utils
from . import deis from . import deis
@ -15,34 +14,7 @@ import comfy.model_patcher
import comfy.model_sampling import comfy.model_sampling
import comfy.memory_management import comfy.memory_management
from comfy.utils import model_trange as trange
def trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
return trange_(*args, **kwargs)
pbar = trange_(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")
_update = pbar.update
def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")
_update(n)
pbar.update = warmup_update
return pbar
def append_zero(x): def append_zero(x):
return torch.cat([x, x.new_zeros([1])]) return torch.cat([x, x.new_zeros([1])])

View File

@ -3,7 +3,6 @@ import comfy.text_encoders.llama
from comfy import sd1_clip from comfy import sd1_clip
import torch import torch
import math import math
from tqdm.auto import trange
import yaml import yaml
import comfy.utils import comfy.utils
@ -52,7 +51,7 @@ def sample_manual_loop_no_classes(
progress_bar = comfy.utils.ProgressBar(max_new_tokens) progress_bar = comfy.utils.ProgressBar(max_new_tokens)
for step in trange(max_new_tokens, desc="LM sampling"): for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values) outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
next_token_logits = model.transformer.logits(outputs[0])[:, -1] next_token_logits = model.transformer.logits(outputs[0])[:, -1]
past_key_values = outputs[2] past_key_values = outputs[2]

View File

@ -27,6 +27,7 @@ from PIL import Image
import logging import logging
import itertools import itertools
from torch.nn.functional import interpolate from torch.nn.functional import interpolate
from tqdm.auto import trange
from einops import rearrange from einops import rearrange
from comfy.cli_args import args, enables_dynamic_vram from comfy.cli_args import args, enables_dynamic_vram
import json import json
@ -1155,6 +1156,32 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
def model_trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
return trange(*args, **kwargs)
pbar = trange(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")
_update = pbar.update
def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")
_update(n)
pbar.update = warmup_update
return pbar
PROGRESS_BAR_ENABLED = True PROGRESS_BAR_ENABLED = True
def set_progress_bar_enabled(enabled): def set_progress_bar_enabled(enabled):
global PROGRESS_BAR_ENABLED global PROGRESS_BAR_ENABLED