Compare commits

...

29 Commits

Author SHA1 Message Date
rattus
494cc0baba
Merge 0eff43261b into 6165c38cb5 2026-01-14 13:52:57 +08:00
comfyanonymous
6165c38cb5
Optimize nvfp4 lora applying. (#11866)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
This changes results a bit but it also speeds up things a lot.
2026-01-14 00:49:38 -05:00
Silver
712cca36a1
feat: throttle ProgressBar updates to reduce WebSocket flooding (#11504) 2026-01-13 22:41:44 -05:00
Johnpaul Chiwetelu
ac4d8ea9b3
feat: add CI container version bump automation (#11692)
* feat: add CI container version bump automation

Adds a workflow that triggers on releases to create PRs in the
comfyui-ci-container repo, updating the ComfyUI version in the Dockerfile.

Supports both release events and manual workflow dispatch for testing.

* feat: add CI container version bump automation

Adds a workflow that triggers on releases to create PRs in the
comfyui-ci-container repo, updating the ComfyUI version in the Dockerfile.

Supports both release events and manual workflow dispatch for testing.

* ci: update CI container repository owner

* refactor: rename `update-ci-container.yaml` workflow to `update-ci-container.yml`

* Remove post-merge instructions from the CI container update workflow.
2026-01-13 22:39:22 -05:00
nomadoor
c9196f355e
Fix scale_shorter_dimension portrait check (#11862) 2026-01-13 18:25:09 -08:00
Christian Byrne
7eb959ce93
fix: update ComfyUI repo reference to Comfy-Org/ComfyUI (#11858) 2026-01-13 21:03:16 -05:00
nomadoor
469dd9c16a
Adds crop to multiple mode to ResizeImageMaskNode. (#11838)
* Add crop-to-multiple resize mode

* Make scale-to-multiple shape handling explicit
2026-01-13 16:48:10 -08:00
comfyanonymous
eff2b9d412
Optimize nvfp4 lora applying. (#11856) 2026-01-13 19:37:19 -05:00
comfyanonymous
15b312de7a
Optimize nvfp4 lora applying. (#11854) 2026-01-13 19:23:58 -05:00
Alexander Piskun
1419047fdb
[Api Nodes]: Improve Price Badge Declarations (#11582)
* api nodes: price badges moved to nodes code

* added price badges for 4 more node-packs

* added price badges for 10 more node-packs

* added new price badges for Omni STD mode

* add support for autogrow groups

* use full names for "widgets", "inputs" and "groups"

* add strict typing for JSONata rules

* add price badge for WanReferenceVideoApi node

* add support for DynamicCombo

* sync price badges changes (https://github.com/Comfy-Org/ComfyUI_frontend/pull/7900)

* sync badges for Vidu2 nodes

* fixed incorrect price for RecraftCrispUpscaleNode

* fixed incorrect price badges for LTXV nodes

* fixed price badge for MinimaxHailuoVideoNode

* fixed price badges for PixVerse nodes
2026-01-13 16:18:28 -08:00
ric-yu
79f6bb5e4f
add blueprints dir for built-in blueprints (#11853) 2026-01-13 16:14:40 -08:00
Rattus
0eff43261b sd: empty cache on tiler fallback
This is needed for aimdo where the cache cant self recover from
fragmentation. It is however a good thing to do anyway after an OOM
so make it unconditional.
2026-01-13 21:22:49 +10:00
Rattus
013f132085 ruff 2026-01-13 20:29:13 +10:00
Rattus
94a709d813 misc cleanup 2026-01-13 19:58:06 +10:00
Rattus
2c554a112e add missing del on unpin 2026-01-13 19:58:06 +10:00
Rattus
e2d29926a9 write better tx commentary 2026-01-13 19:58:06 +10:00
Rattus
e2b440b25e mm: fix sync
Sync before deleting anything.
2026-01-13 19:58:06 +10:00
Rattus
389c334631 main: Go live with --fast dynamic_vram
Add the optional command line switch --fast dynamic_vram.

This is mutually exclusing --high-vram and --gpu-only which contradict
aimdos underlying feature.

Add appropriate installation warning and a startup message, match the
comfy debug level inconfiguring aimdo.

Add comfy-aimdo pip requirement. This will safely stub to a nop for
unsupported platforms.
2026-01-13 19:58:06 +10:00
Rattus
64c2541b05 execution: add aimdo primary pytorch cache integration
We need to general pytorch cache defragmentation on an appropriate level for
aimdo. Do in here on the per node basis, which has a reasonable chance of
purging stale shapes out of the pytorch caching allocator and saving VRAM
without costing too much garbage collector thrash.

This looks like a lot of GC but because aimdo never fails from pytorch and
saves the pytorch allocator from ever need to defrag out of demand, but it
needs a oil change every now and then so we gotta do it. Doing it here also
means the pytorch temps are cleared from task manager VRAM usage so user
anxiety can go down a little when they see their vram drop back at the end
of workflows inline with inference usage (rather than assuming full VRAM
leaks).
2026-01-13 19:58:06 +10:00
Rattus
3597b27515 models: Use CoreModelPatcher
Use CoreModelPatcher for all internal ModelPatcher implementations. This drives
conditional use of the aimdo feature, while making sure custom node packs get
to keep ModelPatcher unchanged for the moment.
2026-01-13 19:58:06 +10:00
Rattus
f75765721d ops/mp: implement aimdo
Implement a model patcher and caster for aimdo.

A new ModelPatcher implementation which backs onto comfy-aimdo to implement varying model load levels that can be adjusted during model use. The patcher defers all load processes to lazily load the model during use (e.g. the first step of a ksampler) and automatically negotiates a load level during the inference to maximize VRAM usage without OOMing. If inference requires more VRAM than is available weights are offloaded to make space before the OOM happens.

As for loading the weight onto the GPU, that happens via comfy_cast_weights which is now used in all cases. cast_bias_weight checks whether the VBAR assigned to the model has space for the weight (based on the same load priority semantics as the original ModelPatcher). If it does, the VRAM as returned by the Aimdo allocator is used as the parameter GPU side. The caster is responsible for populating the weight data. This is done using the usual offload_stream (which mean we now have asynchronous load overlapping first use compute).

Pinning works a little differently. When a weight is detected during load as unable to fit, a pin is allocated at the time of casting and the weight as used by the layer is DMAd back to the the pin using the GPU DMA TX engine, also using the asynchronous offload streams. This means you get to pin the Lora modified and requantized weights which can be a major speedup for offload+quantize+lora use cases, This works around the JIT Lora + FP8 exclusion and brings FP8MM to heavy offloading users (who probably really need it with more modest GPUs). There is a performance risk in that a CPU+RAM patch has been replace with a GPU+RAM patch but my initial performance results look good. Most users as likely to have a GPU that outruns their CPU in these woods.

Some common code is written to consolidate a layers tensors for aimdo mapping, pinning, and DMA transfers. interpret_gathered_like() allows unpacking a raw buffer as a set of tensors. This is used consistently to bundle and pack weights, quantization metadata (QuantizedTensor bits) and biases into one payload for DMA in the load process reducing Cuda overhead a little. Some Quantization metadata was missing async offload is some cases which is now added. This also pins quantization metadata and consolidates the number of cuda_host_register calls (which can be expensive).
2026-01-13 19:58:06 +10:00
Rattus
f74661edc6 mp: add mode for non comfy weight prioritization
non-comfy weights dont get async offload and a few other performance
limitations. Load them at top priority accordingly.
2026-01-13 19:55:35 +10:00
Rattus
b9ee4c6ee5 mp/mm: APi expansions for dynamic loading
Add two api expansions, a flag for whether a model patcher is dynamic
a a very basic RAM freeing system.

Implement the semantics of the dynamic model patcher which never frees
VRAM ahead of time for the sake of another dynamic model patcher.

At the same time add an API for clearing out pins on a reservation of
model size x2 heuristic, as pins consume RAM in their own right in the
dynamic patcher.

This is actually less about OOMing RAM and more about performance, as
with assign=True load semantics there needs to be plenty headroom for
the OS to load models to dosk cache on demand so err on the side of
kicking old pins out.
2026-01-13 19:55:35 +10:00
Rattus
f511367529 mp: wrap get_free_memory
Dynamic load needs to adjust these numbers based on future movements,
so wrap this in a MP API.
2026-01-13 19:55:35 +10:00
Rattus
439c178c2c pinned_memory: add python
Add a python for managing pinned memory of the weight/bias module level.
This allocates, pins and attached a tensor to a module for the pin for this
module. It does not set the weight, just allocates a singular ram buffer
for population and bulk DMA transfer.
2026-01-13 19:55:35 +10:00
Rattus
f1e8ccae5c move string_to_seed to utils.py
This needs to be visible by ops which may want to do stochastic rounding on
the fly.
2026-01-13 19:55:35 +10:00
Rattus
babccae951 mm: Implement cast buffer allocations 2026-01-13 19:55:35 +10:00
Rattus
967f848df2 ops: Do bias dtype conversion on compute stream
For consistency with weights.
2026-01-13 19:55:35 +10:00
Rattus
caf6c6aada Reduce RAM and compute time in model saving with Loras
Get the model saving logic away from force_patch_weights and instead do
the patching JIT during safetensors saving.

Firstly switch off force_patch_weights in the load for save which avoids
creating CPU side tensors with loras calculated.

Then at save time, wrap the tensor to catch safetensors call to .to() and
patch it live.

This avoids having to ever have a lora-calculated copy of offloaded
weights on the CPU.

Also take advantage of the presence of the GPU when doing this Lora
calculation. The former force_patch_weights would just do eveyrthing on
the CPU. Its generally faster to go the GPU and back even if its just
a Lora application.
2026-01-13 19:55:35 +10:00
48 changed files with 2221 additions and 156 deletions

View File

@ -13,7 +13,7 @@ jobs:
- name: Checkout ComfyUI
uses: actions/checkout@v4
with:
repository: "comfyanonymous/ComfyUI"
repository: "Comfy-Org/ComfyUI"
path: "ComfyUI"
- uses: actions/setup-python@v4
with:

View File

@ -0,0 +1,59 @@
name: "CI: Update CI Container"
on:
release:
types: [published]
workflow_dispatch:
inputs:
version:
description: 'ComfyUI version (e.g., v0.7.0)'
required: true
type: string
jobs:
update-ci-container:
runs-on: ubuntu-latest
# Skip pre-releases unless manually triggered
if: github.event_name == 'workflow_dispatch' || !github.event.release.prerelease
steps:
- name: Get version
id: version
run: |
if [ "${{ github.event_name }}" = "release" ]; then
VERSION="${{ github.event.release.tag_name }}"
else
VERSION="${{ inputs.version }}"
fi
echo "version=$VERSION" >> $GITHUB_OUTPUT
- name: Checkout comfyui-ci-container
uses: actions/checkout@v4
with:
repository: comfy-org/comfyui-ci-container
token: ${{ secrets.CI_CONTAINER_PAT }}
- name: Check current version
id: current
run: |
CURRENT=$(grep -oP 'ARG COMFYUI_VERSION=\K.*' Dockerfile || echo "unknown")
echo "current_version=$CURRENT" >> $GITHUB_OUTPUT
- name: Update Dockerfile
run: |
VERSION="${{ steps.version.outputs.version }}"
sed -i "s/^ARG COMFYUI_VERSION=.*/ARG COMFYUI_VERSION=${VERSION}/" Dockerfile
- name: Create Pull Request
id: create-pr
uses: peter-evans/create-pull-request@v7
with:
token: ${{ secrets.CI_CONTAINER_PAT }}
branch: automation/comfyui-${{ steps.version.outputs.version }}
title: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}"
body: |
Updates ComfyUI version from `${{ steps.current.outputs.current_version }}` to `${{ steps.version.outputs.version }}`
**Triggered by:** ${{ github.event_name == 'release' && format('[Release {0}]({1})', github.event.release.tag_name, github.event.release.html_url) || 'Manual workflow dispatch' }}
labels: automation
commit-message: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}"

View File

@ -10,6 +10,7 @@ import hashlib
class Source:
custom_node = "custom_node"
templates = "templates"
class SubgraphEntry(TypedDict):
source: str
@ -38,6 +39,18 @@ class CustomNodeSubgraphEntryInfo(TypedDict):
class SubgraphManager:
def __init__(self):
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None
def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]:
"""Create a subgraph entry from a file path. Expects normalized path (forward slashes)."""
entry_id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
entry: SubgraphEntry = {
"source": source,
"name": os.path.splitext(os.path.basename(file))[0],
"path": file,
"info": {"node_pack": node_pack},
}
return entry_id, entry
async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r') as f:
@ -60,53 +73,60 @@ class SubgraphManager:
return entries
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
# if not forced to reload and cached, return cache
"""Load subgraphs from custom nodes."""
if not force_reload and self.cached_custom_node_subgraphs is not None:
return self.cached_custom_node_subgraphs
# Load subgraphs from custom nodes
subfolder = "subgraphs"
subgraphs_dict: dict[SubgraphEntry] = {}
subgraphs_dict: dict[SubgraphEntry] = {}
for folder in folder_paths.get_folder_paths("custom_nodes"):
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
matched_files = glob.glob(pattern)
for file in matched_files:
# replace backslashes with forward slashes
pattern = os.path.join(folder, "*/subgraphs/*.json")
for file in glob.glob(pattern):
file = file.replace('\\', '/')
info: CustomNodeSubgraphEntryInfo = {
"node_pack": "custom_nodes." + file.split('/')[-3]
}
source = Source.custom_node
# hash source + path to make sure id will be as unique as possible, but
# reproducible across backend reloads
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
entry: SubgraphEntry = {
"source": Source.custom_node,
"name": os.path.splitext(os.path.basename(file))[0],
"path": file,
"info": info,
}
subgraphs_dict[id] = entry
node_pack = "custom_nodes." + file.split('/')[-3]
entry_id, entry = self._create_entry(file, Source.custom_node, node_pack)
subgraphs_dict[entry_id] = entry
self.cached_custom_node_subgraphs = subgraphs_dict
return subgraphs_dict
async def get_custom_node_subgraph(self, id: str, loadedModules):
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
entry: SubgraphEntry = subgraphs.get(id, None)
if entry is not None and entry.get('data', None) is None:
async def get_blueprint_subgraphs(self, force_reload=False):
"""Load subgraphs from the blueprints directory."""
if not force_reload and self.cached_blueprint_subgraphs is not None:
return self.cached_blueprint_subgraphs
subgraphs_dict: dict[SubgraphEntry] = {}
blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints')
if os.path.exists(blueprints_dir):
for file in glob.glob(os.path.join(blueprints_dir, "*.json")):
file = file.replace('\\', '/')
entry_id, entry = self._create_entry(file, Source.templates, "comfyui")
subgraphs_dict[entry_id] = entry
self.cached_blueprint_subgraphs = subgraphs_dict
return subgraphs_dict
async def get_all_subgraphs(self, loadedModules, force_reload=False):
"""Get all subgraphs from all sources (custom nodes and blueprints)."""
custom_node_subgraphs = await self.get_custom_node_subgraphs(loadedModules, force_reload)
blueprint_subgraphs = await self.get_blueprint_subgraphs(force_reload)
return {**custom_node_subgraphs, **blueprint_subgraphs}
async def get_subgraph(self, id: str, loadedModules):
"""Get a specific subgraph by ID from any source."""
entry = (await self.get_all_subgraphs(loadedModules)).get(id)
if entry is not None and entry.get('data') is None:
await self.load_entry_data(entry)
return entry
def add_routes(self, routes, loadedModules):
@routes.get("/global_subgraphs")
async def get_global_subgraphs(request):
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
# that's the reasoning for the current implementation
subgraphs_dict = await self.get_all_subgraphs(loadedModules)
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
@routes.get("/global_subgraphs/{id}")
async def get_global_subgraph(request):
id = request.match_info.get("id", None)
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
subgraph = await self.get_subgraph(id, loadedModules)
return web.json_response(await self.sanitize_entry(subgraph))

View File

View File

@ -25,11 +25,11 @@ class AudioEncoderModel():
elif model_type == "whisper3":
self.model = WhisperLargeV3(**model_config)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
def get_sd(self):
return self.model.state_dict()

View File

@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum):
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
AutoTune = "autotune"
DynamicVRAM = "dynamic_vram"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
@ -257,3 +258,6 @@ elif args.fast == []:
# '--fast' is provided with a list of performance features, use that list
else:
args.fast = set(args.fast)
def enables_dynamic_vram():
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only

View File

@ -47,10 +47,10 @@ class ClipVisionModel():
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
def get_sd(self):
return self.model.state_dict()

View File

@ -203,7 +203,7 @@ class ControlNet(ControlBase):
self.control_model = control_model
self.load_device = load_device
if control_model is not None:
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
self.compression_ratio = compression_ratio
self.global_average_pooling = global_average_pooling

View File

@ -137,10 +137,44 @@ def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
return rearranged.reshape(padded_rows, padded_cols)
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
def stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator):
F4_E2M1_MAX = 6.0
F8_E4M3_MAX = 448.0
orig_shape = x.shape
block_size = 16
x = x.reshape(orig_shape[0], -1, block_size)
scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
x = x / (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1)
x = x.view(orig_shape).nan_to_num()
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
return data_lp, scaled_block_scales_fp8
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
def roundup(x: int, multiple: int) -> int:
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
# Handle padding
if pad_16x:
rows, cols = x.shape
padded_rows = roundup(rows, 16)
padded_cols = roundup(cols, 16)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
x, blocked_scaled = stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator)
return x, to_blocked(blocked_scaled, flatten=False)
def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=0, block_size=4096 * 4096):
def roundup(x: int, multiple: int) -> int:
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple
@ -158,28 +192,20 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
# what we want to produce. If we pad here, we want the padded output.
orig_shape = x.shape
block_size = 16
orig_shape = list(orig_shape)
x = x.reshape(orig_shape[0], -1, block_size)
max_abs = torch.amax(torch.abs(x), dim=-1)
block_scale = max_abs / F4_E2M1_MAX
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)
# Handle zero blocks (from padding): avoid 0/0 NaN
zero_scale_mask = (total_scale == 0)
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)
x = x / total_scale_safe.unsqueeze(-1)
output_fp4 = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 2], dtype=torch.uint8, device=x.device)
output_block = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 16], dtype=torch.float8_e4m3fn, device=x.device)
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x)
num_slices = max(1, (x.numel() / block_size))
slice_size = max(1, (round(x.shape[0] / num_slices)))
x = x.view(orig_shape)
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
for i in range(0, x.shape[0], slice_size):
fp4, block = stochastic_round_quantize_nvfp4_block(x[i: i + slice_size], per_tensor_scale, generator=generator)
output_fp4[i:i + slice_size].copy_(fp4)
output_block[i:i + slice_size].copy_(block)
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
return data_lp, blocked_scales
return output_fp4, to_blocked(output_block, flatten=False)

