mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-28 02:39:26 +08:00
* [Partner Nodes] feat: add Krea 2 Medium Turbo model (#14280)
* [Partner Nodes] feat: add seed input to Flux Erase node (#14283)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
* chore: update workflow templates to v0.9.98 (#14284)
* Bump comfyui-frontend-package to 1.45.15 (#14265)
* Fix ideogram if model dtype gets set to fp8. (#14291)
* Consolidate audio nodes into SaveAudioAdvanced node (CORE-202) (#13871)
* Enable cfg1 optimization for DualModelGuider with CFGGuider (#14290)
* Enable cfg1 optimization for DualModelGuider
* Fix CFG Override tooltip
* Fix interoperation with external source of pinned memory pressure (#14252)
* mm: split off registration helper to doer and headroom calc
* pinned_memory: implement registration comfy side
Move away from Aimdo buffer registrations which seem fraught with
danger and do it comfy side. Just start with the basic move.
* pinned_memory: do registrations as portable memory
* pinned_memory: discard async errors on registration fail
Like the good ol days.
* pinned_memory: implement abs shortfall retry
If pinned registration happens to fail despite the previous budget
ensures, consider the allocation shortfall, ensure it again, and
try again. This allows comfy pins to interoperate with other software
that might be doing substantive pinning.
* aimdo 049 (#14300)
* [Partner Nodes] feat: add new Gemini text node (#14299)
* [Partner Nodes] feat: add temperature and top_p to NanoBanan node (#14305)
* feat: add PreviewGaussianSplat + PreviewPointCloud nodes (#14194)
* Update AMD portable readme. (#14303)
* BE-1172 fix(3d): save Preview3DAdvanced / PreviewGaussianSplat / PreviewPointCloud to temp/, rename viewport input (#14294)
* feat(3d): reorder Preview3DAdvanced / PreviewGaussianSplat / PreviewPointCloud inputs and outputs (#14308)
* Update line endings check to ignore .ci files. (#14319)
* Use windows line endings for windows portable readmes. (#14334)
* Add SeedVR2 support (CORE-6) (#14110)
* chore: update embedded docs to v0.5.3 (#14350)
* Add Color primitive (#14260)
* Improve ResolutionSelector (#14309)
* feat(assets): extract image dimensions at ingest and emit on asset responses (#13991)
* feat(assets): extract image dimensions at ingest and emit on asset responses
Image assets now carry width/height under the existing `metadata` field on
asset responses, shaped as `{"kind": "image", "width": W, "height": H}`.
This lets consumers get original dimensions (e.g. for clients that render
server-side thumbnails and can't recover them from naturalWidth/Height)
without an extra round-trip.
Dimensions are written to AssetReference.system_metadata across three
ingest paths:
- Direct file ingest (upload, in-place registration): Pillow reads the
image header right after hashing, while the file is still in OS page
cache. Non-image MIME types are skipped without touching the file.
- From-hash registration: this path never reads the file bytes, so
dimensions are best-effort copied from any prior sibling reference of
the same asset that already carries kind=image metadata. Missing
siblings, non-image siblings, or absent dimension keys leave the new
reference's metadata unchanged.
- Scanner enrichment: extends the existing system_metadata write in
enrich_asset so scanner-registered images get the same treatment as
uploaded ones.
Existing system_metadata keys (e.g. safetensors fields written by the
enricher, download provenance) are preserved through merge. Existing
assets ingested before this change retain their current metadata — no
automatic backfill in this PR.
Tests cover image emission, non-image no-op, merge preservation, and the
from-hash sibling back-fill (including the no-sibling and non-image-sibling
cases).
* fix(assets): validate sibling dimensions before backfilling
Per CodeRabbit review on #13991: the previous loop accepted any sibling
with `kind == "image"` and copied whichever dimension keys happened to
be present, then returned. A partial sibling (kind set but missing or
invalid width/height) could persist incomplete metadata onto the new
reference even when a later sibling had valid dimensions.
Now we validate that the sibling has both width and height as positive
integers before adopting its dimensions, and continue scanning to the
next sibling otherwise.
* fix(assets): reject booleans in sibling dimension validation (use type-is)
Per CodeRabbit follow-up on #13991: bool is a subclass of int in Python,
so isinstance(True, int) is True. The previous strict-int gate would
have accepted width=True (truthy + > 0) as a valid dimension.
Realistic occurrence is low (extract_image_dimensions returns proper
ints, JSON doesn't serialize bools as numbers), but the validation gate
exists for defense-in-depth so it should be actually strict.
---------
Co-authored-by: guill <jacob.e.segal@gmail.com>
* Revert "Add SeedVR2 support (CORE-6) (#14110)" (#14359)
This reverts commit 7863cf0e53.
* chore(openapi): sync shared API contract from cloud@5273c30 (#14266)
* fix: Add back apply_rotary_emb for Qwen Image (#14364)
* Allow custom templates with Ideogram4 TE (#14374)
* main/server: Add --debug-hang (#14371)
Add an option to debug a hang with ctrl-C, dumping the backtraces to
see where its stuck or slow.
* Add LoRA key mapping for LTXV/LTXAV models (#14349)
* feat: Add model support for SCAIL-2 (#14373)
* initial SCAIL2 support
* Move bg_removal_model input socket to first position for nicer display (#14353)
* mm: dont reset cast buffers in cleanup_models_gc() (#14372)
cleanup_models_gc can be called once per load_models_gpu via
free_memory, which in turn can de-activate an active model via
this reset_cast_buffers.
cleanup_models_gc() could also come via obscure garbage collector
paths so limit reset_cast_buffers to the post-node callsite instead.
* Ensure conditions are not trainable to avoid bugs (#14368)
* feat: Add Bernini-R model support (Wan video) (CORE-279) (#14216)
* Depth anything 3 (Core-135) (#13853)
Co-authored-by: Alexis Rolland <alexisrolland@hotmail.com>
* Always enable cuda malloc on cu130 and higher. (#14381)
* chore(openapi): sync shared API contract from cloud@ca12913 (#14367)
* [Trainer/bug] Ensure model is not inference mode (CORE-72) (#13400)
* Ensure model is not inference mode
* force clone inside training mode to avoid inference tensor
* Allow force deepcopy for model patcher
* chore(assets): drop vestigial tags.tag_type column (#14248)
tag_type was always "user" in practice — no code path ever set it to anything
else (no system/seeded classification was wired up) and nothing queried it. The
column, its ix_tags_tag_type index, and the TagUsage.type API field were dead
weight, so they're removed. Adds alembic migration 0004 to drop the column and
index.
Verified: asset-seeder tests pass; migration applies cleanly on a fresh SQLite
(tags retains only name; tag_type column + index dropped).
Co-authored-by: guill <jacob.e.segal@gmail.com>
* feat(assets): cursor-based pagination on GET /api/assets (#14014)
* spec(assets): add cursor pagination params to GET /api/assets
Add 'after' query param and 'next_cursor' response field for keyset
pagination. Matches the cloud Go implementation (BE-893) so frontend
sees a unified contract across runtimes. Offset/limit remain as a
deprecated fallback.
* feat(assets): add cursor encode/decode helpers for keyset pagination
Port of cloud common/pagination/cursor.go. Wire format is base64url of
{"s", "v", "id"} JSON; times are Unix microseconds UTC to match
PostgreSQL timestamp precision.
Includes a byte-identity fixture pinned against the cloud Go wire
format so cross-runtime FE pagination can't silently drift.
* feat(assets): thread cursor through schemas, service, and query layer
list_assets_page accepts an opaque 'after' cursor and returns
next_cursor when more pages are available. The query applies a keyset
WHERE clause and a secondary ORDER BY id for deterministic tiebreak.
Cursor sort field is validated against the request sort, and a
last_access_time sort (OSS-only) falls back to offset/limit. Offset is
ignored whenever a cursor is supplied.
* feat(assets): wire cursor pagination through GET /api/assets handler
Adds integration tests for: full cursor walk, invalid-cursor 400,
sort/cursor mismatch 400, cursor-wins-over-offset, absent next_cursor
when no more results, and pagination stability across deletes.
* fix(assets): address cursor-review verified findings
- Mint next_cursor on every cursor-supported sort, not only when 'after'
was supplied. A first request (no 'after') previously returned
next_cursor=None, leaving cursor mode unreachable from a clean start.
- Over-fetch limit+1 so an exactly-full terminal page doesn't mint a
spurious cursor pointing at a phantom next page.
- Map crafted out-of-range microsecond cursors (OverflowError / OSError
in datetime construction) to 400 INVALID_CURSOR instead of leaking 500.
- Bump MAX_CURSOR_VALUE_LENGTH 256 -> 512 to match the AssetReference
name column max; without this, a long-named asset minted a cursor the
same server then refused on the next request. Cross-runtime byte
identity with cloud is unaffected because no cloud cursor ever carries
a value > 256 (cloud schema doesn't permit it).
- Return None from _encode_next_cursor when the boundary row carries a
NULL sort value (e.g. an Asset without size_bytes backfilled), instead
of silently encoding 0 and mis-positioning the keyset.
- Fix schemas_in.py comment so it matches actual handler behavior
(last_access_time + 'after' raises 400, does not fall back).
- Add AssetsApiError schema + 400 response to GET /api/assets in
openapi.yaml so generated clients know the INVALID_CURSOR envelope.
- Extend integration coverage: first-page mint, exact-multiple terminal
page, cursor walks for created_at/updated_at/size sorts, datetime
overflow surfaces as 400 not 500.
- Add unit coverage for datetime overflow and 512-char round-trip.
* feat(assets): bind cursor to sort order + Go-compat JSON escaping
Address three needs-judgment items from the cursor-review judge synthesis:
1. Cursor wire format now includes an "o" key carrying the sort
direction ("asc" / "desc") it was minted under. A request that
replays the cursor with a flipped `order` parameter is rejected
with 400 INVALID_CURSOR instead of silently walking the wrong
direction. Legacy cursors without "o" still decode (the binding
is best-effort until cloud mirrors the field — follow-up filed
separately).
2. JSON serialization now escapes `<`, `>`, `&`, U+2028, U+2029
to mirror Go's default `json.Marshal` behavior. Without this, an
asset name containing those characters produced different bytes on
Python vs cloud Go. The escaped form is what both runtimes emit.
3. Add direct query-layer tests for the keyset tiebreaker — the secondary
ORDER BY id branch was previously unexercised. Two scenarios: all
rows share a primary sort value, and mixed ties straddle page
boundaries. Both assert no row is dropped or duplicated across the
walk.
Wire-format note: Python cursors now differ from current cloud cursors
by exactly the "o" key. Cloud follow-up will bring the two back into
byte alignment.
* fix(assets): address bot review comments
- Soften offset param prose: it's not deprecated, just not preferred for
sequential walks. Random-access UIs (jump-to-page, item count displays)
legitimately still want offset, so dropping the 'deprecated' framing
rather than promoting it to a machine-readable deprecated:true flag.
- Add explicit HTTP status assertions before every json() / next_cursor
read in test_list_cursor.py so a failing request surfaces as an HTTP
error instead of a confusing KeyError on a 4xx/5xx body.
* feat(assets): require cursor o field, drop legacy permissive path
Cursor pagination hasn't shipped on either runtime yet — this PR is
still draft and cloud's mirror is just behind it — so there are no
legacy no-o cursors in the wild. Make o mandatory from day one
rather than landing permissive and tightening later.
decode_cursor now rejects any payload without o (or with a non-string
o) as malformed. CursorPayload.order becomes a required str. Tests
that constructed CursorPayload directly now pass order="desc";
test_legacy_cursor_without_order_accepted flips to
test_cursor_without_order_rejected.
* chore(assets): drop cross-repo prose from cursor comments
Strip prose references to sibling Go implementations and external
ticket IDs from cursor.py, the cursor tests, the keyset integration
tests, asset_management's sort-field comment, and the legacy
prompt_id alias comment. Pure docstring/comment scrub — no behavior
or wire-format changes. x-runtime: [cloud] field annotations in
openapi.yaml are unchanged; those are the spec's structural
cross-runtime convention, not internal references.
* test(assets): include 'o' in microsecond-boundary cursor payload
The boundary test was building a cursor without the required `o` key, so
decode failed on the missing-order branch before reaching the µs-overflow
path the test is asserting. Both paths return 400 INVALID_CURSOR so the
assertion passed for the wrong reason. Add `o` to the payload and matching
`order=` to the request so the decode reaches the intended branch.
* fix(assets): address ultrareview findings on cursor pagination
Six fact-checked findings from the multi-model review pass:
- Encoder/decoder length asymmetry: encode_cursor now rejects empty id,
oversized id (>128), oversized value (>512), and invalid order tokens
symmetrically with decode_cursor. Prevents the same server from minting
a cursor it then 400s on the next request (e.g. a filesystem-scanned
asset name >512 chars). The bad-order path now raises InvalidCursorError
(still subclasses ValueError) so route-layer handling stays uniform.
- Raw U+2028/U+2029 in cursor.py source: ripgrep treated those lines as
line-terminators, confirming the bytes were the actual separators. Any
editor save / autoformat / git tooling that normalizes invisibles would
silently break the encoder. Replaced with explicit
/
Python escape sequences.
- set(seen) == set(names) hid ordering regressions: a cursor walk that
dropped a row at a page boundary or returned duplicates could pass.
Reworked the assertion to (1) reject duplicates, (2) require full
coverage, and (3) assert strict positional order for size sort, the
only field with a clock-independent ordering.
- Flaky time.sleep(0.05) between inserts: Windows CI clock resolution is
~15ms, so back-to-back inserts under load could collide and exercise
the tiebreaker instead of the documented path. Removed the sleep and
let the strengthened assertion above carry coverage / no-duplicates,
with size sort carrying strict order.
- Cursor error envelope diverged from the rest of routes.py: cursor 400s
emitted {error: {code, message}} while every other 400 in the file
emits {error: {code, message, details}} via _build_error_response.
Switched to _build_error_response and added the details field to the
AssetsApiError schema in openapi.yaml.
- "Byte-identity fixtures" only checked substring containment, defeating
the test class's stated purpose of pinning the wire format. Switched
to exact-bytes equality against an inline expected payload string per
fixture, so any whitespace / key-order / escape drift fails loudly.
Also dropped Go / json.Marshal references from docstrings — the byte
format is the contract, not the runtime that mints it.
* fix(assets): cap cursors by encoded wire size, not just char count
Char-count guards on value/id can still let multibyte or escape-heavy
inputs blow past MAX_ENCODED_CURSOR_LENGTH once UTF-8 + escape expansion
+ base64url runs. A 512-character name of 'é' (2 bytes UTF-8) or '<'
(serializes to the 6-byte '<' escape) passes the char check, mints
a ~1500-byte cursor, then 400s when handed back on the next request.
Compute the final encoded form and reject it before returning if it
exceeds the wire cap. Adds regression tests for both inflation paths.
* refactor(assets): extract cursor JSON escaping helper; size wire cap above per-field caps
Addresses review feedback on cursor.py:
- Extract the inline escape chain into _apply_wire_compatible_json_escapes()
with a comment pinning it to the wire format's escape set, so the parity
intent is explicit rather than reading as an ad-hoc transform.
- Raise MAX_ENCODED_CURSOR_LENGTH to 8192 (comfortably above the ~5.2KB
worst-case the per-field caps can produce) and drop the mint-time length
guard. Encoder/decoder symmetry now holds by construction: the encoder
can't produce a cursor the decode path rejects, so there is no confusing
user-visible 'cursor too long' failure at mint time.
- Rewrite the two over-wire-cap tests to assert worst-case multibyte and
escape-heavy values mint and round-trip, instead of being rejected.
* refactor(assets): drop cross-runtime cursor escaping; cursors are opaque
The custom JSON escaping of <, >, &, U+2028, and U+2029 existed only to
keep the encoded cursor byte-identical with the Cloud implementation of
the same payload format. Cursors are opaque tokens, so byte-level
compatibility across implementations is not needed — plain json.dumps
output is sufficient. Remove the escaping helper and the byte-identity
test fixtures that pinned the wire format; keep round-trip coverage for
the affected characters.
---------
Co-authored-by: guill <jacob.e.segal@gmail.com>
* fix(assets): remove unused delete_content param from deleteAsset (#14241)
* fix(assets): remove unused delete_content param from deleteAsset
The delete_content query param on DELETE /api/assets/{id} was introduced
in #12125 and had its default flipped to false in #12621. In practice no
client sends it: the frontend issues a bare DELETE /assets/{id}, so every
real caller already gets the default soft-delete (the reference is hidden,
content preserved). The only thing that set delete_content=true was this
repo's own test teardown.
Remove the param from the route and the OpenAPI spec so the contract
matches what clients actually use (and lines up with the cloud surface).
The route now always soft-deletes. The underlying delete_asset_reference
helper keeps its delete_content_if_orphan option, so orphan reclamation
remains available internally for a future GC path — it's just no longer
exposed on the public endpoint. Tests that used delete_content=true for
hard cleanup now soft-delete; test_delete_upon_reference_count asserts
content preservation instead of orphan removal.
* test/docs: address review on deleteAsset delete_content removal
- Rename test_delete_upon_reference_count ->
test_soft_delete_preserves_asset_identity_across_references; the old name
implied last-ref cleanup, but it now verifies the opposite (soft delete
preserves identity across references).
- Strengthen the re-association assertion: also check asset_hash == src_hash
so it proves content reuse rather than relying on the now-tautological
created_new is False.
- Document delete_asset_reference: the orphan-reclamation branch is
intentionally internal-only; the public endpoint always soft-deletes.
- Normalize the soft-delete comment phrasing.
* test(assets): make seed content unique per test for isolation
Removing the delete_content param means delete is always a soft delete, so
content created by one test now survives into the next. The suite had been
relying on hard-delete teardown for isolation, so shared fixed-content
fixtures started colliding: seeded_asset (b"A"*4096) and
make_asset_bytes (deterministic on name) produced the same hash every test,
so the second seed deduped to the surviving asset and returned 200 instead
of 201, cascading into ~14 failures/errors.
Salt both fixtures with a per-test uuid so each test creates fresh content
(created_new True, 201), while keeping content deterministic within a test
(same name/size -> same bytes) and preserving exact byte length so size-based
list/sort assertions are unaffected.
* main: force cudnn.benchmark to false (#14390)
Some custom nodes try to set this true globally. It messes with dynamic
VRAM with one-off spikes that can OOM but this is also very high risk
for windows where such allocations might get serviced by shared memory
fallback.
Trump it.
* feat(assets): add job_ids filter to GET /api/assets (#13998)
* feat(assets): add job_ids filter to GET /api/assets
Mirrors the existing cloud `job_ids` query param on the local Python server:
clients can pass a comma-separated list (or repeated query params) of UUIDs
to filter assets by their associated job.
The `AssetReference.job_id` column already exists, so no migration is
needed — this just plumbs the filter through schema → service → query.
Marks the parameter as available in both runtimes by dropping the
`[cloud-only]` description prefix and the `x-runtime: [cloud]` tag from
the OpenAPI spec, per the OSS field-drift convention (absent runtime tag
= populated by both local and cloud).
* fix(assets): tighten job_ids — array schema, max_length, narrow except
From cursor-reviews on the parent commit:
- OpenAPI: declare job_ids as `type: array, items: string format: uuid`
with `style: form, explode: true` so it matches the documented
contract (and matches sibling include_tags/exclude_tags shape).
Description now states both accepted shapes explicitly.
- Schema: cap `job_ids` at 500 entries (max_length on the Pydantic
field) so a client can't splice an unbounded list into the IN clauses.
- Schema: drop `AttributeError` from the except — `raw` only contains
`str` items by construction, so `uuid.UUID(<str>)` raises `ValueError`
exclusively; the second clause was dead code.
* fix(assets): tighten job_ids validator + add schema-level tests
Aligns with the parallel hardening from draft PR #13848 (now closed as
a duplicate). The validator now:
- Raises ValueError on non-string list items (was: silently dropped).
- Raises ValueError on non-string / non-list top-level values like dict
or int (was: silently passed through to Pydantic's downstream coercion).
Adds tests-unit/assets_test/queries/test_list_assets_query.py covering
the validator end-to-end: CSV canonicalization, dedup order, default
empty, invalid UUID, non-string list item, non-string non-list value,
and the max_length=500 boundary.
* feat(prompt): enforce canonical UUID prompt_id at job creation
POST /prompt previously accepted any client-supplied prompt_id verbatim,
str()-coercing even non-strings, and minting the literal job id "None"
for an explicit JSON null. The new GET /api/assets job_ids filter matches
stored job ids as canonical UUIDs exactly, so a non-UUID id minted a job
whose assets could never be filtered.
- validate_job_id (comfy_execution/jobs.py): requires a string in the
canonical lowercase hyphenated UUID form; raises ValueError otherwise,
including parseable-but-non-canonical spellings (uppercase, braced, URN,
bare hex), which would otherwise be silently rewritten and then miss
every exact-match lookup downstream (history keys, websocket
correlation, /interrupt, the assets job_ids filter).
- POST /prompt: absent or null prompt_id means the server mints uuid4;
invalid means 400 invalid_prompt_id on the standard error envelope.
- openapi.yaml: document the request-side prompt_id (format uuid,
nullable) on PromptRequest.
- tests: unit matrix for validate_job_id; integration tests against the
booted server covering rejection, acceptance, and null handling.
---------
Co-authored-by: guill <jacob.e.segal@gmail.com>
* feat(assets): include asset id in executed WebSocket message (#13862)
* feat(assets): enrich executed WS message with asset metadata
When --enable-assets is set, each file-type output entry in the
`executed` WebSocket message now includes id, name, asset_hash, size,
and mime_type — matching the shape already returned by /upload/image.
The enrichment lives in comfy_execution/asset_enrichment.py (no torch
dependency) and is called from both send sites in execution.py: freshly
executed nodes register the file inline via register_file_in_place;
cached node re-sends look up the existing AssetReference by file path
to avoid re-hashing. Errors are caught per-entry so a failure never
blocks the WS message from sending.
* fix(assets): inject only id in executed WS message per Asset Identity RFC
Per the Asset Identity RFC, the executed WebSocket payload should carry
id alone — hash is already encoded in the filename, and name/preview_url/
size belong behind GET /api/assets/{id} rather than being pushed eagerly.
Simplifies the DB lookup path: we only need ref.id, so the asset.hash
null-check is no longer required as a fallback trigger.
* fix(assets): reject path traversal when resolving output abs_path
Subfolder/filename were joined and absolutized without containment check,
so '..' segments or an absolute filename could escape the type's base
directory and register an unrelated on-disk file as an asset.
Add commonpath-based containment check; skip enrichment (warn, leave
entry unchanged) when the resolved path escapes base. Catches ValueError
from cross-drive paths on Windows.
* docs(assets): drop Asset Identity RFC reference from docstring
* docs(assets): trim docstring to what enrichment does, not what it doesn't
* test(assets): use real platform paths so containment check works on Windows
The previous test setup patched os.path.abspath to identity and used a
POSIX-style '/output' base, which collided with Windows path separators
in os.path.commonpath. Drop the abspath/join patches and use a real
tempdir-rooted base so the containment check runs against actual
platform paths.
* refactor(assets): enrich at output-processing time, not in the WS send path
Per review: enrichment lived inside the client_id-guarded send sites, so a
headless run (no websocket client) never registered assets at all, and
ui_outputs/history stored the un-enriched entries.
Now output_ui is enriched once, right after the node produces it and before
it is stored in ui_outputs — so registration happens regardless of connected
clients, and the asset id flows into history and the execution cache for
free. _send_cached_ui re-sends the stored (already-enriched) dict verbatim,
which lets the DB-lookup-by-path fallback be deleted: every enrichment is
now a fresh output, and register_file_in_place re-hashes on upsert so an
overwritten path can never carry a stale id.
* revert(assets): drop job_ids filter from GET /api/assets (#14408)
The job_ids query filter added in #13998 has no live consumer: the
frontend Generated tab kept sourcing from GET /jobs, and the cloud side
removed its equivalent filter from the shared asset spec. Carrying it on
the local server only re-introduces Core<->Cloud drift on the shared
contract, so remove it to match.
Removed: the job_ids field + validator on ListAssetsQuery, the IN(...)
clauses in list_references_page, the service/route passthrough, and the
filter-only tests.
Kept: the canonical-UUID prompt_id enforcement at job creation (also
landed in #13998). It stands on its own -- job ids are matched verbatim
by history keys, websocket correlation, and /interrupt -- and cloud
inherits it by running core for execution, so no divergence is created.
* chore(openapi): sync shared API contract from cloud@e3c52ad (#14406)
* I don't think this actually works anymore. (#14403)
* ops: tolerate already force casted dynamic weight (#14410)
Some custom nodes .to weights completely out of load context which
can wreak havoc if its for a model that is not active. Detect this
condition and just let it fall-through to the non-dynamic loader
straight up.
---------
Signed-off-by: bigcat88 <bigcat88@icloud.com>
Co-authored-by: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Co-authored-by: Daxiong (Lin) <contact@comfyui-wiki.com>
Co-authored-by: Comfy Org PR Bot <snomiao+comfy-pr@gmail.com>
Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Co-authored-by: Alexis Rolland <alexisrolland@hotmail.com>
Co-authored-by: Jukka Seppänen <40791699+kijai@users.noreply.github.com>
Co-authored-by: rattus <46076784+rattus128@users.noreply.github.com>
Co-authored-by: Terry Jia <terryjia88@gmail.com>
Co-authored-by: John Pollock <pollockjj@users.noreply.github.com>
Co-authored-by: Silver <65376327+silveroxides@users.noreply.github.com>
Co-authored-by: Matt Miller <mattmiller@comfy.org>
Co-authored-by: guill <jacob.e.segal@gmail.com>
Co-authored-by: kelseyee <971704395@qq.com>
Co-authored-by: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>
Co-authored-by: Talmaj <Talmaj@users.noreply.github.com>
2024 lines
65 KiB
Python
2024 lines
65 KiB
Python
"""
|
|
This file is part of ComfyUI.
|
|
Copyright (C) 2024 Comfy
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import psutil
|
|
import logging
|
|
from enum import Enum
|
|
from comfy.cli_args import args, PerformanceFeature
|
|
import threading
|
|
import torch
|
|
import sys
|
|
import platform
|
|
import weakref
|
|
import gc
|
|
import os
|
|
from contextlib import contextmanager, nullcontext
|
|
import comfy.memory_management
|
|
import comfy.utils
|
|
import comfy.quant_ops
|
|
import comfy_aimdo.host_buffer
|
|
import comfy_aimdo.vram_buffer
|
|
|
|
from typing import TYPE_CHECKING
|
|
if TYPE_CHECKING:
|
|
from comfy.model_patcher import ModelPatcher
|
|
|
|
|
|
class VRAMState(Enum):
|
|
DISABLED = 0 #No vram present: no need to move models to vram
|
|
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
|
LOW_VRAM = 2
|
|
NORMAL_VRAM = 3
|
|
HIGH_VRAM = 4
|
|
SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
|
|
|
|
class CPUState(Enum):
|
|
GPU = 0
|
|
CPU = 1
|
|
MPS = 2
|
|
|
|
# Determine VRAM State
|
|
vram_state = VRAMState.NORMAL_VRAM
|
|
set_vram_to = VRAMState.NORMAL_VRAM
|
|
cpu_state = CPUState.GPU
|
|
|
|
total_vram = 0
|
|
|
|
|
|
# Training Related State
|
|
in_training = False
|
|
training_fp8_bwd = False
|
|
|
|
|
|
def get_supported_float8_types():
|
|
float8_types = []
|
|
try:
|
|
float8_types.append(torch.float8_e4m3fn)
|
|
except:
|
|
pass
|
|
try:
|
|
float8_types.append(torch.float8_e4m3fnuz)
|
|
except:
|
|
pass
|
|
try:
|
|
float8_types.append(torch.float8_e5m2)
|
|
except:
|
|
pass
|
|
try:
|
|
float8_types.append(torch.float8_e5m2fnuz)
|
|
except:
|
|
pass
|
|
try:
|
|
float8_types.append(torch.float8_e8m0fnu)
|
|
except:
|
|
pass
|
|
return float8_types
|
|
|
|
FLOAT8_TYPES = get_supported_float8_types()
|
|
|
|
xpu_available = False
|
|
torch_version = ""
|
|
try:
|
|
torch_version = torch.version.__version__
|
|
temp = torch_version.split(".")
|
|
torch_version_numeric = (int(temp[0]), int(temp[1]))
|
|
except:
|
|
pass
|
|
|
|
lowvram_available = True
|
|
if args.deterministic:
|
|
logging.info("Using deterministic algorithms for pytorch")
|
|
torch.use_deterministic_algorithms(True, warn_only=True)
|
|
|
|
directml_enabled = False
|
|
if args.directml is not None:
|
|
logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.")
|
|
import torch_directml
|
|
directml_enabled = True
|
|
device_index = args.directml
|
|
if device_index < 0:
|
|
directml_device = torch_directml.device()
|
|
else:
|
|
directml_device = torch_directml.device(device_index)
|
|
logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index)))
|
|
# torch_directml.disable_tiled_resources(True)
|
|
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
|
|
|
|
|
try:
|
|
_ = torch.xpu.device_count()
|
|
xpu_available = torch.xpu.is_available()
|
|
except:
|
|
xpu_available = False
|
|
|
|
try:
|
|
if torch.backends.mps.is_available():
|
|
cpu_state = CPUState.MPS
|
|
import torch.mps
|
|
except:
|
|
pass
|
|
|
|
try:
|
|
import torch_npu # noqa: F401
|
|
_ = torch.npu.device_count()
|
|
npu_available = torch.npu.is_available()
|
|
except:
|
|
npu_available = False
|
|
|
|
try:
|
|
import torch_mlu # noqa: F401
|
|
_ = torch.mlu.device_count()
|
|
mlu_available = torch.mlu.is_available()
|
|
except:
|
|
mlu_available = False
|
|
|
|
try:
|
|
ixuca_available = hasattr(torch, "corex")
|
|
except:
|
|
ixuca_available = False
|
|
|
|
if args.cpu:
|
|
cpu_state = CPUState.CPU
|
|
|
|
def is_intel_xpu():
|
|
global cpu_state
|
|
global xpu_available
|
|
if cpu_state == CPUState.GPU:
|
|
if xpu_available:
|
|
return True
|
|
return False
|
|
|
|
def is_ascend_npu():
|
|
global npu_available
|
|
if npu_available:
|
|
return True
|
|
return False
|
|
|
|
def is_mlu():
|
|
global mlu_available
|
|
if mlu_available:
|
|
return True
|
|
return False
|
|
|
|
def is_ixuca():
|
|
global ixuca_available
|
|
if ixuca_available:
|
|
return True
|
|
return False
|
|
|
|
def is_wsl():
|
|
version = platform.uname().release
|
|
if version.endswith("-Microsoft"):
|
|
return True
|
|
elif version.endswith("microsoft-standard-WSL2"):
|
|
return True
|
|
return False
|
|
|
|
def get_torch_device():
|
|
global directml_enabled
|
|
global cpu_state
|
|
if directml_enabled:
|
|
global directml_device
|
|
return directml_device
|
|
if cpu_state == CPUState.MPS:
|
|
return torch.device("mps")
|
|
if cpu_state == CPUState.CPU:
|
|
return torch.device("cpu")
|
|
else:
|
|
if is_intel_xpu():
|
|
return torch.device("xpu", torch.xpu.current_device())
|
|
elif is_ascend_npu():
|
|
return torch.device("npu", torch.npu.current_device())
|
|
elif is_mlu():
|
|
return torch.device("mlu", torch.mlu.current_device())
|
|
else:
|
|
return torch.device(torch.cuda.current_device())
|
|
|
|
def get_all_torch_devices(exclude_current=False):
|
|
global cpu_state
|
|
devices = []
|
|
if cpu_state == CPUState.GPU:
|
|
# NVIDIA + AMD/ROCm both expose their GPUs through torch.cuda.*;
|
|
# without the AMD arm, single-GPU ROCm users get an empty list
|
|
# which silently turns unload_all_models() into a no-op.
|
|
if is_nvidia() or is_amd():
|
|
for i in range(torch.cuda.device_count()):
|
|
devices.append(torch.device("cuda", i))
|
|
elif is_intel_xpu():
|
|
for i in range(torch.xpu.device_count()):
|
|
devices.append(torch.device("xpu", i))
|
|
elif is_ascend_npu():
|
|
for i in range(torch.npu.device_count()):
|
|
devices.append(torch.device("npu", i))
|
|
elif is_mlu():
|
|
for i in range(torch.mlu.device_count()):
|
|
devices.append(torch.device("mlu", i))
|
|
else:
|
|
# Fallback for unhandled GPU backends (e.g. DirectML): at least
|
|
# report the current device so callers like unload_all_models()
|
|
# do not silently no-op.
|
|
devices.append(get_torch_device())
|
|
else:
|
|
devices.append(get_torch_device())
|
|
if exclude_current:
|
|
current = get_torch_device()
|
|
if current in devices:
|
|
devices.remove(current)
|
|
return devices
|
|
|
|
def get_gpu_device_options():
|
|
"""Return list of device option strings for node widgets.
|
|
|
|
Always includes "default" and "cpu". When multiple GPUs are present,
|
|
adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels).
|
|
"""
|
|
options = ["default", "cpu"]
|
|
devices = get_all_torch_devices()
|
|
if len(devices) > 1:
|
|
for i in range(len(devices)):
|
|
options.append(f"gpu:{i}")
|
|
return options
|
|
|
|
def get_gpu_device_options_no_cpu():
|
|
"""Variant of get_gpu_device_options that omits "cpu".
|
|
|
|
Intended for components like the VAE selector where running on CPU
|
|
is impractical and should not be offered as a choice.
|
|
"""
|
|
return [o for o in get_gpu_device_options() if o != "cpu"]
|
|
|
|
def resolve_gpu_device_option(option: str):
|
|
"""Resolve a device option string to a torch.device.
|
|
|
|
Returns None for "default" (let the caller use its normal default).
|
|
Returns torch.device("cpu") for "cpu".
|
|
For "gpu:N", returns the Nth torch device. Returns None if the
|
|
index is out of range, the option string is malformed, or
|
|
unrecognized (callers are expected to log their own context-rich
|
|
message before falling back to the default device).
|
|
"""
|
|
if option is None or option == "default":
|
|
return None
|
|
if option == "cpu":
|
|
return torch.device("cpu")
|
|
if option.startswith("gpu:"):
|
|
try:
|
|
idx = int(option[4:])
|
|
except ValueError:
|
|
return None
|
|
devices = get_all_torch_devices()
|
|
if 0 <= idx < len(devices):
|
|
return devices[idx]
|
|
return None
|
|
|
|
@contextmanager
|
|
def cuda_device_context(device):
|
|
"""Context manager that sets torch.cuda.current_device to match *device*.
|
|
|
|
Used when running operations on a non-default CUDA device so that custom
|
|
CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
|
|
device index. The previous device is restored on exit.
|
|
|
|
No-op when *device* is not CUDA, has no explicit index, or already matches
|
|
the current device.
|
|
"""
|
|
prev = None
|
|
if device.type == "cuda" and device.index is not None:
|
|
prev = torch.cuda.current_device()
|
|
if prev != device.index:
|
|
torch.cuda.set_device(device)
|
|
else:
|
|
prev = None
|
|
try:
|
|
yield
|
|
finally:
|
|
if prev is not None:
|
|
torch.cuda.set_device(prev)
|
|
|
|
def get_total_memory(dev=None, torch_total_too=False):
|
|
global directml_enabled
|
|
if dev is None:
|
|
dev = get_torch_device()
|
|
|
|
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
|
mem_total = psutil.virtual_memory().total
|
|
mem_total_torch = mem_total
|
|
else:
|
|
if directml_enabled:
|
|
mem_total = 1024 * 1024 * 1024 #TODO
|
|
mem_total_torch = mem_total
|
|
elif is_intel_xpu():
|
|
stats = torch.xpu.memory_stats(dev)
|
|
mem_reserved = stats['reserved_bytes.all.current']
|
|
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
|
|
mem_total_torch = mem_reserved
|
|
mem_total = mem_total_xpu
|
|
elif is_ascend_npu():
|
|
stats = torch.npu.memory_stats(dev)
|
|
mem_reserved = stats['reserved_bytes.all.current']
|
|
_, mem_total_npu = torch.npu.mem_get_info(dev)
|
|
mem_total_torch = mem_reserved
|
|
mem_total = mem_total_npu
|
|
elif is_mlu():
|
|
stats = torch.mlu.memory_stats(dev)
|
|
mem_reserved = stats['reserved_bytes.all.current']
|
|
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
|
|
mem_total_torch = mem_reserved
|
|
mem_total = mem_total_mlu
|
|
else:
|
|
stats = torch.cuda.memory_stats(dev)
|
|
mem_reserved = stats['reserved_bytes.all.current']
|
|
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
|
|
mem_total_torch = mem_reserved
|
|
mem_total = mem_total_cuda
|
|
|
|
if torch_total_too:
|
|
return (mem_total, mem_total_torch)
|
|
else:
|
|
return mem_total
|
|
|
|
def mac_version():
|
|
try:
|
|
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
|
except:
|
|
return None
|
|
|
|
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
|
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
|
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
|
|
|
try:
|
|
logging.info("pytorch version: {}".format(torch_version))
|
|
mac_ver = mac_version()
|
|
if mac_ver is not None:
|
|
logging.info("Mac Version {}".format(mac_ver))
|
|
except:
|
|
pass
|
|
|
|
try:
|
|
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
|
except:
|
|
OOM_EXCEPTION = Exception
|
|
|
|
try:
|
|
ACCELERATOR_ERROR = torch.AcceleratorError
|
|
except AttributeError:
|
|
ACCELERATOR_ERROR = RuntimeError
|
|
|
|
def is_oom(e):
|
|
if isinstance(e, OOM_EXCEPTION):
|
|
return True
|
|
if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
|
|
discard_cuda_async_error()
|
|
return True
|
|
return False
|
|
|
|
def raise_non_oom(e):
|
|
if not is_oom(e):
|
|
raise e
|
|
|
|
XFORMERS_VERSION = ""
|
|
XFORMERS_ENABLED_VAE = True
|
|
if args.disable_xformers:
|
|
XFORMERS_IS_AVAILABLE = False
|
|
else:
|
|
try:
|
|
import xformers
|
|
import xformers.ops
|
|
XFORMERS_IS_AVAILABLE = True
|
|
try:
|
|
XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
|
|
except:
|
|
pass
|
|
try:
|
|
XFORMERS_VERSION = xformers.version.__version__
|
|
logging.info("xformers version: {}".format(XFORMERS_VERSION))
|
|
if XFORMERS_VERSION.startswith("0.0.18"):
|
|
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
|
|
logging.warning("Please downgrade or upgrade xformers to a different version.\n")
|
|
XFORMERS_ENABLED_VAE = False
|
|
except:
|
|
pass
|
|
except:
|
|
XFORMERS_IS_AVAILABLE = False
|
|
|
|
def is_nvidia():
|
|
global cpu_state
|
|
if cpu_state == CPUState.GPU:
|
|
if torch.version.cuda:
|
|
return True
|
|
return False
|
|
|
|
def is_amd():
|
|
global cpu_state
|
|
if cpu_state == CPUState.GPU:
|
|
if torch.version.hip:
|
|
return True
|
|
return False
|
|
|
|
def amd_min_version(device=None, min_rdna_version=0):
|
|
if not is_amd():
|
|
return False
|
|
|
|
if is_device_cpu(device):
|
|
return False
|
|
|
|
arch = torch.cuda.get_device_properties(device).gcnArchName
|
|
if arch.startswith('gfx') and len(arch) == 7:
|
|
try:
|
|
cmp_rdna_version = int(arch[4]) + 2
|
|
except:
|
|
cmp_rdna_version = 0
|
|
if cmp_rdna_version >= min_rdna_version:
|
|
return True
|
|
|
|
return False
|
|
|
|
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
|
if is_nvidia():
|
|
MIN_WEIGHT_MEMORY_RATIO = 0.0
|
|
|
|
ENABLE_PYTORCH_ATTENTION = False
|
|
if args.use_pytorch_cross_attention:
|
|
ENABLE_PYTORCH_ATTENTION = True
|
|
XFORMERS_IS_AVAILABLE = False
|
|
|
|
try:
|
|
if is_nvidia():
|
|
if torch_version_numeric[0] >= 2:
|
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
|
ENABLE_PYTORCH_ATTENTION = True
|
|
if is_intel_xpu() or is_ascend_npu() or is_mlu() or is_ixuca():
|
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
|
ENABLE_PYTORCH_ATTENTION = True
|
|
except:
|
|
pass
|
|
|
|
|
|
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
|
|
|
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
|
|
AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
|
|
|
|
try:
|
|
if is_amd():
|
|
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0]
|
|
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
|
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
|
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
|
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
|
|
|
try:
|
|
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
|
except:
|
|
rocm_version = (6, -1)
|
|
|
|
def aotriton_supported(gpu_arch):
|
|
path = torch.__path__[0]
|
|
path = os.path.join(os.path.join(path, "lib"), "aotriton.images")
|
|
gfx = set(map(lambda a: a[4:], filter(lambda a: a.startswith("amd-gfx"), os.listdir(path))))
|
|
if gpu_arch in gfx:
|
|
return True
|
|
if "{}x".format(gpu_arch[:-1]) in gfx:
|
|
return True
|
|
if "{}xx".format(gpu_arch[:-2]) in gfx:
|
|
return True
|
|
return False
|
|
|
|
logging.info("AMD arch: {}".format(arch))
|
|
logging.info("ROCm version: {}".format(rocm_version))
|
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
|
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
|
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
|
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1150", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
|
ENABLE_PYTORCH_ATTENTION = True
|
|
if rocm_version >= (7, 0):
|
|
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
|
ENABLE_PYTORCH_ATTENTION = True
|
|
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
|
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
|
|
SUPPORT_FP8_OPS = True
|
|
|
|
except:
|
|
pass
|
|
|
|
|
|
if ENABLE_PYTORCH_ATTENTION:
|
|
torch.backends.cuda.enable_math_sdp(True)
|
|
torch.backends.cuda.enable_flash_sdp(True)
|
|
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
|
|
|
|
|
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
|
|
try:
|
|
if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast:
|
|
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
|
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
|
|
logging.info("Enabled fp16 accumulation.")
|
|
except:
|
|
pass
|
|
|
|
|
|
def set_cudnn_benchmark():
|
|
if torch.cuda.is_available() and torch.backends.cudnn.is_available():
|
|
torch.backends.cudnn.benchmark = PerformanceFeature.AutoTune in args.fast
|
|
|
|
try:
|
|
if torch_version_numeric >= (2, 5):
|
|
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
|
except:
|
|
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
|
|
|
|
if args.lowvram:
|
|
set_vram_to = VRAMState.LOW_VRAM
|
|
lowvram_available = True
|
|
elif args.novram:
|
|
set_vram_to = VRAMState.NO_VRAM
|
|
elif args.highvram or args.gpu_only:
|
|
vram_state = VRAMState.HIGH_VRAM
|
|
|
|
FORCE_FP32 = False
|
|
if args.force_fp32:
|
|
logging.info("Forcing FP32, if this improves things please report it.")
|
|
FORCE_FP32 = True
|
|
|
|
if lowvram_available:
|
|
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
|
vram_state = set_vram_to
|
|
|
|
|
|
if cpu_state != CPUState.GPU:
|
|
vram_state = VRAMState.DISABLED
|
|
|
|
if cpu_state == CPUState.MPS:
|
|
vram_state = VRAMState.SHARED
|
|
|
|
logging.info(f"Set vram state to: {vram_state.name}")
|
|
|
|
DISABLE_SMART_MEMORY = args.disable_smart_memory
|
|
|
|
if DISABLE_SMART_MEMORY:
|
|
logging.info("Disabling smart memory management")
|
|
|
|
def get_torch_device_name(device):
|
|
if hasattr(device, 'type'):
|
|
if device.type == "cuda":
|
|
try:
|
|
allocator_backend = torch.cuda.get_allocator_backend()
|
|
except:
|
|
allocator_backend = ""
|
|
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
|
|
elif device.type == "xpu":
|
|
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
|
else:
|
|
return "{}".format(device.type)
|
|
elif is_intel_xpu():
|
|
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
|
elif is_ascend_npu():
|
|
return "{} {}".format(device, torch.npu.get_device_name(device))
|
|
elif is_mlu():
|
|
return "{} {}".format(device, torch.mlu.get_device_name(device))
|
|
else:
|
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
|
|
|
try:
|
|
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
|
except:
|
|
logging.warning("Could not pick default device.")
|
|
try:
|
|
for device in get_all_torch_devices(exclude_current=True):
|
|
logging.info("Device: {}".format(get_torch_device_name(device)))
|
|
except:
|
|
pass
|
|
|
|
current_loaded_models: list[LoadedModel] = []
|
|
|
|
DIRTY_MMAPS = set()
|
|
|
|
PIN_PRESSURE_HYSTERESIS = 256 * 1024 * 1024
|
|
|
|
#Freeing registerables on pressure does imply a GPU sync, so go big on
|
|
#the hysteresis so each expensive sync gives us back a good chunk.
|
|
REGISTERABLE_PIN_HYSTERESIS = 2048 * 1024 * 1024
|
|
|
|
def module_size(module):
|
|
module_mem = 0
|
|
sd = module.state_dict()
|
|
for k in sd:
|
|
t = sd[k]
|
|
module_mem += t.nbytes
|
|
return module_mem
|
|
|
|
def mark_mmap_dirty(storage):
|
|
mmap_refs = getattr(storage, "_comfy_tensor_mmap_refs", None)
|
|
if mmap_refs is not None:
|
|
DIRTY_MMAPS.add(mmap_refs[0])
|
|
|
|
def free_pins(size, evict_active=False):
|
|
freed_total = 0
|
|
for loaded_model in reversed(current_loaded_models):
|
|
if size <= 0:
|
|
return freed_total
|
|
model = loaded_model.model
|
|
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
|
|
freed = model.partially_unload_ram(size)
|
|
freed_total += freed
|
|
size -= freed
|
|
return freed_total
|
|
|
|
def ensure_pin_budget(size, evict_active=False):
|
|
if args.fast_disk:
|
|
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
|
else:
|
|
shortfall = size + max(comfy.memory_management.RAM_CACHE_HEADROOM / 2, 2048 * 1024 ** 2) - psutil.virtual_memory().available
|
|
if shortfall <= 0:
|
|
return True
|
|
|
|
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
|
|
return free_pins(to_free, evict_active=evict_active) >= shortfall
|
|
|
|
def free_registrations(shortfall, evict_active=True):
|
|
if MAX_PINNED_MEMORY <= 0:
|
|
return False
|
|
if shortfall <= 0:
|
|
return True
|
|
|
|
shortfall += REGISTERABLE_PIN_HYSTERESIS
|
|
for loaded_model in reversed(current_loaded_models):
|
|
model = loaded_model.model
|
|
if model is not None and model.is_dynamic() and not model.model.dynamic_pins[model.load_device]["active"]:
|
|
shortfall -= model.unregister_inactive_pins(shortfall)
|
|
if shortfall <= 0:
|
|
return True
|
|
if evict_active:
|
|
for loaded_model in current_loaded_models:
|
|
model = loaded_model.model
|
|
if model is not None and model.is_dynamic() and model.model.dynamic_pins[model.load_device]["active"]:
|
|
shortfall -= model.unregister_inactive_pins(shortfall)
|
|
if shortfall <= 0:
|
|
return True
|
|
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
|
|
|
|
def ensure_pin_registerable(size, evict_active=True):
|
|
return free_registrations(TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY, evict_active=evict_active)
|
|
|
|
class LoadedModel:
|
|
def __init__(self, model: ModelPatcher):
|
|
self._set_model(model)
|
|
self.device = model.load_device
|
|
self.real_model = None
|
|
self.currently_used = True
|
|
self.model_finalizer = None
|
|
self._patcher_finalizer = None
|
|
|
|
def _set_model(self, model: ModelPatcher):
|
|
self._model = weakref.ref(model)
|
|
if model.parent is not None:
|
|
self._parent_model = weakref.ref(model.parent)
|
|
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
|
|
self._patcher_finalizer.atexit = False
|
|
|
|
def _switch_parent(self):
|
|
model = self._parent_model()
|
|
if model is not None:
|
|
self._set_model(model)
|
|
self.device = model.load_device
|
|
|
|
@property
|
|
def model(self):
|
|
return self._model()
|
|
|
|
def model_memory(self):
|
|
return self.model.model_size()
|
|
|
|
def model_loaded_memory(self):
|
|
return self.model.loaded_size()
|
|
|
|
def model_offloaded_memory(self):
|
|
return self.model.model_size() - self.model.loaded_size()
|
|
|
|
def model_memory_required(self, device):
|
|
if device == self.model.current_loaded_device():
|
|
return self.model_offloaded_memory()
|
|
else:
|
|
return self.model_memory()
|
|
|
|
def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
|
|
self.model.model_patches_to(self.device)
|
|
self.model.model_patches_to(self.model.model_dtype())
|
|
|
|
# if self.model.loaded_size() > 0:
|
|
use_more_vram = lowvram_model_memory
|
|
if use_more_vram == 0:
|
|
use_more_vram = 1e32
|
|
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
|
|
|
|
real_model = self.model.model
|
|
|
|
|
|
self.real_model = weakref.ref(real_model)
|
|
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
|
|
self.model_finalizer.atexit = False
|
|
return real_model
|
|
|
|
def should_reload_model(self, force_patch_weights=False):
|
|
if force_patch_weights and self.model.lowvram_patch_counter() > 0:
|
|
return True
|
|
return False
|
|
|
|
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
|
if memory_to_free is not None:
|
|
if memory_to_free < self.model.loaded_size():
|
|
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
|
if freed >= memory_to_free:
|
|
return False
|
|
self.model.detach(unpatch_weights)
|
|
self.model_finalizer.detach()
|
|
self.model_finalizer = None
|
|
self.real_model = None
|
|
return True
|
|
|
|
def model_use_more_vram(self, extra_memory, force_patch_weights=False):
|
|
return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
|
|
|
|
def __eq__(self, other):
|
|
return self.model is other.model
|
|
|
|
def __del__(self):
|
|
if self._patcher_finalizer is not None:
|
|
self._patcher_finalizer.detach()
|
|
|
|
def is_dead(self):
|
|
return self.real_model() is not None and self.model is None
|
|
|
|
|
|
def use_more_memory(extra_memory, loaded_models, device):
|
|
for m in loaded_models:
|
|
if m.device == device:
|
|
extra_memory -= m.model_use_more_vram(extra_memory)
|
|
if extra_memory <= 0:
|
|
break
|
|
|
|
def offloaded_memory(loaded_models, device):
|
|
offloaded_mem = 0
|
|
for m in loaded_models:
|
|
if m.device == device:
|
|
offloaded_mem += m.model_offloaded_memory()
|
|
return offloaded_mem
|
|
|
|
WINDOWS = any(platform.win32_ver())
|
|
|
|
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
|
if WINDOWS:
|
|
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
|
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
|
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
|
|
|
if args.reserve_vram is not None:
|
|
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
|
logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
|
|
|
|
def extra_reserved_memory():
|
|
return EXTRA_RESERVED_VRAM
|
|
|
|
def minimum_inference_memory():
|
|
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
|
|
|
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
|
|
cleanup_models_gc()
|
|
unloaded_model = []
|
|
can_unload = []
|
|
unloaded_models = []
|
|
|
|
for i in range(len(current_loaded_models) -1, -1, -1):
|
|
shift_model = current_loaded_models[i]
|
|
if device is None or shift_model.device == device:
|
|
if shift_model not in keep_loaded and not shift_model.is_dead():
|
|
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
|
shift_model.currently_used = False
|
|
|
|
can_unload_sorted = sorted(can_unload)
|
|
for x in can_unload_sorted:
|
|
i = x[-1]
|
|
memory_to_free = 1e32
|
|
if not DISABLE_SMART_MEMORY or device is None:
|
|
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
|
|
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
|
#don't actually unload dynamic models for the sake of other dynamic models
|
|
#as that works on-demand.
|
|
memory_required -= current_loaded_models[i].model.loaded_size()
|
|
memory_to_free = 0
|
|
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
|
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
|
unloaded_model.append(i)
|
|
|
|
for i in sorted(unloaded_model, reverse=True):
|
|
unloaded_models.append(current_loaded_models.pop(i))
|
|
|
|
if not for_dynamic and pins_required > 0:
|
|
ensure_pin_budget(pins_required)
|
|
ensure_pin_registerable(pins_required)
|
|
|
|
if len(unloaded_model) > 0:
|
|
soft_empty_cache()
|
|
elif device is not None:
|
|
if vram_state != VRAMState.HIGH_VRAM:
|
|
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
|
if mem_free_torch > mem_free_total * 0.25:
|
|
soft_empty_cache()
|
|
return unloaded_models
|
|
|
|
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
|
cleanup_models_gc()
|
|
global vram_state
|
|
|
|
inference_memory = minimum_inference_memory()
|
|
extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
|
|
if minimum_memory_required is None:
|
|
minimum_memory_required = extra_mem
|
|
else:
|
|
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
|
|
|
# Order-preserving dedup. A plain set() would randomize iteration order across runs
|
|
models_temp = {}
|
|
for m in models:
|
|
models_temp[m] = None
|
|
for mm in m.model_patches_models():
|
|
models_temp[mm] = None
|
|
|
|
models = list(models_temp)
|
|
models.reverse()
|
|
|
|
models_to_load = []
|
|
|
|
free_for_dynamic=True
|
|
for x in models:
|
|
if not x.is_dynamic():
|
|
free_for_dynamic = False
|
|
loaded_model = LoadedModel(x)
|
|
try:
|
|
loaded_model_index = current_loaded_models.index(loaded_model)
|
|
except:
|
|
loaded_model_index = None
|
|
|
|
if loaded_model_index is not None:
|
|
loaded = current_loaded_models[loaded_model_index]
|
|
loaded.currently_used = True
|
|
models_to_load.append(loaded)
|
|
else:
|
|
if hasattr(x, "model"):
|
|
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
|
models_to_load.append(loaded_model)
|
|
|
|
for loaded_model in models_to_load:
|
|
to_unload = []
|
|
for i in range(len(current_loaded_models)):
|
|
if loaded_model.model.is_clone(current_loaded_models[i].model):
|
|
to_unload = [i] + to_unload
|
|
for i in to_unload:
|
|
model_to_unload = current_loaded_models.pop(i)
|
|
model_to_unload.model.detach(unpatch_all=False)
|
|
model_to_unload.model_finalizer.detach()
|
|
|
|
total_memory_required = {}
|
|
total_pins_required = {}
|
|
for loaded_model in models_to_load:
|
|
device = loaded_model.device
|
|
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
|
if not loaded_model.model.is_dynamic():
|
|
total_pins_required[device] = total_pins_required.get(device, 0) + loaded_model.model_memory()
|
|
|
|
for device in total_memory_required:
|
|
if device != torch.device("cpu"):
|
|
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
|
device,
|
|
for_dynamic=free_for_dynamic,
|
|
pins_required=total_pins_required.get(device, 0))
|
|
|
|
for device in total_memory_required:
|
|
if device != torch.device("cpu"):
|
|
free_mem = get_free_memory(device)
|
|
if free_mem < minimum_memory_required:
|
|
models_l = free_memory(minimum_memory_required, device, for_dynamic=free_for_dynamic)
|
|
logging.info("{} models unloaded.".format(len(models_l)))
|
|
|
|
for loaded_model in models_to_load:
|
|
model = loaded_model.model
|
|
torch_dev = model.load_device
|
|
if is_device_cpu(torch_dev):
|
|
vram_set_state = VRAMState.DISABLED
|
|
else:
|
|
vram_set_state = vram_state
|
|
lowvram_model_memory = 0
|
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
|
|
loaded_memory = loaded_model.model_loaded_memory()
|
|
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
|
|
|
lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
|
lowvram_model_memory = lowvram_model_memory - loaded_memory
|
|
|
|
if lowvram_model_memory == 0:
|
|
lowvram_model_memory = 0.1
|
|
|
|
if vram_set_state == VRAMState.NO_VRAM:
|
|
lowvram_model_memory = 0.1
|
|
|
|
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
|
current_loaded_models.insert(0, loaded_model)
|
|
return
|
|
|
|
def load_model_gpu(model):
|
|
return load_models_gpu([model])
|
|
|
|
def loaded_models(only_currently_used=False):
|
|
output = []
|
|
for m in current_loaded_models:
|
|
if only_currently_used:
|
|
if not m.currently_used:
|
|
continue
|
|
|
|
output.append(m.model)
|
|
return output
|
|
|
|
|
|
def cleanup_models_gc():
|
|
do_gc = False
|
|
|
|
for i in range(len(current_loaded_models)):
|
|
cur = current_loaded_models[i]
|
|
if cur.is_dead():
|
|
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
|
|
do_gc = True
|
|
break
|
|
|
|
if do_gc:
|
|
gc.collect()
|
|
soft_empty_cache()
|
|
|
|
for i in range(len(current_loaded_models)):
|
|
cur = current_loaded_models[i]
|
|
if cur.is_dead():
|
|
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
|
|
|
|
|
def archive_model_dtypes(model):
|
|
for name, module in model.named_modules():
|
|
for param_name, param in module.named_parameters(recurse=False):
|
|
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
|
|
for buf_name, buf in module.named_buffers(recurse=False):
|
|
setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype)
|
|
|
|
|
|
def cleanup_models():
|
|
to_delete = []
|
|
for i in range(len(current_loaded_models)):
|
|
if current_loaded_models[i].real_model() is None:
|
|
to_delete = [i] + to_delete
|
|
|
|
for i in to_delete:
|
|
x = current_loaded_models.pop(i)
|
|
del x
|
|
|
|
def dtype_size(dtype):
|
|
dtype_size = 4
|
|
if dtype == torch.float16 or dtype == torch.bfloat16:
|
|
dtype_size = 2
|
|
elif dtype == torch.float32:
|
|
dtype_size = 4
|
|
else:
|
|
try:
|
|
dtype_size = dtype.itemsize
|
|
except: #Old pytorch doesn't have .itemsize
|
|
pass
|
|
return dtype_size
|
|
|
|
def unet_offload_device():
|
|
if vram_state == VRAMState.HIGH_VRAM:
|
|
return get_torch_device()
|
|
else:
|
|
return torch.device("cpu")
|
|
|
|
def unet_inital_load_device(parameters, dtype):
|
|
cpu_dev = torch.device("cpu")
|
|
if comfy.memory_management.aimdo_enabled:
|
|
return cpu_dev
|
|
|
|
torch_dev = get_torch_device()
|
|
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
|
|
return torch_dev
|
|
|
|
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
|
|
return cpu_dev
|
|
|
|
model_size = dtype_size(dtype) * parameters
|
|
|
|
mem_dev = get_free_memory(torch_dev)
|
|
mem_cpu = get_free_memory(cpu_dev)
|
|
if mem_dev > mem_cpu and model_size < mem_dev:
|
|
return torch_dev
|
|
else:
|
|
return cpu_dev
|
|
|
|
def maximum_vram_for_weights(device=None):
|
|
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
|
|
|
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32], weight_dtype=None):
|
|
if model_params < 0:
|
|
model_params = 1000000000000000000000
|
|
if args.fp32_unet:
|
|
return torch.float32
|
|
if args.fp64_unet:
|
|
return torch.float64
|
|
if args.bf16_unet:
|
|
return torch.bfloat16
|
|
if args.fp16_unet:
|
|
return torch.float16
|
|
if args.fp8_e4m3fn_unet:
|
|
return torch.float8_e4m3fn
|
|
if args.fp8_e5m2_unet:
|
|
return torch.float8_e5m2
|
|
if args.fp8_e8m0fnu_unet:
|
|
return torch.float8_e8m0fnu
|
|
|
|
fp8_dtype = None
|
|
if weight_dtype in FLOAT8_TYPES:
|
|
fp8_dtype = weight_dtype
|
|
|
|
if fp8_dtype is not None:
|
|
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
|
|
return fp8_dtype
|
|
|
|
free_model_memory = maximum_vram_for_weights(device)
|
|
if model_params * 2 > free_model_memory:
|
|
return fp8_dtype
|
|
|
|
if PRIORITIZE_FP16 or weight_dtype == torch.float16:
|
|
if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params):
|
|
return torch.float16
|
|
|
|
for dt in supported_dtypes:
|
|
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
|
|
if torch.float16 in supported_dtypes:
|
|
return torch.float16
|
|
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
|
|
if torch.bfloat16 in supported_dtypes:
|
|
return torch.bfloat16
|
|
|
|
for dt in supported_dtypes:
|
|
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
|
if torch.float16 in supported_dtypes:
|
|
return torch.float16
|
|
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
|
|
if torch.bfloat16 in supported_dtypes:
|
|
return torch.bfloat16
|
|
|
|
return torch.float32
|
|
|
|
# None means no manual cast
|
|
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
|
if weight_dtype == torch.float32 or weight_dtype == torch.float64:
|
|
return None
|
|
|
|
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
|
if fp16_supported and weight_dtype == torch.float16:
|
|
return None
|
|
|
|
bf16_supported = should_use_bf16(inference_device)
|
|
if bf16_supported and weight_dtype == torch.bfloat16:
|
|
return None
|
|
|
|
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
|
|
if PRIORITIZE_FP16 and fp16_supported and torch.float16 in supported_dtypes:
|
|
return torch.float16
|
|
|
|
for dt in supported_dtypes:
|
|
if dt == torch.float16 and fp16_supported:
|
|
return torch.float16
|
|
if dt == torch.bfloat16 and bf16_supported:
|
|
return torch.bfloat16
|
|
|
|
return torch.float32
|
|
|
|
def text_encoder_offload_device():
|
|
if args.gpu_only:
|
|
return get_torch_device()
|
|
else:
|
|
return torch.device("cpu")
|
|
|
|
def text_encoder_device():
|
|
if args.gpu_only:
|
|
return get_torch_device()
|
|
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled:
|
|
if should_use_fp16(prioritize_performance=False):
|
|
return get_torch_device()
|
|
else:
|
|
return torch.device("cpu")
|
|
else:
|
|
return torch.device("cpu")
|
|
|
|
def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
|
if comfy.memory_management.aimdo_enabled:
|
|
return offload_device
|
|
|
|
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
|
return offload_device
|
|
|
|
if is_device_mps(load_device):
|
|
return load_device
|
|
|
|
mem_l = get_free_memory(load_device)
|
|
mem_o = get_free_memory(offload_device)
|
|
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
|
|
return load_device
|
|
else:
|
|
return offload_device
|
|
|
|
def text_encoder_dtype(device=None):
|
|
if args.fp8_e4m3fn_text_enc:
|
|
return torch.float8_e4m3fn
|
|
elif args.fp8_e5m2_text_enc:
|
|
return torch.float8_e5m2
|
|
elif args.fp16_text_enc:
|
|
return torch.float16
|
|
elif args.bf16_text_enc:
|
|
return torch.bfloat16
|
|
elif args.fp32_text_enc:
|
|
return torch.float32
|
|
|
|
if is_device_cpu(device):
|
|
return torch.float16
|
|
|
|
return torch.float16
|
|
|
|
|
|
def intermediate_device():
|
|
if args.gpu_only:
|
|
return get_torch_device()
|
|
else:
|
|
return torch.device("cpu")
|
|
|
|
def intermediate_dtype():
|
|
if args.fp16_intermediates:
|
|
return torch.float16
|
|
else:
|
|
return torch.float32
|
|
|
|
def vae_device():
|
|
if args.cpu_vae:
|
|
return torch.device("cpu")
|
|
return get_torch_device()
|
|
|
|
def vae_offload_device():
|
|
if args.gpu_only:
|
|
return get_torch_device()
|
|
else:
|
|
return torch.device("cpu")
|
|
|
|
def vae_dtype(device=None, allowed_dtypes=[]):
|
|
if args.fp16_vae:
|
|
return torch.float16
|
|
elif args.bf16_vae:
|
|
return torch.bfloat16
|
|
elif args.fp32_vae:
|
|
return torch.float32
|
|
|
|
for d in allowed_dtypes:
|
|
if d == torch.float16 and should_use_fp16(device):
|
|
return d
|
|
|
|
if d == torch.bfloat16 and should_use_bf16(device):
|
|
return d
|
|
|
|
return torch.float32
|
|
|
|
def get_autocast_device(dev):
|
|
if hasattr(dev, 'type'):
|
|
return dev.type
|
|
return "cuda"
|
|
|
|
def supports_dtype(device, dtype): #TODO
|
|
if dtype == torch.float32:
|
|
return True
|
|
if is_device_cpu(device):
|
|
return False
|
|
if dtype == torch.float16:
|
|
return True
|
|
if dtype == torch.bfloat16:
|
|
return True
|
|
return False
|
|
|
|
def supports_cast(device, dtype): #TODO
|
|
if dtype == torch.float32:
|
|
return True
|
|
if dtype == torch.float16:
|
|
return True
|
|
if directml_enabled: #TODO: test this
|
|
return False
|
|
if dtype == torch.bfloat16:
|
|
return True
|
|
if is_device_mps(device):
|
|
return False
|
|
if dtype == torch.float8_e4m3fn:
|
|
return True
|
|
if dtype == torch.float8_e5m2:
|
|
return True
|
|
return False
|
|
|
|
def pick_weight_dtype(dtype, fallback_dtype, device=None):
|
|
if dtype is None:
|
|
dtype = fallback_dtype
|
|
elif dtype_size(dtype) > dtype_size(fallback_dtype):
|
|
dtype = fallback_dtype
|
|
|
|
if not supports_cast(device, dtype):
|
|
dtype = fallback_dtype
|
|
|
|
return dtype
|
|
|
|
def device_supports_non_blocking(device):
|
|
if args.force_non_blocking:
|
|
return True
|
|
if is_device_mps(device):
|
|
return False #pytorch bug? mps doesn't support non blocking
|
|
if is_intel_xpu(): #xpu does support non blocking but it is slower on iGPUs for some reason so disable by default until situation changes
|
|
return False
|
|
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
|
return False
|
|
if directml_enabled:
|
|
return False
|
|
return True
|
|
|
|
def force_channels_last():
|
|
if args.force_channels_last:
|
|
return True
|
|
|
|
#TODO
|
|
return False
|
|
|
|
|
|
STREAMS = {}
|
|
NUM_STREAMS = 0
|
|
if args.async_offload is not None:
|
|
NUM_STREAMS = args.async_offload
|
|
else:
|
|
# Enable by default on Nvidia and AMD
|
|
if is_nvidia() or is_amd():
|
|
NUM_STREAMS = 2
|
|
|
|
if args.disable_async_offload:
|
|
NUM_STREAMS = 0
|
|
|
|
if NUM_STREAMS > 0:
|
|
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
|
|
|
def current_stream(device):
|
|
if device is None:
|
|
return None
|
|
if is_device_cuda(device):
|
|
return torch.cuda.current_stream()
|
|
elif is_device_xpu(device):
|
|
return torch.xpu.current_stream()
|
|
else:
|
|
return None
|
|
|
|
stream_counters = {}
|
|
|
|
STREAM_CAST_BUFFERS = {}
|
|
LARGEST_CASTED_WEIGHT = (None, 0)
|
|
STREAM_AIMDO_CAST_BUFFERS = {}
|
|
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
|
|
|
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
|
|
|
def get_cast_buffer(offload_stream, device, size, ref):
|
|
global LARGEST_CASTED_WEIGHT
|
|
|
|
if offload_stream is not None:
|
|
wf_context = offload_stream
|
|
if hasattr(wf_context, "as_context"):
|
|
wf_context = wf_context.as_context(offload_stream)
|
|
else:
|
|
wf_context = nullcontext()
|
|
|
|
cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None)
|
|
if cast_buffer is None or cast_buffer.numel() < size:
|
|
if ref is LARGEST_CASTED_WEIGHT[0]:
|
|
#If there is one giant weight we do not want both streams to
|
|
#allocate a buffer for it. It's up to the caster to get the other
|
|
#offload stream in this corner case
|
|
return None
|
|
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
|
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
|
synchronize()
|
|
del STREAM_CAST_BUFFERS[offload_stream]
|
|
del cast_buffer
|
|
soft_empty_cache()
|
|
with wf_context:
|
|
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
|
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
|
|
|
if size > LARGEST_CASTED_WEIGHT[1]:
|
|
LARGEST_CASTED_WEIGHT = (ref, size)
|
|
|
|
return cast_buffer
|
|
|
|
def get_aimdo_cast_buffer(offload_stream, device):
|
|
cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
|
|
if cast_buffer is None:
|
|
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
|
|
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
|
return cast_buffer
|
|
|
|
def reset_cast_buffers():
|
|
global LARGEST_CASTED_WEIGHT
|
|
global LARGEST_AIMDO_CASTED_WEIGHT
|
|
|
|
LARGEST_CASTED_WEIGHT = (None, 0)
|
|
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
|
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
|
|
if offload_stream is not None:
|
|
offload_stream.synchronize()
|
|
synchronize()
|
|
|
|
for mmap_obj in DIRTY_MMAPS:
|
|
mmap_obj.bounce()
|
|
DIRTY_MMAPS.clear()
|
|
|
|
for loaded_model in current_loaded_models:
|
|
model = loaded_model.model
|
|
if model is not None and model.is_dynamic():
|
|
pin_state = model.model.dynamic_pins[model.load_device]
|
|
|
|
if pin_state["active"]:
|
|
*_, buckets = pin_state["weights"]
|
|
for size, bucket in list(buckets.items()):
|
|
bucket[:] = [ entry for entry in bucket if entry[-1] is not None ]
|
|
if not bucket:
|
|
del buckets[size]
|
|
|
|
pin_state["active"] = False
|
|
model.partially_unload_ram(1e30, subsets=[ "patches" ])
|
|
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0], [0], {})
|
|
|
|
STREAM_CAST_BUFFERS.clear()
|
|
STREAM_AIMDO_CAST_BUFFERS.clear()
|
|
soft_empty_cache()
|
|
|
|
def get_offload_stream(device):
|
|
stream_counter = stream_counters.get(device, 0)
|
|
if NUM_STREAMS == 0:
|
|
return None
|
|
|
|
if torch.compiler.is_compiling():
|
|
return None
|
|
|
|
if device in STREAMS:
|
|
ss = STREAMS[device]
|
|
#Sync the oldest stream in the queue with the current
|
|
ss[stream_counter].wait_stream(current_stream(device))
|
|
stream_counter = (stream_counter + 1) % len(ss)
|
|
stream_counters[device] = stream_counter
|
|
return ss[stream_counter]
|
|
elif is_device_cuda(device):
|
|
ss = []
|
|
for k in range(NUM_STREAMS):
|
|
s1 = torch.cuda.Stream(device=device, priority=0)
|
|
s1.as_context = torch.cuda.stream
|
|
ss.append(s1)
|
|
STREAMS[device] = ss
|
|
s = ss[stream_counter]
|
|
stream_counters[device] = stream_counter
|
|
return s
|
|
elif is_device_xpu(device):
|
|
ss = []
|
|
for k in range(NUM_STREAMS):
|
|
s1 = torch.xpu.Stream(device=device, priority=0)
|
|
s1.as_context = torch.xpu.stream
|
|
ss.append(s1)
|
|
STREAMS[device] = ss
|
|
s = ss[stream_counter]
|
|
stream_counters[device] = stream_counter
|
|
return s
|
|
return None
|
|
|
|
def sync_stream(device, stream):
|
|
if stream is None or current_stream(device) is None:
|
|
return
|
|
current_stream(device).wait_stream(stream)
|
|
|
|
|
|
def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
|
|
wf_context = nullcontext()
|
|
if stream is not None:
|
|
wf_context = stream
|
|
if hasattr(wf_context, "as_context"):
|
|
wf_context = wf_context.as_context(stream)
|
|
|
|
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) if r is not None else [None] * len(tensors)
|
|
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
|
|
with wf_context:
|
|
for tensor in tensors:
|
|
dest_view = dest_views.pop(0)
|
|
dest2_view = dest2_views.pop(0) if dest2_views is not None else None
|
|
if tensor is None:
|
|
continue
|
|
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view, stream=stream, destination2=dest2_view):
|
|
continue
|
|
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
|
mark_mmap_dirty(storage)
|
|
if dest_view is not None:
|
|
dest_view.copy_(tensor, non_blocking=non_blocking)
|
|
if dest2_view is not None:
|
|
dest2_view.copy_(tensor if dest_view is None else dest_view, non_blocking=non_blocking)
|
|
|
|
|
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
|
if device is None or weight.device == device:
|
|
if not copy:
|
|
if dtype is None or weight.dtype == dtype:
|
|
return weight
|
|
if stream is not None:
|
|
wf_context = stream
|
|
if hasattr(wf_context, "as_context"):
|
|
wf_context = wf_context.as_context(stream)
|
|
with wf_context:
|
|
return weight.to(dtype=dtype, copy=copy)
|
|
return weight.to(dtype=dtype, copy=copy)
|
|
|
|
|
|
if stream is not None:
|
|
wf_context = stream
|
|
if hasattr(wf_context, "as_context"):
|
|
wf_context = wf_context.as_context(stream)
|
|
with wf_context:
|
|
if r is None:
|
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
r.copy_(weight, non_blocking=non_blocking)
|
|
else:
|
|
if r is None:
|
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
r.copy_(weight, non_blocking=non_blocking)
|
|
return r
|
|
|
|
def cast_to_device(tensor, device, dtype, copy=False):
|
|
non_blocking = device_supports_non_blocking(device)
|
|
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
|
|
|
|
|
PINNED_MEMORY = {}
|
|
TOTAL_PINNED_MEMORY = 0
|
|
MAX_PINNED_MEMORY = -1
|
|
if not args.disable_pinned_memory:
|
|
if is_nvidia() or is_amd():
|
|
ram = get_total_memory(torch.device("cpu"))
|
|
if WINDOWS:
|
|
MAX_PINNED_MEMORY = ram * 0.40 # Windows limit is apparently 50%
|
|
else:
|
|
MAX_PINNED_MEMORY = ram * 0.90
|
|
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
|
|
|
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
|
|
|
def pinned_hostbuf_size(size):
|
|
return max(0, int(min(size, MAX_PINNED_MEMORY) * 2))
|
|
|
|
def discard_cuda_async_error():
|
|
try:
|
|
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
|
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
|
_ = a + b
|
|
synchronize()
|
|
except RuntimeError:
|
|
#Dump it! We already know about it from the synchronous return
|
|
pass
|
|
|
|
def pin_memory(tensor):
|
|
global TOTAL_PINNED_MEMORY
|
|
if MAX_PINNED_MEMORY <= 0:
|
|
return False
|
|
|
|
if type(tensor).__name__ not in PINNING_ALLOWED_TYPES:
|
|
return False
|
|
|
|
if not is_device_cpu(tensor.device):
|
|
return False
|
|
|
|
if tensor.is_pinned():
|
|
#NOTE: Cuda does detect when a tensor is already pinned and would
|
|
#error below, but there are proven cases where this also queues an error
|
|
#on the GPU async. So dont trust the CUDA API and guard here
|
|
return False
|
|
|
|
if not tensor.is_contiguous():
|
|
return False
|
|
|
|
size = tensor.nbytes
|
|
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
|
|
ensure_pin_registerable(size)
|
|
|
|
ptr = tensor.data_ptr()
|
|
if ptr == 0:
|
|
return False
|
|
|
|
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
|
|
PINNED_MEMORY[ptr] = size
|
|
TOTAL_PINNED_MEMORY += size
|
|
return True
|
|
else:
|
|
logging.warning("Pin error.")
|
|
discard_cuda_async_error()
|
|
|
|
return False
|
|
|
|
def unpin_memory(tensor):
|
|
global TOTAL_PINNED_MEMORY
|
|
if MAX_PINNED_MEMORY <= 0:
|
|
return False
|
|
|
|
if not is_device_cpu(tensor.device):
|
|
return False
|
|
|
|
ptr = tensor.data_ptr()
|
|
size = tensor.nbytes
|
|
|
|
size_stored = PINNED_MEMORY.get(ptr, None)
|
|
if size_stored is None:
|
|
logging.warning("Tried to unpin tensor not pinned by ComfyUI")
|
|
return False
|
|
|
|
if size != size_stored:
|
|
logging.warning("Size of pinned tensor changed")
|
|
return False
|
|
|
|
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
|
size = PINNED_MEMORY.pop(ptr)
|
|
TOTAL_PINNED_MEMORY -= size
|
|
return True
|
|
else:
|
|
logging.warning("Unpin error.")
|
|
discard_cuda_async_error()
|
|
|
|
return False
|
|
|
|
def sage_attention_enabled():
|
|
return args.use_sage_attention
|
|
|
|
def flash_attention_enabled():
|
|
return args.use_flash_attention
|
|
|
|
def xformers_enabled():
|
|
global directml_enabled
|
|
global cpu_state
|
|
if cpu_state != CPUState.GPU:
|
|
return False
|
|
if is_intel_xpu():
|
|
return False
|
|
if is_ascend_npu():
|
|
return False
|
|
if is_mlu():
|
|
return False
|
|
if is_ixuca():
|
|
return False
|
|
if directml_enabled:
|
|
return False
|
|
return XFORMERS_IS_AVAILABLE
|
|
|
|
|
|
def xformers_enabled_vae():
|
|
enabled = xformers_enabled()
|
|
if not enabled:
|
|
return False
|
|
|
|
return XFORMERS_ENABLED_VAE
|
|
|
|
def pytorch_attention_enabled():
|
|
global ENABLE_PYTORCH_ATTENTION
|
|
return ENABLE_PYTORCH_ATTENTION
|
|
|
|
def pytorch_attention_enabled_vae():
|
|
if is_amd():
|
|
return False # enabling pytorch attention on AMD currently causes crash when doing high res
|
|
return pytorch_attention_enabled()
|
|
|
|
def pytorch_attention_flash_attention():
|
|
global ENABLE_PYTORCH_ATTENTION
|
|
if ENABLE_PYTORCH_ATTENTION:
|
|
#TODO: more reliable way of checking for flash attention?
|
|
if is_nvidia():
|
|
return True
|
|
if is_intel_xpu():
|
|
return True
|
|
if is_ascend_npu():
|
|
return True
|
|
if is_mlu():
|
|
return True
|
|
if is_amd():
|
|
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
|
if is_ixuca():
|
|
return True
|
|
return False
|
|
|
|
def force_upcast_attention_dtype():
|
|
upcast = args.force_upcast_attention
|
|
|
|
macos_version = mac_version()
|
|
if macos_version is not None and ((14, 5) <= macos_version): # black image bug on recent versions of macOS, I don't think it's ever getting fixed
|
|
upcast = True
|
|
|
|
if upcast:
|
|
return {torch.float16: torch.float32}
|
|
else:
|
|
return None
|
|
|
|
def get_free_memory(dev=None, torch_free_too=False):
|
|
global directml_enabled
|
|
if dev is None:
|
|
dev = get_torch_device()
|
|
|
|
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
|
mem_free_total = psutil.virtual_memory().available
|
|
mem_free_torch = mem_free_total
|
|
else:
|
|
if directml_enabled:
|
|
mem_free_total = 1024 * 1024 * 1024 #TODO
|
|
mem_free_torch = mem_free_total
|
|
elif is_intel_xpu():
|
|
stats = torch.xpu.memory_stats(dev)
|
|
mem_active = stats['active_bytes.all.current']
|
|
mem_reserved = stats['reserved_bytes.all.current']
|
|
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
|
mem_free_torch = mem_reserved - mem_active
|
|
mem_free_total = mem_free_xpu + mem_free_torch
|
|
elif is_ascend_npu():
|
|
stats = torch.npu.memory_stats(dev)
|
|
mem_active = stats['active_bytes.all.current']
|
|
mem_reserved = stats['reserved_bytes.all.current']
|
|
mem_free_npu, _ = torch.npu.mem_get_info(dev)
|
|
mem_free_torch = mem_reserved - mem_active
|
|
mem_free_total = mem_free_npu + mem_free_torch
|
|
elif is_mlu():
|
|
stats = torch.mlu.memory_stats(dev)
|
|
mem_active = stats['active_bytes.all.current']
|
|
mem_reserved = stats['reserved_bytes.all.current']
|
|
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
|
|
mem_free_torch = mem_reserved - mem_active
|
|
mem_free_total = mem_free_mlu + mem_free_torch
|
|
else:
|
|
stats = torch.cuda.memory_stats(dev)
|
|
mem_active = stats['active_bytes.all.current']
|
|
mem_reserved = stats['reserved_bytes.all.current']
|
|
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
|
|
mem_free_torch = mem_reserved - mem_active
|
|
mem_free_total = mem_free_cuda + mem_free_torch
|
|
|
|
if torch_free_too:
|
|
return (mem_free_total, mem_free_torch)
|
|
else:
|
|
return mem_free_total
|
|
|
|
def cpu_mode():
|
|
global cpu_state
|
|
return cpu_state == CPUState.CPU
|
|
|
|
def mps_mode():
|
|
global cpu_state
|
|
return cpu_state == CPUState.MPS
|
|
|
|
def is_device_type(device, type):
|
|
if hasattr(device, 'type'):
|
|
if (device.type == type):
|
|
return True
|
|
return False
|
|
|
|
def is_device_cpu(device):
|
|
return is_device_type(device, 'cpu')
|
|
|
|
def is_device_mps(device):
|
|
return is_device_type(device, 'mps')
|
|
|
|
def is_device_xpu(device):
|
|
return is_device_type(device, 'xpu')
|
|
|
|
def is_device_cuda(device):
|
|
return is_device_type(device, 'cuda')
|
|
|
|
def set_torch_device(device):
|
|
"""Set the current device for the given torch device. Supports CUDA and XPU."""
|
|
if is_device_cuda(device):
|
|
torch.cuda.set_device(device)
|
|
elif is_device_xpu(device):
|
|
torch.xpu.set_device(device)
|
|
|
|
def is_directml_enabled():
|
|
global directml_enabled
|
|
if directml_enabled:
|
|
return True
|
|
|
|
return False
|
|
|
|
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
|
if device is not None:
|
|
if is_device_cpu(device):
|
|
return False
|
|
|
|
if args.force_fp16:
|
|
return True
|
|
|
|
if FORCE_FP32:
|
|
return False
|
|
|
|
if is_directml_enabled():
|
|
return True
|
|
|
|
if (device is not None and is_device_mps(device)) or mps_mode():
|
|
return True
|
|
|
|
if cpu_mode():
|
|
return False
|
|
|
|
if is_intel_xpu():
|
|
return torch.xpu.get_device_properties(device).has_fp16
|
|
|
|
if is_ascend_npu():
|
|
return True
|
|
|
|
if is_mlu():
|
|
return True
|
|
|
|
if is_ixuca():
|
|
return True
|
|
|
|
if torch.version.hip:
|
|
return True
|
|
|
|
props = torch.cuda.get_device_properties(device)
|
|
if props.major >= 8:
|
|
return True
|
|
|
|
if props.major < 6:
|
|
return False
|
|
|
|
#FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
|
|
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
|
for x in nvidia_10_series:
|
|
if x in props.name.lower():
|
|
if WINDOWS or manual_cast:
|
|
return True
|
|
else:
|
|
return False #weird linux behavior where fp32 is faster
|
|
|
|
if manual_cast:
|
|
free_model_memory = maximum_vram_for_weights(device)
|
|
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
|
return True
|
|
|
|
if props.major < 7:
|
|
return False
|
|
|
|
#FP16 is just broken on these cards
|
|
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
|
|
for x in nvidia_16_series:
|
|
if x in props.name:
|
|
return False
|
|
|
|
return True
|
|
|
|
def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
|
if device is not None:
|
|
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
|
|
return False
|
|
|
|
if FORCE_FP32:
|
|
return False
|
|
|
|
if directml_enabled:
|
|
return False
|
|
|
|
if (device is not None and is_device_mps(device)) or mps_mode():
|
|
if mac_version() < (14,):
|
|
return False
|
|
return True
|
|
|
|
if cpu_mode():
|
|
return False
|
|
|
|
if is_intel_xpu():
|
|
return torch.xpu.is_bf16_supported()
|
|
|
|
if is_ascend_npu():
|
|
return True
|
|
|
|
if is_ixuca():
|
|
return True
|
|
|
|
if is_amd():
|
|
arch = torch.cuda.get_device_properties(device).gcnArchName
|
|
if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16
|
|
if manual_cast:
|
|
return True
|
|
return False
|
|
|
|
props = torch.cuda.get_device_properties(device)
|
|
|
|
if is_mlu():
|
|
if props.major > 3:
|
|
return True
|
|
|
|
if props.major >= 8:
|
|
return True
|
|
|
|
bf16_works = torch.cuda.is_bf16_supported()
|
|
|
|
if bf16_works and manual_cast:
|
|
free_model_memory = maximum_vram_for_weights(device)
|
|
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
|
return True
|
|
|
|
return False
|
|
|
|
def supports_fp8_compute(device=None):
|
|
if SUPPORT_FP8_OPS:
|
|
return True
|
|
|
|
if not is_nvidia():
|
|
return False
|
|
|
|
props = torch.cuda.get_device_properties(device)
|
|
if props.major >= 9:
|
|
return True
|
|
if props.major < 8:
|
|
return False
|
|
if props.minor < 9:
|
|
return False
|
|
|
|
if torch_version_numeric < (2, 3):
|
|
return False
|
|
|
|
if WINDOWS:
|
|
if torch_version_numeric < (2, 4):
|
|
return False
|
|
|
|
return True
|
|
|
|
def supports_nvfp4_compute(device=None):
|
|
if not is_nvidia():
|
|
return False
|
|
|
|
props = torch.cuda.get_device_properties(device)
|
|
if props.major < 10:
|
|
return False
|
|
|
|
return True
|
|
|
|
def supports_mxfp8_compute(device=None):
|
|
if not is_nvidia():
|
|
return False
|
|
|
|
if torch_version_numeric < (2, 10):
|
|
return False
|
|
|
|
props = torch.cuda.get_device_properties(device)
|
|
if props.major < 10:
|
|
return False
|
|
|
|
return True
|
|
|
|
def supports_fp64(device=None):
|
|
if is_device_mps(device):
|
|
return False
|
|
|
|
if is_intel_xpu():
|
|
return False
|
|
|
|
if is_directml_enabled():
|
|
return False
|
|
|
|
if is_ixuca():
|
|
return False
|
|
|
|
return True
|
|
|
|
def extended_fp16_support():
|
|
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
|
if torch_version_numeric < (2, 7):
|
|
return False
|
|
|
|
return True
|
|
|
|
LORA_COMPUTE_DTYPES = {}
|
|
def lora_compute_dtype(device):
|
|
dtype = LORA_COMPUTE_DTYPES.get(device, None)
|
|
if dtype is not None:
|
|
return dtype
|
|
|
|
if should_use_fp16(device):
|
|
dtype = torch.float16
|
|
else:
|
|
dtype = torch.float32
|
|
|
|
LORA_COMPUTE_DTYPES[device] = dtype
|
|
return dtype
|
|
|
|
def synchronize():
|
|
if cpu_mode():
|
|
return
|
|
if is_intel_xpu():
|
|
torch.xpu.synchronize()
|
|
elif torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
|
|
def soft_empty_cache(force=False):
|
|
if cpu_mode():
|
|
return
|
|
global cpu_state
|
|
if cpu_state == CPUState.MPS:
|
|
torch.mps.empty_cache()
|
|
elif is_intel_xpu():
|
|
torch.xpu.synchronize()
|
|
torch.xpu.empty_cache()
|
|
elif is_ascend_npu():
|
|
torch.npu.empty_cache()
|
|
elif is_mlu():
|
|
torch.mlu.empty_cache()
|
|
elif torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.ipc_collect()
|
|
|
|
def unload_all_models():
|
|
for device in get_all_torch_devices():
|
|
free_memory(1e30, device)
|
|
|
|
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
|
|
'Unload only model and its clones - primarily for multigpu cloning purposes.'
|
|
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
|
|
additional_models = []
|
|
if unload_additional_models:
|
|
additional_models = model.get_nested_additional_models()
|
|
keep_loaded = []
|
|
for loaded_model in initial_keep_loaded:
|
|
if loaded_model.model is not None:
|
|
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
|
continue
|
|
# check additional models if they are a match
|
|
skip = False
|
|
for add_model in additional_models:
|
|
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
|
skip = True
|
|
break
|
|
if skip:
|
|
continue
|
|
keep_loaded.append(loaded_model)
|
|
if not all_devices:
|
|
free_memory(1e30, get_torch_device(), keep_loaded)
|
|
else:
|
|
for device in get_all_torch_devices():
|
|
free_memory(1e30, device, keep_loaded)
|
|
|
|
def debug_memory_summary():
|
|
if is_amd() or is_nvidia():
|
|
return torch.cuda.memory.memory_summary()
|
|
return ""
|
|
|
|
class InterruptProcessingException(BaseException):
|
|
pass
|
|
|
|
interrupt_processing_mutex = threading.RLock()
|
|
|
|
interrupt_processing = False
|
|
def interrupt_current_processing(value=True):
|
|
global interrupt_processing
|
|
global interrupt_processing_mutex
|
|
with interrupt_processing_mutex:
|
|
interrupt_processing = value
|
|
|
|
def processing_interrupted():
|
|
global interrupt_processing
|
|
global interrupt_processing_mutex
|
|
with interrupt_processing_mutex:
|
|
return interrupt_processing
|
|
|
|
def throw_exception_if_processing_interrupted():
|
|
global interrupt_processing
|
|
global interrupt_processing_mutex
|
|
with interrupt_processing_mutex:
|
|
if interrupt_processing:
|
|
interrupt_processing = False
|
|
raise InterruptProcessingException()
|