View File

@ -109,10 +109,10 @@ class HunyuanVideo15SRModel():
self.model_class = UPSAMPLERS.get(model_type)
self.model = self.model_class(**config).eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=True)
return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic())
def get_sd(self):
return self.model.state_dict()

View File

@ -0,0 +1,54 @@
from comfy.quant_ops import QuantizedTensor
import comfy_aimdo.torch
def vram_aligned_size(tensor):
if isinstance(tensor, list):
return sum([vram_aligned_size(t) for t in tensor])
if isinstance(tensor, QuantizedTensor):
inner_tensors, _ = tensor.__tensor_flatten__()
return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])
if tensor is None:
return 0
size = tensor.numel() * tensor.element_size()
aligment_req = 1024
return (size + aligment_req - 1) // aligment_req * aligment_req
def interpret_gathered_like(tensors, gathered):
offset = 0
dest_views = []
if gathered.dim() != 1 or gathered.element_size() != 1:
raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")
for tensor in tensors:
if tensor is None:
dest_views.append(None)
continue
if isinstance(tensor, QuantizedTensor):
inner_tensors, qt_ctx = tensor.__tensor_flatten__()
templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
else:
templates = { "data": tensor }
actuals = {}
for attr, template in templates.items():
size = template.numel() * template.element_size()
if offset + size > gathered.numel():
raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
offset += vram_aligned_size(template)
if isinstance(tensor, QuantizedTensor):
dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
else:
dest_views.append(actuals["data"])
return dest_views
aimdo_allocator = comfy_aimdo.torch.CUDAPluggableAllocator()

View File

@ -298,7 +298,7 @@ class BaseModel(torch.nn.Module):
return out
def load_model_weights(self, sd, unet_prefix=""):
def load_model_weights(self, sd, unet_prefix="", assign=False):
to_load = {}
keys = list(sd.keys())
for k in keys:
@ -306,7 +306,7 @@ class BaseModel(torch.nn.Module):
to_load[k[len(unet_prefix):]] = sd.pop(k)
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))
@ -321,7 +321,7 @@ class BaseModel(torch.nn.Module):
def process_latent_out(self, latent):
return self.latent_format.process_out(latent)
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
extra_sds = []
if clip_state_dict is not None:
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
@ -329,10 +329,7 @@ class BaseModel(torch.nn.Module):
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
if clip_vision_state_dict is not None:
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([])
@ -775,8 +772,8 @@ class StableAudio1(BaseModel):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
for k in d:
s = d[k]

View File

@ -26,6 +26,12 @@ import platform
import weakref
import gc
import os
from contextlib import nullcontext
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.torch
import comfy_aimdo.model_vbar
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -592,7 +598,7 @@ def extra_reserved_memory():
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def free_memory(memory_required, device, keep_loaded=[]):
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
cleanup_models_gc()
unloaded_model = []
can_unload = []
@ -607,15 +613,22 @@ def free_memory(memory_required, device, keep_loaded=[]):
for x in sorted(can_unload):
i = x[-1]
memory_to_free = None
memory_to_free = 1e32
ram_to_free = 1e32
if not DISABLE_SMART_MEMORY:
free_mem = get_free_memory(device)
if free_mem > memory_required:
break
memory_to_free = memory_required - free_mem
memory_to_free = memory_required - get_free_memory(device)
ram_to_free = ram_required - psutil.virtual_memory().available
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size()
continue
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
if current_loaded_models[i].model_unload(memory_to_free):
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
unloaded_model.append(i)
if ram_to_free > 0:
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
@ -650,7 +663,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
models_to_load = []
free_for_dynamic=True
for x in models:
if not x.is_dynamic():
free_for_dynamic = False
loaded_model = LoadedModel(x)
try:
loaded_model_index = current_loaded_models.index(loaded_model)
@ -676,19 +692,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
model_to_unload.model.detach(unpatch_all=False)
model_to_unload.model_finalizer.detach()
total_memory_required = {}
total_ram_required = {}
for loaded_model in models_to_load:
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
#want to do.
#FIXME: This should subtract off the to_load current pin consumption.
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
for device in total_memory_required:
if device != torch.device("cpu"):
free_mem = get_free_memory(device)
if free_mem < minimum_memory_required:
models_l = free_memory(minimum_memory_required, device)
models_l = free_memory(minimum_memory_required, device, for_dynamic=free_for_dynamic)
logging.info("{} models unloaded.".format(len(models_l)))
for loaded_model in models_to_load:
@ -732,6 +754,9 @@ def loaded_models(only_currently_used=False):
def cleanup_models_gc():
do_gc = False
reset_cast_buffers()
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.is_dead():
@ -1051,6 +1076,49 @@ def current_stream(device):
return None
stream_counters = {}
STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0)
def get_cast_buffer(offload_stream, device, size, ref):
global LARGEST_CASTED_WEIGHT
if offload_stream is not None:
wf_context = offload_stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(offload_stream)
else:
wf_context = nullcontext()
cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None)
if cast_buffer is None or cast_buffer.numel() < size:
if ref is LARGEST_CASTED_WEIGHT[0]:
#If there is one giant weight we do not want both streams to
#allocate a buffer for it. It's up to the caster to get the other
#offload stream in this corner case
return None
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
torch.cuda.synchronize()
del STREAM_CAST_BUFFERS[offload_stream]
del cast_buffer
torch.cuda.empty_cache()
with wf_context:
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
if size > LARGEST_CASTED_WEIGHT[1]:
LARGEST_CASTED_WEIGHT = (ref, size)
return cast_buffer
def reset_cast_buffers():
global LARGEST_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0)
torch.cuda.synchronize()
STREAM_CAST_BUFFERS.clear()
torch.cuda.empty_cache()
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
if NUM_STREAMS == 0:
@ -1093,7 +1161,59 @@ def sync_stream(device, stream):
return
current_stream(device).wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
wf_context = nullcontext()
if stream is not None:
wf_context = stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
with wf_context:
for tensor in tensors:
dest_view = dest_views.pop(0)
if tensor is None:
continue
dest_view.copy_(tensor, non_blocking=non_blocking)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
if hasattr(weight, "_v"):
#Unexpected usage patterns. There is no reason these don't work but they
#have no testing and no callers do this.
assert r is None
assert stream is None
r = torch.empty_like(weight, dtype=dtype, device=device)
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like([weight], raw_tensor)[0]
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
#a non comfy weight
r.copy_(v_tensor)
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
return r
r.copy_(weight, non_blocking=non_blocking)
#FIXME: remove hooks before PR
if hasattr(weight, "comfy_hook"):
dtype = r.dtype
r = weight.comfy_hook(r)
if r.dtype != dtype:
r = comfy.float.stochastic_rounding(r, dtype, seed=comfy.utils.string_to_seed(weight.seed_key))
if signature is not None:
v_tensor.copy_(r)
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
return r
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
@ -1112,10 +1232,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
with wf_context:
r = torch.empty_like(weight, dtype=dtype, device=device)
if r is None:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
else:
r = torch.empty_like(weight, dtype=dtype, device=device)
if r is None:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r
@ -1135,7 +1257,7 @@ if not args.disable_pinned_memory:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
def discard_cuda_async_error():
try:
@ -1557,6 +1679,7 @@ def soft_empty_cache(force=False):
elif is_mlu():
torch.mlu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

View File

@ -38,19 +38,7 @@ from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
def string_to_seed(data):
crc = 0xFFFFFFFF
for byte in data:
if isinstance(byte, str):
byte = ord(byte)
crc ^= byte
for _ in range(8):
if crc & 1:
crc = (crc >> 1) ^ 0xEDB88320
else:
crc >>= 1
return crc ^ 0xFFFFFFFF
import comfy_aimdo.model_vbar
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
@ -212,6 +200,27 @@ class MemoryCounter:
def decrement(self, used: int):
self.value -= used
CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0)
class LazyCastingParam(torch.nn.Parameter):
def __new__(cls, model, key, tensor):
return super().__new__(cls, tensor)
def __init__(self, model, key, tensor):
self.model = model
self.key = key
@property
def device(self):
return CustomTorchDevice
#safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is
#then just a short lived thing in the safetensors serialization logic inside its big for loop over
#all weights getting garbage collected per-weight
def to(self, *args, **kwargs):
return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu")
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size
@ -269,6 +278,9 @@ class ModelPatcher:
if not hasattr(self.model, 'model_offload_buffer_memory'):
self.model.model_offload_buffer_memory = 0
def is_dynamic(self):
return False
def model_size(self):
if self.size > 0:
return self.size
@ -284,6 +296,9 @@ class ModelPatcher:
def lowvram_patch_counter(self):
return self.model.lowvram_patch_counter
def get_free_memory(self, device):
return comfy.model_management.get_free_memory(device)
def clone(self):
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {}
@ -611,14 +626,14 @@ class ModelPatcher:
sd.pop(k)
return sd
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
if key not in self.patches:
return
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False):
weight, set_func, convert_func = get_key_weight(self.model, key)
if key not in self.patches:
return weight
inplace_update = self.weight_inplace_update or inplace_update
if key not in self.backup:
if key not in self.backup and not return_weight:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
@ -631,13 +646,15 @@ class ModelPatcher:
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if inplace_update:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
if return_weight:
return out_weight
elif inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
return set_func(out_weight, inplace_update=inplace_update, seed=comfy.utils.string_to_seed(key), return_weight=return_weight)
def pin_weight_to_device(self, key):
weight, set_func, convert_func = get_key_weight(self.model, key)
@ -654,7 +671,7 @@ class ModelPatcher:
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self):
def _load_list(self, prio_comfy_cast_weights=False):
loading = []
for n, m in self.model.named_modules():
params = []
@ -681,7 +698,8 @@ class ModelPatcher:
return 0
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
loading.append((module_offload_mem, module_mem, n, m, params))
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@ -984,6 +1002,9 @@ class ModelPatcher:
return self.model.model_loaded_weight_memory - current_used
def partially_unload_ram(self, ram_to_unload):
pass
def detach(self, unpatch_all=True):
self.eject_model()
self.model_patches_to(self.offload_device)
@ -1317,10 +1338,10 @@ class ModelPatcher:
key, original_weights=original_weights)
del original_weights[key]
if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
set_func(out_weight, inplace_update=True, seed=string_to_seed(key))
set_func(out_weight, inplace_update=True, seed=comfy.utils.string_to_seed(key))
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
# TODO: disable caching if not enough system RAM to do so
target_device = self.offload_device
@ -1355,7 +1376,237 @@ class ModelPatcher:
self.unpatch_hooks()
self.clear_cached_hook_weights()
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
unet_state_dict = self.model.diffusion_model.state_dict()
for k, v in unet_state_dict.items():
op_keys = k.rsplit('.', 1)
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
continue
try:
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
except:
continue
if not op or not hasattr(op, "comfy_cast_weights") or \
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
continue
key = "diffusion_model." + k
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
return self.model.state_dict_for_saving(unet_state_dict)
def __del__(self):
self.unpin_all_weights()
self.detach(unpatch_all=False)
class ModelPatcherDynamic(ModelPatcher):
def __new__(cls, model, load_device, offload_device, size=0, weight_inplace_update=False):
if comfy.model_management.is_device_cpu(load_device):
#reroute to default MP for CPUs
return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update)
return super().__new__(cls)
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
#this is now way more dynamic and we dont support the same base model for both Dynamic
#and non-dynamic patchers.
if hasattr(self.model, "model_loaded_weight_memory"):
del self.model.model_loaded_weight_memory
if not hasattr(self.model, "dynamic_vbars"):
self.model.dynamic_vbars = {}
assert load_device is not None
def is_dynamic(self):
return True
def _vbar_get(self, create=False):
if self.load_device == torch.device("cpu"):
return None
vbar = self.model.dynamic_vbars.get(self.load_device, None)
if create and vbar is None:
vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 1.2, self.load_device.index)
self.model.dynamic_vbars[self.load_device] = vbar
return vbar
def loaded_size(self):
vbar = self._vbar_get()
if vbar is None:
return 0
return vbar.loaded_size()
def get_free_memory(self, device):
#NOTE: on high condition / batch counts, estimate should have already vacated
#all non-dynamic models so this is safe even if its not 100% true that this
#would all be avaiable for inference use.
return comfy.model_management.get_total_memory(device) - self.model_size()
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
def pin_weight_to_device(self, key):
raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading")
def unpin_weight(self, key):
raise RuntimeError("unpin_weight invalid for dymamic weight loading")
def unpin_all_weights(self):
pass
def memory_required(self, input_shape):
#Pad this significantly. We are trying to get away from precise estimates. This
#estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you
#use all ModelPatcherDynamic this is ignored and its all done dynamically.
return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
#Force patching doesn't make sense in Dynamic loading, as you dont know what does and
#doesn't need to be forced at this stage. The only thing you could do would be patch
#it all on CPU which consumes huge RAM.
assert not force_patch_weights
#Full load doesn't make sense as we dont actually have any loader capability here and
#now.
assert not full_load
assert device_to == self.load_device
num_patches = 0
allocated_size = 0
with self.use_ejected():
self.unpatch_hooks()
vbar = self._vbar_get(create=True)
if vbar is not None:
vbar.prioritize()
#We have way more tools for acceleration on comfy weight offloading, so always
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True)
loading.sort(reverse=True)
for x in loading:
_, _, _, n, m, params = x
def set_dirty(item, dirty):
if dirty or not hasattr(item, "_v_signature"):
item._v_signature = None
if dirty:
comfy.pinned_memory.unpin_memory(item)
def setup_param(self, m, n, param_key):
nonlocal num_patches
key = "{}.{}".format(n, param_key)
weight_function = []
weight, _, _ = get_key_weight(self.model, key)
if key in self.patches:
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
num_patches += 1
else:
setattr(m, param_key + "_lowvram_function", None)
if key in self.weight_wrapper_patches:
weight_function.extend(self.weight_wrapper_patches[key])
setattr(m, param_key + "_function", weight_function)
return comfy.memory_management.vram_aligned_size(weight)
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
m.pin_failed = False
m.seed_key = n
set_dirty(m, dirty)
v_weight_size = 0
v_weight_size += setup_param(self, m, n, "weight")
v_weight_size += setup_param(self, m, n, "bias")
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size
else:
for param in params:
key = "{}.{}".format(n, param)
weight, _, _ = get_key_weight(self.model, key)
weight.seed_key = key
set_dirty(weight, dirty)
weight_size = weight.numel() * weight.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
allocated_size += weight_size
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
self.model.device = device_to
self.model.current_weight_patches_uuid = self.patches_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
#These are all super dangerous. Who knows what the custom nodes actually do here...
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
self.apply_hooks(self.forced_hooks, force_apply=True)
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
assert not force_patch_weights #See above
assert self.load_device != torch.device("cpu")
vbar = self._vbar_get()
return 0 if vbar is None else vbar.free_memory(memory_to_free)
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True)
for x in loading:
_, _, _, _, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
if ram_to_unload <= 0:
return
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
#This isn't used by the core at all and can only be to load a model out of
#the control of proper model_managment. If you are a custom node author reading
#this, the correct pattern is to call load_models_gpu() to get a proper
#managed load of your model.
assert not load_weights
return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights)
def unpatch_model(self, device_to=None, unpatch_weights=True):
super().unpatch_model(device_to=None, unpatch_weights=False)
if unpatch_weights:
self.partially_unload_ram(1e32)
self.partially_unload(None)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above
with self.use_ejected(skip_and_inject_on_exit_only=True):
dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid)
self.unpatch_model(self.offload_device, unpatch_weights=False)
self.patch_model(load_weights=False)
try:
self.load(device_to, dirty=dirty)
except Exception as e:
self.detach()
raise e
#ModelPatcher::partially_load returns a number on what got loaded but
#nothing in core uses this and we have no data in the Dynamic world. Hit
#the custom node devs with a None rather than a 0 that would mislead any
#logic they might have.
return None
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
assert False #Should be unreachable - we dont ever cache in the new implementation
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
if key not in combined_patches:
return
raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup")
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
pass
CoreModelPatcher = ModelPatcher

View File

@ -23,6 +23,12 @@ from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
import json
import comfy.memory_management
import comfy.pinned_memory
import comfy.utils
import comfy_aimdo.model_vbar
import comfy_aimdo.torch
def run_every_op():
if torch.compiler.is_compiling():
@ -72,7 +78,109 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
offload_stream = None
xfer_dest = None
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
if signature is not None:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if not resident:
xfer_source = [ s.weight, s.bias ]
pin = comfy.pinned_memory.get_pin(s)
if pin is not None:
xfer_source = [ pin ]
resident = True #If pinned data exists, it always has LowVram already applied
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
offload_stream = comfy.model_management.get_offload_stream(device)
if xfer_dest is None and offload_stream is not None:
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
if xfer_dest is None:
offload_stream = comfy.model_management.get_offload_stream(device)
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
if xfer_dest is None:
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
offload_stream = None
#send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
pin = None
if signature is not None:
#If we are able to increase our load level (e.g. user reduces resolution or batch number)
#reclaim the pin previously used for offload.
comfy.pinned_memory.unpin_memory(s)
elif not resident:
#prepare a new pin
assert comfy.pinned_memory.get_pin(s) is None
comfy.pinned_memory.pin_memory(s)
pin = comfy.pinned_memory.get_pin(s)
params = comfy.memory_management.interpret_gathered_like([s.weight, s.bias], xfer_dest)
weight = params[0]
bias = params[1]
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
fns = getattr(s, param_key + "_function", [])
orig = x
def to_dequant(tensor, dtype):
tensor = tensor.to(dtype=dtype)
if isinstance(tensor, QuantizedTensor):
tensor = tensor.dequantize()
return tensor
if orig.dtype != dtype or len(fns) > 0:
x = to_dequant(x, dtype)
if not resident and lowvram_fn is not None:
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
x = lowvram_fn(x)
if (isinstance(orig, QuantizedTensor) and
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
seed = comfy.utils.string_to_seed(s.seed_key)
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
if orig.dtype == dtype and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
else:
y = x
if update_weight:
orig.copy_(y)
for f in fns:
x = f(x)
return x
update_weight = signature is not None or pin is not None
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
s._v_signature=signature
if pin is not None:
xfer_dest = comfy.memory_management.interpret_gathered_like([ pin ], xfer_dest)[0]
#FIXME: This might be the wrong thing to do. Some reading suggests the DMA engine
#is posted writes and the compute stream could just fire and forget here. That
#would save this sync and some stalling on the offload stream that is better off
#running ahead to the next layer to read.
if offload_stream is not None:
offload_stream.wait_stream(comfy.model_management.current_stream(device))
comfy.model_management.cast_to(xfer_dest, device=pin.device, non_blocking=non_blocking, stream=offload_stream, r=pin)
#FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
@ -87,22 +195,38 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None:
device = input.device
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
offload_stream = comfy.model_management.get_offload_stream(device)
else:
offload_stream = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
bias = None
weight = None
if offload_stream is not None and not args.cuda_malloc:
cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ])
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
#The streams can be uneven in buffer capability and reject us. Retry to get the other stream
if cast_buffer is None:
offload_stream = comfy.model_management.get_offload_stream(device)
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer)
weight = params[0]
bias = params[1]
weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight)
bias = None
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias)
comfy.model_management.sync_stream(device, offload_stream)
@ -110,6 +234,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
weight_a = weight
if s.bias is not None:
bias = bias.to(dtype=bias_dtype)
for f in s.bias_function:
bias = f(bias)
@ -131,14 +256,20 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
os, weight_a, bias_a = offload_stream
device=None
#FIXME: This is not good RTTI
if not isinstance(weight_a, torch.Tensor):
comfy_aimdo.model_vbar.vbar_unpin(s._v)
device = weight_a
if os is None:
return
if weight_a is not None:
device = weight_a.device
else:
if bias_a is None:
return
device = bias_a.device
if device is None:
if weight_a is not None:
device = weight_a.device
else:
if bias_a is None:
return
device = bias_a.device
os.wait_stream(comfy.model_management.current_stream(device))
@ -653,8 +784,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
def forward_comfy_cast_weights(self, input, compute_dtype=None):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
@ -664,6 +795,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
input_shape = input.shape
reshaped_3d = False
#If cast needs to apply lora, it should be done in the compute dtype
compute_dtype = input.dtype
if (getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
@ -682,7 +815,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
output = self.forward_comfy_cast_weights(input)
output = self.forward_comfy_cast_weights(input, compute_dtype)
# Reshape output back to 3D if input was 3D
if reshaped_3d:

30
comfy/pinned_memory.py Normal file
View File

@ -0,0 +1,30 @@
import torch
import comfy.model_management
import comfy.memory_management
from comfy.cli_args import args
def get_pin(module):
return getattr(module, "_pin", None)
def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
return
#FIXME: This is a RAM cache trigger event
params = [ module.weight, module.bias ]
size = comfy.memory_management.vram_aligned_size(params)
pin = torch.empty((size,), dtype=torch.uint8)
if comfy.model_management.pin_memory(pin):
module._pin = pin
else:
module.pin_failed = True
return False
return True
def unpin_memory(module):
if get_pin(module) is None:
return 0
size = module._pin.numel() * module._pin.element_size()
comfy.model_management.unpin_memory(module._pin)
del module._pin
return size

View File

@ -104,7 +104,7 @@ class TensorCoreNVFP4Layout(_CKNvfp4Layout):
needs_padding = padded_shape != orig_shape
if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4_by_block(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)

View File

@ -9,7 +9,6 @@ if TYPE_CHECKING:
import torch
from functools import partial
import collections
from comfy import model_management
import math
import logging
import comfy.sampler_helpers
@ -260,7 +259,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(x_in.device)
free_memory = model.current_patcher.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]

View File

@ -128,7 +128,7 @@ class CLIP:
logging.warning("Had to shift TE back.")
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
@ -288,7 +288,7 @@ class CLIP:
def load_sd(self, sd, full_model=False):
if full_model:
return self.cond_stage_model.load_state_dict(sd, strict=False)
return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
else:
return self.cond_stage_model.load_sd(sd)
@ -665,13 +665,6 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0:
logging.warning("Missing VAE keys {}".format(m))
if len(u) > 0:
logging.debug("Leftover VAE keys {}".format(u))
if device is None:
device = model_management.vae_device()
self.device = device
@ -682,7 +675,18 @@ class VAE:
self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device()
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
mp = comfy.model_patcher.CoreModelPatcher
if self.disable_offload:
mp = comfy.model_patcher.ModelPatcher
self.patcher = mp(self.first_stage_model, load_device=self.device, offload_device=offload_device)
m, u = self.first_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
if len(m) > 0:
logging.warning("Missing VAE keys {}".format(m))
if len(u) > 0:
logging.debug("Leftover VAE keys {}".format(u))
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()
@ -797,7 +801,7 @@ class VAE:
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = model_management.get_free_memory(self.device)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
@ -816,6 +820,7 @@ class VAE:
do_tile = True
if do_tile:
torch.cuda.empty_cache()
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
@ -871,7 +876,7 @@ class VAE:
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = model_management.get_free_memory(self.device)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
samples = None
@ -891,6 +896,7 @@ class VAE:
do_tile = True
if do_tile:
torch.cuda.empty_cache()
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
@ -1315,7 +1321,7 @@ def load_gligen(ckpt_path):
model = gligen.load_gligen(data)
if model_management.should_use_fp16():
model = model.half()
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
return comfy.model_patcher.CoreModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def model_detection_error_hint(path, state_dict):
filename = os.path.basename(path)
@ -1403,7 +1409,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model.load_model_weights(sd, diffusion_model_prefix)
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
@ -1446,7 +1453,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
logging.debug("left over keys: {}".format(left_over))
if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
if inital_load_device != torch.device("cpu"):
logging.info("loaded diffusion model directly to GPU")
model_management.load_models_gpu([model_patcher], force_full_load=True)
@ -1538,13 +1544,14 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device)
if not model_management.is_device_cpu(offload_device):
model.to(offload_device)
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
left_over = sd.keys()
if len(left_over) > 0:
logging.info("left over keys in diffusion model: {}".format(left_over))
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
return model_patcher
def load_diffusion_model(unet_path, model_options={}):
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
@ -1575,9 +1582,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
if metadata is None:
metadata = {}
model_management.load_models_gpu(load_models, force_patch_weights=True)
model_management.load_models_gpu(load_models)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
for k in extra_keys:
sd[k] = extra_keys[k]

View File

@ -1042,7 +1042,7 @@ class ZImage(Lumina2):
"shift": 3.0,
}
memory_usage_factor = 2.0
memory_usage_factor = 2.8
supported_inference_dtypes = [torch.bfloat16, torch.float32]

View File

@ -30,6 +30,7 @@ from torch.nn.functional import interpolate
from einops import rearrange
from comfy.cli_args import args
import json
import time
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
@ -1097,6 +1098,10 @@ def set_progress_bar_global_hook(function):
global PROGRESS_BAR_HOOK
PROGRESS_BAR_HOOK = function
# Throttle settings for progress bar updates to reduce WebSocket flooding
PROGRESS_THROTTLE_MIN_INTERVAL = 0.1 # 100ms minimum between updates
PROGRESS_THROTTLE_MIN_PERCENT = 0.5 # 0.5% minimum progress change
class ProgressBar:
def __init__(self, total, node_id=None):
global PROGRESS_BAR_HOOK
@ -1104,6 +1109,8 @@ class ProgressBar:
self.current = 0
self.hook = PROGRESS_BAR_HOOK
self.node_id = node_id
self._last_update_time = 0.0
self._last_sent_value = -1
def update_absolute(self, value, total=None, preview=None):
if total is not None:
@ -1112,7 +1119,29 @@ class ProgressBar:
value = self.total
self.current = value
if self.hook is not None:
self.hook(self.current, self.total, preview, node_id=self.node_id)
current_time = time.perf_counter()
is_first = (self._last_sent_value < 0)
is_final = (value >= self.total)
has_preview = (preview is not None)
# Always send immediately for previews, first update, or final update
if has_preview or is_first or is_final:
self.hook(self.current, self.total, preview, node_id=self.node_id)
self._last_update_time = current_time
self._last_sent_value = value
return
# Apply throttling for regular progress updates
if self.total > 0:
percent_changed = ((value - max(0, self._last_sent_value)) / self.total) * 100
else:
percent_changed = 100
time_elapsed = current_time - self._last_update_time
if time_elapsed >= PROGRESS_THROTTLE_MIN_INTERVAL and percent_changed >= PROGRESS_THROTTLE_MIN_PERCENT:
self.hook(self.current, self.total, preview, node_id=self.node_id)
self._last_update_time = current_time
self._last_sent_value = value
def update(self, value):
self.update_absolute(self.current + value)
@ -1267,3 +1296,16 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
return state_dict, metadata
def string_to_seed(data):
crc = 0xFFFFFFFF
for byte in data:
if isinstance(byte, str):
byte = ord(byte)
crc ^= byte
for _ in range(8):
if crc & 1:
crc = (crc >> 1) ^ 0xEDB88320
else:
crc >>= 1
return crc ^ 0xFFFFFFFF

View File

@ -1225,6 +1225,7 @@ class NodeInfoV1:
deprecated: bool=None
experimental: bool=None
api_node: bool=None
price_badge: dict | None = None
@dataclass
class NodeInfoV3:
@ -1234,11 +1235,77 @@ class NodeInfoV3:
name: str=None
display_name: str=None
description: str=None
python_module: Any = None
category: str=None
output_node: bool=None
deprecated: bool=None
experimental: bool=None
api_node: bool=None
price_badge: dict | None = None
@dataclass
class PriceBadgeDepends:
widgets: list[str] = field(default_factory=list)
inputs: list[str] = field(default_factory=list)
input_groups: list[str] = field(default_factory=list)
def validate(self) -> None:
if not isinstance(self.widgets, list) or any(not isinstance(x, str) for x in self.widgets):
raise ValueError("PriceBadgeDepends.widgets must be a list[str].")
if not isinstance(self.inputs, list) or any(not isinstance(x, str) for x in self.inputs):
raise ValueError("PriceBadgeDepends.inputs must be a list[str].")
if not isinstance(self.input_groups, list) or any(not isinstance(x, str) for x in self.input_groups):
raise ValueError("PriceBadgeDepends.input_groups must be a list[str].")
def as_dict(self, schema_inputs: list["Input"]) -> dict[str, Any]:
# Build lookup: widget_id -> io_type
input_types: dict[str, str] = {}
for inp in schema_inputs:
all_inputs = inp.get_all()
input_types[inp.id] = inp.get_io_type() # First input is always the parent itself
for nested_inp in all_inputs[1:]:
# For DynamicCombo/DynamicSlot, nested inputs are prefixed with parent ID
# to match frontend naming convention (e.g., "should_texture.enable_pbr")
prefixed_id = f"{inp.id}.{nested_inp.id}"
input_types[prefixed_id] = nested_inp.get_io_type()
# Enrich widgets with type information, raising error for unknown widgets
widgets_data: list[dict[str, str]] = []
for w in self.widgets:
if w not in input_types:
raise ValueError(
f"PriceBadge depends_on.widgets references unknown widget '{w}'. "
f"Available widgets: {list(input_types.keys())}"
)
widgets_data.append({"name": w, "type": input_types[w]})
return {
"widgets": widgets_data,
"inputs": self.inputs,
"input_groups": self.input_groups,
}
@dataclass
class PriceBadge:
expr: str
depends_on: PriceBadgeDepends = field(default_factory=PriceBadgeDepends)
engine: str = field(default="jsonata")
def validate(self) -> None:
if self.engine != "jsonata":
raise ValueError(f"Unsupported PriceBadge.engine '{self.engine}'. Only 'jsonata' is supported.")
if not isinstance(self.expr, str) or not self.expr.strip():
raise ValueError("PriceBadge.expr must be a non-empty string.")
self.depends_on.validate()
def as_dict(self, schema_inputs: list["Input"]) -> dict[str, Any]:
return {
"engine": self.engine,
"depends_on": self.depends_on.as_dict(schema_inputs),
"expr": self.expr,
}
@dataclass
@ -1284,6 +1351,8 @@ class Schema:
"""Flags a node as experimental, informing users that it may change or not work as expected."""
is_api_node: bool=False
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
price_badge: PriceBadge | None = None
"""Optional client-evaluated pricing badge declaration for this node."""
not_idempotent: bool=False
"""Flags a node as not idempotent; when True, the node will run and not reuse the cached outputs when identical inputs are provided on a different node in the graph."""
enable_expand: bool=False
@ -1314,6 +1383,8 @@ class Schema:
input.validate()
for output in self.outputs:
output.validate()
if self.price_badge is not None:
self.price_badge.validate()
def finalize(self):
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
@ -1387,7 +1458,8 @@ class Schema:
deprecated=self.is_deprecated,
experimental=self.is_experimental,
api_node=self.is_api_node,
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes")
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
)
return info
@ -1419,7 +1491,8 @@ class Schema:
deprecated=self.is_deprecated,
experimental=self.is_experimental,
api_node=self.is_api_node,
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes")
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
)
return info
@ -1971,4 +2044,6 @@ __all__ = [
"add_to_dict_v3",
"V3Data",
"ImageCompare",
"PriceBadgeDepends",
"PriceBadge",
]

View File

@ -97,6 +97,9 @@ class FluxProUltraImageNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.06}""",
),
)
@classmethod
@ -352,6 +355,9 @@ class FluxProExpandNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.05}""",
),
)
@classmethod
@ -458,6 +464,9 @@ class FluxProFillNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.05}""",
),
)
@classmethod
@ -511,6 +520,21 @@ class Flux2ProImageNode(IO.ComfyNode):
NODE_ID = "Flux2ProImageNode"
DISPLAY_NAME = "Flux.2 [pro] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate"
PRICE_BADGE_EXPR = """
(
$MP := 1024 * 1024;
$outMP := $max([1, $floor(((widgets.width * widgets.height) + $MP - 1) / $MP)]);
$outputCost := 0.03 + 0.015 * ($outMP - 1);
inputs.images.connected
? {
"type":"range_usd",
"min_usd": $outputCost + 0.015,
"max_usd": $outputCost + 0.12,
"format": { "approximate": true }
}
: {"type":"usd","usd": $outputCost}
)
"""
@classmethod
def define_schema(cls) -> IO.Schema:
@ -563,6 +587,10 @@ class Flux2ProImageNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["width", "height"], inputs=["images"]),
expr=cls.PRICE_BADGE_EXPR,
),
)
@classmethod
@ -623,6 +651,22 @@ class Flux2MaxImageNode(Flux2ProImageNode):
NODE_ID = "Flux2MaxImageNode"
DISPLAY_NAME = "Flux.2 [max] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-max/generate"
PRICE_BADGE_EXPR = """
(
$MP := 1024 * 1024;
$outMP := $max([1, $floor(((widgets.width * widgets.height) + $MP - 1) / $MP)]);
$outputCost := 0.07 + 0.03 * ($outMP - 1);
inputs.images.connected
? {
"type":"range_usd",
"min_usd": $outputCost + 0.03,
"max_usd": $outputCost + 0.24,
"format": { "approximate": true }
}
: {"type":"usd","usd": $outputCost}
)
"""
class BFLExtension(ComfyExtension):

View File

@ -126,6 +126,9 @@ class ByteDanceImageNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03}""",
),
)
@classmethod
@ -367,6 +370,19 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$price := $contains(widgets.model, "seedream-4-5-251128") ? 0.04 : 0.03;
{
"type":"usd",
"usd": $price,
"format": { "suffix":" x images/Run", "approximate": true }
}
)
""",
),
)
@classmethod
@ -522,6 +538,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -632,6 +649,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -754,6 +772,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -877,6 +896,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -946,6 +966,52 @@ def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
)
PRICE_BADGE_VIDEO = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]),
expr="""
(
$priceByModel := {
"seedance-1-0-pro": {
"480p":[0.23,0.24],
"720p":[0.51,0.56],
"1080p":[1.18,1.22]
},
"seedance-1-0-pro-fast": {
"480p":[0.09,0.1],
"720p":[0.21,0.23],
"1080p":[0.47,0.49]
},
"seedance-1-0-lite": {
"480p":[0.17,0.18],
"720p":[0.37,0.41],
"1080p":[0.85,0.88]
}
};
$model := widgets.model;
$modelKey :=
$contains($model, "seedance-1-0-pro-fast") ? "seedance-1-0-pro-fast" :
$contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" :
"seedance-1-0-lite";
$resolution := widgets.resolution;
$resKey :=
$contains($resolution, "1080") ? "1080p" :
$contains($resolution, "720") ? "720p" :
"480p";
$modelPrices := $lookup($priceByModel, $modelKey);
$baseRange := $lookup($modelPrices, $resKey);
$min10s := $baseRange[0];
$max10s := $baseRange[1];
$scale := widgets.duration / 10;
$minCost := $min10s * $scale;
$maxCost := $max10s * $scale;
($minCost = $maxCost)
? {"type":"usd","usd": $minCost}
: {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost}
)
""",
)
class ByteDanceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:

View File

@ -319,6 +319,30 @@ class GeminiNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$m := widgets.model;
$contains($m, "gemini-2.5-flash") ? {
"type": "list_usd",
"usd": [0.0003, 0.0025],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens"}
}
: $contains($m, "gemini-2.5-pro") ? {
"type": "list_usd",
"usd": [0.00125, 0.01],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gemini-3-pro-preview") ? {
"type": "list_usd",
"usd": [0.002, 0.012],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: {"type":"text", "text":"Token-based"}
)
""",
),
)
@classmethod
@ -580,6 +604,9 @@ class GeminiImage(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.039,"format":{"suffix":"/Image (1K)","approximate":true}}""",
),
)
@classmethod
@ -710,6 +737,19 @@ class GeminiImage2(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
expr="""
(
$r := widgets.resolution;
($contains($r,"1k") or $contains($r,"2k"))
? {"type":"usd","usd":0.134,"format":{"suffix":"/Image","approximate":true}}
: $contains($r,"4k")
? {"type":"usd","usd":0.24,"format":{"suffix":"/Image","approximate":true}}
: {"type":"text","text":"Token-based"}
)
""",
),
)
@classmethod

View File

@ -236,7 +236,6 @@ class IdeogramV1(IO.ComfyNode):
display_name="Ideogram V1",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V1 model.",
is_api_node=True,
inputs=[
IO.String.Input(
"prompt",
@ -298,6 +297,17 @@ class IdeogramV1(IO.ComfyNode):
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]),
expr="""
(
$n := widgets.num_images;
$base := (widgets.turbo = true) ? 0.0286 : 0.0858;
{"type":"usd","usd": $round($base * $n, 2)}
)
""",
),
)
@classmethod
@ -351,7 +361,6 @@ class IdeogramV2(IO.ComfyNode):
display_name="Ideogram V2",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V2 model.",
is_api_node=True,
inputs=[
IO.String.Input(
"prompt",
@ -436,6 +445,17 @@ class IdeogramV2(IO.ComfyNode):
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]),
expr="""
(
$n := widgets.num_images;
$base := (widgets.turbo = true) ? 0.0715 : 0.1144;
{"type":"usd","usd": $round($base * $n, 2)}
)
""",
),
)
@classmethod
@ -506,7 +526,6 @@ class IdeogramV3(IO.ComfyNode):
category="api node/image/Ideogram",
description="Generates images using the Ideogram V3 model. "
"Supports both regular image generation from text prompts and image editing with mask.",
is_api_node=True,
inputs=[
IO.String.Input(
"prompt",
@ -591,6 +610,23 @@ class IdeogramV3(IO.ComfyNode):
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["rendering_speed", "num_images"], inputs=["character_image"]),
expr="""
(
$n := widgets.num_images;
$speed := widgets.rendering_speed;
$hasChar := inputs.character_image.connected;
$base :=
$contains($speed,"quality") ? ($hasChar ? 0.286 : 0.1287) :
$contains($speed,"default") ? ($hasChar ? 0.2145 : 0.0858) :
$contains($speed,"turbo") ? ($hasChar ? 0.143 : 0.0429) :
0.0858;
{"type":"usd","usd": $round($base * $n, 2)}
)
""",
),
)
@classmethod

View File

@ -764,6 +764,33 @@ class KlingTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
expr="""
(
$m := widgets.mode;
$contains($m,"v2-5-turbo")
? ($contains($m,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35})
: $contains($m,"v2-1-master")
? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4})
: $contains($m,"v2-master")
? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4})
: $contains($m,"v1-6")
? (
$contains($m,"pro")
? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: ($contains($m,"10s") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28})
)
: $contains($m,"v1")
? (
$contains($m,"pro")
? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: ($contains($m,"10s") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14})
)
: {"type":"usd","usd":0.14}
)
""",
),
)
@classmethod
@ -818,6 +845,16 @@ class OmniProTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
$rates := {"std": 0.084, "pro": 0.112};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
),
)
@classmethod
@ -886,6 +923,16 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
$rates := {"std": 0.084, "pro": 0.112};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
),
)
@classmethod
@ -981,6 +1028,16 @@ class OmniProImageToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
$rates := {"std": 0.084, "pro": 0.112};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
),
)
@classmethod
@ -1056,6 +1113,16 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
$rates := {"std": 0.126, "pro": 0.168};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
),
)
@classmethod
@ -1142,6 +1209,16 @@ class OmniProEditVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
$rates := {"std": 0.126, "pro": 0.168};
{"type":"usd","usd": $lookup($rates, $mode), "format":{"suffix":"/second"}}
)
""",
),
)
@classmethod
@ -1228,6 +1305,9 @@ class OmniProImageNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.028}""",
),
)
@classmethod
@ -1313,6 +1393,9 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.14}""",
),
)
@classmethod
@ -1375,6 +1458,33 @@ class KlingImage2VideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode", "model_name", "duration"]),
expr="""
(
$mode := widgets.mode;
$model := widgets.model_name;
$dur := widgets.duration;
$contains($model,"v2-5-turbo")
? ($contains($dur,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35})
: ($contains($model,"v2-1-master") or $contains($model,"v2-master"))
? ($contains($dur,"10") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4})
: ($contains($model,"v2-1") or $contains($model,"v1-6") or $contains($model,"v1-5"))
? (
$contains($mode,"pro")
? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: ($contains($dur,"10") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28})
)
: $contains($model,"v1")
? (
$contains($mode,"pro")
? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: ($contains($dur,"10") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14})
)
: {"type":"usd","usd":0.14}
)
""",
),
)
@classmethod
@ -1448,6 +1558,9 @@ class KlingCameraControlI2VNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.49}""",
),
)
@classmethod
@ -1518,6 +1631,33 @@ class KlingStartEndFrameNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
expr="""
(
$m := widgets.mode;
$contains($m,"v2-5-turbo")
? ($contains($m,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35})
: $contains($m,"v2-1")
? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: $contains($m,"v2-master")
? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4})
: $contains($m,"v1-6")
? (
$contains($m,"pro")
? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: ($contains($m,"10s") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28})
)
: $contains($m,"v1")
? (
$contains($m,"pro")
? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: ($contains($m,"10s") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14})
)
: {"type":"usd","usd":0.14}
)
""",
),
)
@classmethod
@ -1583,6 +1723,9 @@ class KlingVideoExtendNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.28}""",
),
)
@classmethod
@ -1664,6 +1807,29 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode", "model_name", "duration"]),
expr="""
(
$mode := widgets.mode;
$model := widgets.model_name;
$dur := widgets.duration;
($contains($model,"v1-6") or $contains($model,"v1-5"))
? (
$contains($mode,"pro")
? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: ($contains($dur,"10") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28})
)
: $contains($model,"v1")
? (
$contains($mode,"pro")
? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: ($contains($dur,"10") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14})
)
: {"type":"usd","usd":0.14}
)
""",
),
)
@classmethod
@ -1728,6 +1894,16 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["effect_scene"]),
expr="""
(
($contains(widgets.effect_scene,"dizzydizzy") or $contains(widgets.effect_scene,"bloombloom"))
? {"type":"usd","usd":0.49}
: {"type":"usd","usd":0.28}
)
""",
),
)
@classmethod
@ -1782,6 +1958,9 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.1,"format":{"approximate":true}}""",
),
)
@classmethod
@ -1842,6 +2021,9 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.1,"format":{"approximate":true}}""",
),
)
@classmethod
@ -1892,6 +2074,9 @@ class KlingVirtualTryOnNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.7}""",
),
)
@classmethod
@ -1991,6 +2176,19 @@ class KlingImageGenerationNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model_name", "n"], inputs=["image"]),
expr="""
(
$m := widgets.model_name;
$base :=
$contains($m,"kling-v1-5")
? (inputs.image.connected ? 0.028 : 0.014)
: ($contains($m,"kling-v1") ? 0.0035 : 0.014);
{"type":"usd","usd": $base * widgets.n}
)
""",
),
)
@classmethod
@ -2074,6 +2272,10 @@ class TextToVideoWithAudio(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "generate_audio"]),
expr="""{"type":"usd","usd": 0.07 * widgets.duration * (widgets.generate_audio ? 2 : 1)}""",
),
)
@classmethod
@ -2138,6 +2340,10 @@ class ImageToVideoWithAudio(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "generate_audio"]),
expr="""{"type":"usd","usd": 0.07 * widgets.duration * (widgets.generate_audio ? 2 : 1)}""",
),
)
@classmethod
@ -2218,6 +2424,15 @@ class MotionControl(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
expr="""
(
$prices := {"std": 0.07, "pro": 0.112};
{"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}}
)
""",
),
)
@classmethod

View File

@ -28,6 +28,22 @@ class ExecuteTaskRequest(BaseModel):
image_uri: str | None = Field(None)
PRICE_BADGE = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]),
expr="""
(
$prices := {
"ltx-2 (pro)": {"1920x1080":0.06,"2560x1440":0.12,"3840x2160":0.24},
"ltx-2 (fast)": {"1920x1080":0.04,"2560x1440":0.08,"3840x2160":0.16}
};
$modelPrices := $lookup($prices, $lowercase(widgets.model));
$pps := $lookup($modelPrices, widgets.resolution);
{"type":"usd","usd": $pps * widgets.duration}
)
""",
)
class TextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
@ -69,6 +85,7 @@ class TextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE,
)
@classmethod
@ -145,6 +162,7 @@ class ImageToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE,
)
@classmethod

View File

@ -189,6 +189,19 @@ class LumaImageGenerationNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$m := widgets.model;
$contains($m,"photon-flash-1")
? {"type":"usd","usd":0.0027}
: $contains($m,"photon-1")
? {"type":"usd","usd":0.0104}
: {"type":"usd","usd":0.0246}
)
""",
),
)
@classmethod
@ -303,6 +316,19 @@ class LumaImageModifyNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$m := widgets.model;
$contains($m,"photon-flash-1")
? {"type":"usd","usd":0.0027}
: $contains($m,"photon-1")
? {"type":"usd","usd":0.0104}
: {"type":"usd","usd":0.0246}
)
""",
),
)
@classmethod
@ -395,6 +421,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -505,6 +532,8 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -568,6 +597,53 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
return LumaKeyframes(frame0=frame0, frame1=frame1)
PRICE_BADGE_VIDEO = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "resolution", "duration"]),
expr="""
(
$p := {
"ray-flash-2": {
"5s": {"4k":3.13,"1080p":0.79,"720p":0.34,"540p":0.2},
"9s": {"4k":5.65,"1080p":1.42,"720p":0.61,"540p":0.36}
},
"ray-2": {
"5s": {"4k":9.11,"1080p":2.27,"720p":1.02,"540p":0.57},
"9s": {"4k":16.4,"1080p":4.1,"720p":1.83,"540p":1.03}
}
};
$m := widgets.model;
$d := widgets.duration;
$r := widgets.resolution;
$modelKey :=
$contains($m,"ray-flash-2") ? "ray-flash-2" :
$contains($m,"ray-2") ? "ray-2" :
$contains($m,"ray-1-6") ? "ray-1-6" :
"other";
$durKey := $contains($d,"5s") ? "5s" : $contains($d,"9s") ? "9s" : "";
$resKey :=
$contains($r,"4k") ? "4k" :
$contains($r,"1080p") ? "1080p" :
$contains($r,"720p") ? "720p" :
$contains($r,"540p") ? "540p" : "";
$modelPrices := $lookup($p, $modelKey);
$durPrices := $lookup($modelPrices, $durKey);
$v := $lookup($durPrices, $resKey);
$price :=
($modelKey = "ray-1-6") ? 0.5 :
($modelKey = "other") ? 0.79 :
($exists($v) ? $v : 0.79);
{"type":"usd","usd": $price}
)
""",
)
class LumaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:

View File

@ -134,6 +134,9 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.43}""",
),
)
@classmethod
@ -197,6 +200,9 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.43}""",
),
)
@classmethod
@ -340,6 +346,20 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution", "duration"]),
expr="""
(
$prices := {
"768p": {"6": 0.28, "10": 0.56},
"1080p": {"6": 0.49}
};
$resPrices := $lookup($prices, $lowercase(widgets.resolution));
$price := $lookup($resPrices, $string(widgets.duration));
{"type":"usd","usd": $price ? $price : 0.43}
)
""",
),
)
@classmethod

View File

@ -233,6 +233,10 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(),
expr="""{"type":"usd","usd": 1.5}""",
),
)
@classmethod
@ -351,6 +355,10 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(),
expr="""{"type":"usd","usd": 2.25}""",
),
)
@classmethod
@ -471,6 +479,10 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(),
expr="""{"type":"usd","usd": 1.5}""",
),
)
@classmethod

View File

@ -160,6 +160,23 @@ class OpenAIDalle2(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["size", "n"]),
expr="""
(
$size := widgets.size;
$nRaw := widgets.n;
$n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1;
$base :=
$contains($size, "256x256") ? 0.016 :
$contains($size, "512x512") ? 0.018 :
0.02;
{"type":"usd","usd": $round($base * $n, 3)}
)
""",
),
)
@classmethod
@ -287,6 +304,25 @@ class OpenAIDalle3(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["size", "quality"]),
expr="""
(
$size := widgets.size;
$q := widgets.quality;
$hd := $contains($q, "hd");
$price :=
$contains($size, "1024x1024")
? ($hd ? 0.08 : 0.04)
: (($contains($size, "1792x1024") or $contains($size, "1024x1792"))
? ($hd ? 0.12 : 0.08)
: 0.04);
{"type":"usd","usd": $price}
)
""",
),
)
@classmethod
@ -411,6 +447,28 @@ class OpenAIGPTImage1(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]),
expr="""
(
$ranges := {
"low": [0.011, 0.02],
"medium": [0.046, 0.07],
"high": [0.167, 0.3]
};
$range := $lookup($ranges, widgets.quality);
$n := widgets.n;
($n = 1)
? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1]}
: {
"type":"range_usd",
"min_usd": $range[0],
"max_usd": $range[1],
"format": { "suffix": " x " & $string($n) & "/Run" }
}
)
""",
),
)
@classmethod
@ -566,6 +624,75 @@ class OpenAIChatNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$m := widgets.model;
$contains($m, "o4-mini") ? {
"type": "list_usd",
"usd": [0.0011, 0.0044],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "o1-pro") ? {
"type": "list_usd",
"usd": [0.15, 0.6],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "o1") ? {
"type": "list_usd",
"usd": [0.015, 0.06],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "o3-mini") ? {
"type": "list_usd",
"usd": [0.0011, 0.0044],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "o3") ? {
"type": "list_usd",
"usd": [0.01, 0.04],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-4o") ? {
"type": "list_usd",
"usd": [0.0025, 0.01],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-4.1-nano") ? {
"type": "list_usd",
"usd": [0.0001, 0.0004],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-4.1-mini") ? {
"type": "list_usd",
"usd": [0.0004, 0.0016],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-4.1") ? {
"type": "list_usd",
"usd": [0.002, 0.008],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-5-nano") ? {
"type": "list_usd",
"usd": [0.00005, 0.0004],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-5-mini") ? {
"type": "list_usd",
"usd": [0.00025, 0.002],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-5") ? {
"type": "list_usd",
"usd": [0.00125, 0.01],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: {"type": "text", "text": "Token-based"}
)
""",
),
)
@classmethod

View File

@ -128,6 +128,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -242,6 +243,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -355,6 +357,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -416,6 +419,33 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
PRICE_BADGE_VIDEO = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration_seconds", "quality", "motion_mode"]),
expr="""
(
$prices := {
"5": {
"1080p": {"normal": 1.2, "fast": 1.2},
"720p": {"normal": 0.6, "fast": 1.2},
"540p": {"normal": 0.45, "fast": 0.9},
"360p": {"normal": 0.45, "fast": 0.9}
},
"8": {
"1080p": {"normal": 1.2, "fast": 1.2},
"720p": {"normal": 1.2, "fast": 1.2},
"540p": {"normal": 0.9, "fast": 1.2},
"360p": {"normal": 0.9, "fast": 1.2}
}
};
$durPrices := $lookup($prices, $string(widgets.duration_seconds));
$qualityPrices := $lookup($durPrices, widgets.quality);
$price := $lookup($qualityPrices, widgets.motion_mode);
{"type":"usd","usd": $price ? $price : 0.9}
)
""",
)
class PixVerseExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:

View File

@ -378,6 +378,10 @@ class RecraftTextToImageNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["n"]),
expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""",
),
)
@classmethod
@ -490,6 +494,10 @@ class RecraftImageToImageNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["n"]),
expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""",
),
)
@classmethod
@ -591,6 +599,10 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["n"]),
expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""",
),
)
@classmethod
@ -692,6 +704,10 @@ class RecraftTextToVectorNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["n"]),
expr="""{"type":"usd","usd": $round(0.08 * widgets.n, 2)}""",
),
)
@classmethod
@ -759,6 +775,10 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(),
expr="""{"type":"usd","usd": 0.01}""",
),
)
@classmethod
@ -817,6 +837,9 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.04}""",
),
)
@classmethod
@ -883,6 +906,9 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.01}""",
),
)
@classmethod
@ -929,6 +955,9 @@ class RecraftCrispUpscaleNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.004}""",
),
)
@classmethod
@ -972,6 +1001,9 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.25}""",
),
)

View File

@ -241,6 +241,9 @@ class Rodin3D_Regular(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
@ -294,6 +297,9 @@ class Rodin3D_Detail(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
@ -347,6 +353,9 @@ class Rodin3D_Smooth(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
@ -406,6 +415,9 @@ class Rodin3D_Sketch(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod

View File

@ -184,6 +184,10 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration"]),
expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""",
),
)
@classmethod
@ -274,6 +278,10 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration"]),
expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""",
),
)
@classmethod
@ -372,6 +380,10 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration"]),
expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""",
),
)
@classmethod
@ -457,6 +469,9 @@ class RunwayTextToImageNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.11}""",
),
)
@classmethod

View File

@ -89,6 +89,24 @@ class OpenAIVideoSora2(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "size", "duration"]),
expr="""
(
$m := widgets.model;
$size := widgets.size;
$dur := widgets.duration;
$isPro := $contains($m, "sora-2-pro");
$isSora2 := $contains($m, "sora-2");
$isProSize := ($size = "1024x1792" or $size = "1792x1024");
$perSec :=
$isPro ? ($isProSize ? 0.5 : 0.3) :
$isSora2 ? 0.1 :
($isProSize ? 0.5 : 0.1);
{"type":"usd","usd": $round($perSec * $dur, 2)}
)
""",
),
)
@classmethod

View File

@ -127,6 +127,9 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.08}""",
),
)
@classmethod
@ -264,6 +267,16 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$contains(widgets.model,"large")
? {"type":"usd","usd":0.065}
: {"type":"usd","usd":0.035}
)
""",
),
)
@classmethod
@ -382,6 +395,9 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.25}""",
),
)
@classmethod
@ -486,6 +502,9 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.25}""",
),
)
@classmethod
@ -566,6 +585,9 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.01}""",
),
)
@classmethod
@ -648,6 +670,9 @@ class StabilityTextToAudio(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.2}""",
),
)
@classmethod
@ -732,6 +757,9 @@ class StabilityAudioToAudio(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.2}""",
),
)
@classmethod
@ -828,6 +856,9 @@ class StabilityAudioInpaint(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.2}""",
),
)
@classmethod

View File

@ -117,6 +117,38 @@ class TripoTextToModelNode(IO.ComfyNode):
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=[
"model_version",
"style",
"texture",
"pbr",
"quad",
"texture_quality",
"geometry_quality",
],
),
expr="""
(
$isV14 := $contains(widgets.model_version,"v1.4");
$style := widgets.style;
$hasStyle := ($style != "" and $style != "none");
$withTexture := widgets.texture or widgets.pbr;
$isHdTexture := (widgets.texture_quality = "detailed");
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
$baseCredits :=
$isV14 ? 20 : ($withTexture ? 20 : 10);
$credits :=
$baseCredits
+ ($hasStyle ? 5 : 0)
+ (widgets.quad ? 5 : 0)
+ ($isHdTexture ? 10 : 0)
+ ($isDetailedGeometry ? 20 : 0);
{"type":"usd","usd": $round($credits * 0.01, 2)}
)
""",
),
)
@classmethod
@ -210,6 +242,38 @@ class TripoImageToModelNode(IO.ComfyNode):
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=[
"model_version",
"style",
"texture",
"pbr",
"quad",
"texture_quality",
"geometry_quality",
],
),
expr="""
(
$isV14 := $contains(widgets.model_version,"v1.4");
$style := widgets.style;
$hasStyle := ($style != "" and $style != "none");
$withTexture := widgets.texture or widgets.pbr;
$isHdTexture := (widgets.texture_quality = "detailed");
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
$baseCredits :=
$isV14 ? 30 : ($withTexture ? 30 : 20);
$credits :=
$baseCredits
+ ($hasStyle ? 5 : 0)
+ (widgets.quad ? 5 : 0)
+ ($isHdTexture ? 10 : 0)
+ ($isDetailedGeometry ? 20 : 0);
{"type":"usd","usd": $round($credits * 0.01, 2)}
)
""",
),
)
@classmethod
@ -314,6 +378,34 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=[
"model_version",
"texture",
"pbr",
"quad",
"texture_quality",
"geometry_quality",
],
),
expr="""
(
$isV14 := $contains(widgets.model_version,"v1.4");
$withTexture := widgets.texture or widgets.pbr;
$isHdTexture := (widgets.texture_quality = "detailed");
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
$baseCredits :=
$isV14 ? 30 : ($withTexture ? 30 : 20);
$credits :=
$baseCredits
+ (widgets.quad ? 5 : 0)
+ ($isHdTexture ? 10 : 0)
+ ($isDetailedGeometry ? 20 : 0);
{"type":"usd","usd": $round($credits * 0.01, 2)}
)
""",
),
)
@classmethod
@ -405,6 +497,15 @@ class TripoTextureNode(IO.ComfyNode):
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["texture_quality"]),
expr="""
(
$tq := widgets.texture_quality;
{"type":"usd","usd": ($contains($tq,"detailed") ? 0.2 : 0.1)}
)
""",
),
)
@classmethod
@ -456,6 +557,9 @@ class TripoRefineNode(IO.ComfyNode):
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.3}""",
),
)
@classmethod
@ -489,6 +593,9 @@ class TripoRigNode(IO.ComfyNode):
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.25}""",
),
)
@classmethod
@ -545,6 +652,9 @@ class TripoRetargetNode(IO.ComfyNode):
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.1}""",
),
)
@classmethod
@ -638,6 +748,60 @@ class TripoConversionNode(IO.ComfyNode):
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=[
"quad",
"face_limit",
"texture_size",
"texture_format",
"force_symmetry",
"flatten_bottom",
"flatten_bottom_threshold",
"pivot_to_center_bottom",
"scale_factor",
"with_animation",
"pack_uv",
"bake",
"part_names",
"fbx_preset",
"export_vertex_colors",
"export_orientation",
"animate_in_place",
],
),
expr="""
(
$face := (widgets.face_limit != null) ? widgets.face_limit : -1;
$texSize := (widgets.texture_size != null) ? widgets.texture_size : 4096;
$flatThresh := (widgets.flatten_bottom_threshold != null) ? widgets.flatten_bottom_threshold : 0;
$scale := (widgets.scale_factor != null) ? widgets.scale_factor : 1;
$texFmt := (widgets.texture_format != "" ? widgets.texture_format : "jpeg");
$part := widgets.part_names;
$fbx := (widgets.fbx_preset != "" ? widgets.fbx_preset : "blender");
$orient := (widgets.export_orientation != "" ? widgets.export_orientation : "default");
$advanced :=
widgets.quad or
widgets.force_symmetry or
widgets.flatten_bottom or
widgets.pivot_to_center_bottom or
widgets.with_animation or
widgets.pack_uv or
widgets.bake or
widgets.export_vertex_colors or
widgets.animate_in_place or
($face != -1) or
($texSize != 4096) or
($flatThresh != 0) or
($scale != 1) or
($texFmt != "jpeg") or
($part != "") or
($fbx != "blender") or
($orient != "default");
{"type":"usd","usd": ($advanced ? 0.1 : 0.05)}
)
""",
),
)
@classmethod

View File

@ -122,6 +122,10 @@ class VeoVideoGenerationNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration_seconds"]),
expr="""{"type":"usd","usd": 0.5 * widgets.duration_seconds}""",
),
)
@classmethod
@ -347,6 +351,20 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]),
expr="""
(
$m := widgets.model;
$a := widgets.generate_audio;
($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate"))
? {"type":"usd","usd": ($a ? 1.2 : 0.8)}
: ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate"))
? {"type":"usd","usd": ($a ? 3.2 : 1.6)}
: {"type":"range_usd","min_usd":0.8,"max_usd":3.2}
)
""",
),
)
@ -420,6 +438,30 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]),
expr="""
(
$prices := {
"veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 },
"veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 }
};
$m := widgets.model;
$ga := (widgets.generate_audio = "true");
$seconds := widgets.duration;
$modelKey :=
$contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" :
$contains($m, "veo-3.1-generate") ? "veo-3.1-generate" :
"";
$audioKey := $ga ? "audio" : "no_audio";
$modelPrices := $lookup($prices, $modelKey);
$pps := $lookup($modelPrices, $audioKey);
($pps != null)
? {"type":"usd","usd": $pps * $seconds}
: {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2}
)
""",
),
)
@classmethod

View File

@ -121,6 +121,9 @@ class ViduTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
@ -214,6 +217,9 @@ class ViduImageToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
@ -317,6 +323,9 @@ class ViduReferenceVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
@ -426,6 +435,9 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
@ -507,6 +519,17 @@ class Vidu2TextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
expr="""
(
$is1080 := widgets.resolution = "1080p";
$base := $is1080 ? 0.1 : 0.075;
$perSec := $is1080 ? 0.05 : 0.025;
{"type":"usd","usd": $base + $perSec * (widgets.duration - 1)}
)
""",
),
)
@classmethod
@ -594,6 +617,39 @@ class Vidu2ImageToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]),
expr="""
(
$m := widgets.model;
$d := widgets.duration;
$is1080 := widgets.resolution = "1080p";
$contains($m, "pro-fast")
? (
$base := $is1080 ? 0.08 : 0.04;
$perSec := $is1080 ? 0.02 : 0.01;
{"type":"usd","usd": $base + $perSec * ($d - 1)}
)
: $contains($m, "pro")
? (
$base := $is1080 ? 0.275 : 0.075;
$perSec := $is1080 ? 0.075 : 0.05;
{"type":"usd","usd": $base + $perSec * ($d - 1)}
)
: $contains($m, "turbo")
? (
$is1080
? {"type":"usd","usd": 0.175 + 0.05 * ($d - 1)}
: (
$d <= 1 ? {"type":"usd","usd": 0.04}
: $d <= 2 ? {"type":"usd","usd": 0.05}
: {"type":"usd","usd": 0.05 + 0.05 * ($d - 2)}
)
)
: {"type":"usd","usd": 0.04}
)
""",
),
)
@classmethod
@ -698,6 +754,18 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["audio", "duration", "resolution"]),
expr="""
(
$is1080 := widgets.resolution = "1080p";
$base := $is1080 ? 0.375 : 0.125;
$perSec := $is1080 ? 0.05 : 0.025;
$audioCost := widgets.audio = true ? 0.075 : 0;
{"type":"usd","usd": $base + $perSec * (widgets.duration - 1) + $audioCost}
)
""",
),
)
@classmethod
@ -804,6 +872,38 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]),
expr="""
(
$m := widgets.model;
$d := widgets.duration;
$is1080 := widgets.resolution = "1080p";
$contains($m, "pro-fast")
? (
$base := $is1080 ? 0.08 : 0.04;
$perSec := $is1080 ? 0.02 : 0.01;
{"type":"usd","usd": $base + $perSec * ($d - 1)}
)
: $contains($m, "pro")
? (
$base := $is1080 ? 0.275 : 0.075;
$perSec := $is1080 ? 0.075 : 0.05;
{"type":"usd","usd": $base + $perSec * ($d - 1)}
)
: $contains($m, "turbo")
? (
$is1080
? {"type":"usd","usd": 0.175 + 0.05 * ($d - 1)}
: (
$d <= 2 ? {"type":"usd","usd": 0.05}
: {"type":"usd","usd": 0.05 + 0.05 * ($d - 2)}
)
)
: {"type":"usd","usd": 0.04}
)
""",
),
)
@classmethod

View File

@ -244,6 +244,9 @@ class WanTextToImageApi(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03}""",
),
)
@classmethod
@ -363,6 +366,9 @@ class WanImageToImageApi(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03}""",
),
)
@classmethod
@ -520,6 +526,17 @@ class WanTextToVideoApi(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "size"]),
expr="""
(
$ppsTable := { "480p": 0.05, "720p": 0.1, "1080p": 0.15 };
$resKey := $substringBefore(widgets.size, ":");
$pps := $lookup($ppsTable, $resKey);
{ "type": "usd", "usd": $round($pps * widgets.duration, 2) }
)
""",
),
)
@classmethod
@ -681,6 +698,16 @@ class WanImageToVideoApi(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
expr="""
(
$ppsTable := { "480p": 0.05, "720p": 0.1, "1080p": 0.15 };
$pps := $lookup($ppsTable, widgets.resolution);
{ "type": "usd", "usd": $round($pps * widgets.duration, 2) }
)
""",
),
)
@classmethod
@ -828,6 +855,22 @@ class WanReferenceVideoApi(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["size", "duration"]),
expr="""
(
$rate := $contains(widgets.size, "1080p") ? 0.15 : 0.10;
$inputMin := 2 * $rate;
$inputMax := 5 * $rate;
$outputPrice := widgets.duration * $rate;
{
"type": "range_usd",
"min_usd": $inputMin + $outputPrice,
"max_usd": $inputMax + $outputPrice
}
)
""",
),
)
@classmethod

View File

@ -258,9 +258,9 @@ class ModelPatchLoader:
config['broken'] = True
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
return (model,)
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
model.load_state_dict(sd, assign=self.model_patcher.is_dynamic())
return (model_patcher,)
class DiffSynthCnetPatch:

View File

@ -254,6 +254,7 @@ class ResizeType(str, Enum):
SCALE_HEIGHT = "scale height"
SCALE_TOTAL_PIXELS = "scale total pixels"
MATCH_SIZE = "match size"
SCALE_TO_MULTIPLE = "scale to multiple"
def is_image(input: torch.Tensor) -> bool:
# images have 4 dimensions: [batch, height, width, channels]
@ -328,7 +329,7 @@ def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method
if height < width:
width = round((width / height) * shorter_size)
height = shorter_size
elif width > height:
elif width < height:
height = round((height / width) * shorter_size)
width = shorter_size
else:
@ -363,6 +364,43 @@ def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str
input = finalize_image_mask_input(input, is_type_image)
return input
def scale_to_multiple_cover(input: torch.Tensor, multiple: int, scale_method: str) -> torch.Tensor:
if multiple <= 1:
return input
is_type_image = is_image(input)
if is_type_image:
_, height, width, _ = input.shape
else:
_, height, width = input.shape
target_w = (width // multiple) * multiple
target_h = (height // multiple) * multiple
if target_w == 0 or target_h == 0:
return input
if target_w == width and target_h == height:
return input
s_w = target_w / width
s_h = target_h / height
if s_w >= s_h:
scaled_w = target_w
scaled_h = int(math.ceil(height * s_w))
if scaled_h < target_h:
scaled_h = target_h
else:
scaled_h = target_h
scaled_w = int(math.ceil(width * s_h))
if scaled_w < target_w:
scaled_w = target_w
input = init_image_mask_input(input, is_type_image)
input = comfy.utils.common_upscale(input, scaled_w, scaled_h, scale_method, "disabled")
input = finalize_image_mask_input(input, is_type_image)
x0 = (scaled_w - target_w) // 2
y0 = (scaled_h - target_h) // 2
x1 = x0 + target_w
y1 = y0 + target_h
if is_type_image:
return input[:, y0:y1, x0:x1, :]
return input[:, y0:y1, x0:x1]
class ResizeImageMaskNode(io.ComfyNode):
scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
@ -378,6 +416,7 @@ class ResizeImageMaskNode(io.ComfyNode):
longer_size: int
shorter_size: int
megapixels: float
multiple: int
@classmethod
def define_schema(cls):
@ -417,6 +456,9 @@ class ResizeImageMaskNode(io.ComfyNode):
io.MultiType.Input("match", [io.Image, io.Mask]),
crop_combo,
]),
io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [
io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1),
]),
]),
io.Combo.Input("scale_method", options=cls.scale_methods, default="area"),
],
@ -442,6 +484,8 @@ class ResizeImageMaskNode(io.ComfyNode):
return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method))
elif selected_type == ResizeType.MATCH_SIZE:
return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"]))
elif selected_type == ResizeType.SCALE_TO_MULTIPLE:
return io.NodeOutput(scale_to_multiple_cover(input, resize_type["multiple"], scale_method))
raise ValueError(f"Unsupported resize type: {selected_type}")
def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None:

View File

@ -1,8 +1,10 @@
import os
import importlib.util
from comfy.cli_args import args, PerformanceFeature
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import subprocess
import comfy_aimdo.control
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names():
if os.name == 'nt':
@ -85,8 +87,14 @@ if not args.cuda_malloc:
except:
pass
if enables_dynamic_vram() and comfy_aimdo.control.lib is not None:
args.cuda_malloc = False
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ""
if args.cuda_malloc and not args.disable_cuda_malloc:
if args.disable_cuda_malloc:
args.cuda_malloc = False
if args.cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"

View File

@ -1,3 +1,4 @@
import gc
import copy
import heapq
import inspect
@ -9,9 +10,11 @@ import traceback
from enum import Enum
from typing import List, Literal, NamedTuple, Optional, Union
import asyncio
from contextlib import nullcontext
import torch
import comfy.memory_management
import comfy.model_management
from latent_preview import set_preview_method
import nodes
@ -515,7 +518,21 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
#Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
#that we just want to cull out each model run.
allocator = comfy.memory_management.aimdo_allocator
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
torch.cuda.synchronize()
if allocator is not None:
#FIXME: this is probably a little zealous
# Torch code comments says some stuff about not actually freeing tensors on mempool
#context release. Explicitly garbage collect now.
gc.collect()
torch.cuda.empty_cache()
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id)

35
main.py
View File

@ -5,7 +5,7 @@ import os
import importlib.util
import folder_paths
import time
from comfy.cli_args import args
from comfy.cli_args import args, enables_dynamic_vram
from app.logger import setup_logger
from app.assets.scanner import seed_assets
import itertools
@ -173,6 +173,30 @@ import gc
if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
has_aimdo = False
import comfy_aimdo.control
if comfy_aimdo.control.lib is not None:
if args.verbose == 'DEBUG':
comfy_aimdo.control.set_log_debug()
elif args.verbose == 'CRITICAL':
comfy_aimdo.control.set_log_critical()
elif args.verbose == 'ERROR':
comfy_aimdo.control.set_log_error()
elif args.verbose == 'WARNING':
comfy_aimdo.control.set_log_warning()
else: #INFO
comfy_aimdo.control.set_log_info()
if enables_dynamic_vram():
logging.info("DynamicVRAM support detected and enabled")
has_aimdo = True
else:
if enables_dynamic_vram():
logging.info("No native comfy-aimdo install detected. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
import comfy.utils
import execution
@ -184,6 +208,15 @@ import comfyui_version
import app.logger
import hook_breaker_ac10a0
import comfy.memory_management
import comfy.model_patcher
if has_aimdo:
comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic
comfy_aimdo.control.init_vram_guard(comfy.model_management.get_torch_device().index)
else:
comfy.memory_management.aimdo_allocator = None
def cuda_malloc_warning():
device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device)

View File

@ -22,6 +22,7 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.6
comfy-aimdo>=0.1.1
#non essential dependencies:
kornia>=0.7.1