Merge branch 'master' into matt/asset-image-dimensions-metadata
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

This commit is contained in:
Matt Miller 2026-05-26 21:08:37 -07:00 committed by GitHub
commit 81c4bc5fe9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
115 changed files with 29564 additions and 3022 deletions

View File

@ -160,10 +160,12 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
preview_url = None
else:
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
asset_content_hash = result.asset.hash if result.asset else None
return schemas_out.Asset(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
hash=asset_content_hash,
asset_hash=asset_content_hash,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,

View File

@ -10,6 +10,7 @@ class Asset(BaseModel):
id: str
name: str
hash: str | None = None
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None

View File

@ -4,7 +4,6 @@ Tier 1: Filesystem metadata (zero parsing)
Tier 2: Safetensors header metadata (fast JSON read only)
"""
from __future__ import annotations
import json
import logging

View File

@ -1,5 +1,3 @@
from __future__ import annotations
import os
import folder_paths
import glob

View File

@ -1,4 +1,3 @@
from __future__ import annotations
import argparse
import logging
import os

View File

@ -1,5 +1,3 @@
from __future__ import annotations
import os
import base64
import json

View File

@ -1,4 +1,3 @@
from __future__ import annotations
import json
import os
import re

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1553,7 +1553,7 @@
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"category": "Image generation and editing/Canny to image",
"category": "Image generation and editing/Conditioned",
"description": "Generates an image from a Canny edge map using Z-Image-Turbo, with text conditioning."
}
]

View File

@ -3600,7 +3600,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Canny to video",
"category": "Video generation and editing/Conditioned",
"description": "Generates video from Canny edge maps using LTX-2, with optional synchronized audio."
}
]

View File

@ -1401,7 +1401,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/ControlNet",
"category": "Image generation and editing/Conditioned",
"description": "Generates images from a text prompt and ControlNet conditioning (e.g. depth, canny) using Z-Image-Turbo."
}
]

View File

@ -1579,7 +1579,7 @@
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"category": "Image generation and editing/Depth to image",
"category": "Image generation and editing/Conditioned",
"description": "Generates an image from a depth map using Z-Image-Turbo with text conditioning."
},
{

View File

@ -4233,7 +4233,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Depth to video",
"category": "Video generation and editing/Conditioned",
"description": "Generates depth-controlled video with LTX-2: motion and structure follow a depth-reference video alongside text prompting, optional first-frame image conditioning, with optional synchronized audio."
},
{

View File

@ -3350,7 +3350,7 @@
}
],
"extra": {},
"category": "Video generation and editing/First-Last-Frame to Video",
"category": "Video generation and editing/Conditioned",
"description": "Generates a video interpolating between first and last keyframes using LTX-2.3."
}
]

View File

@ -3350,7 +3350,7 @@
}
],
"extra": {},
"category": "Video generation and editing/First-Last-Frame to Video",
"category": "Video generation and editing/FLF2V",
"description": "Generates a video that interpolates between the first and last keyframes using LTX-2.3, including optional audio."
}
]

File diff suppressed because it is too large Load Diff

View File

@ -310,9 +310,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Text generation/Image Captioning",
"category": "Image Tools",
"description": "Generates descriptive captions for images using Google's Gemini multimodal LLM."
}
]
}
}
}

View File

@ -1,19 +1,18 @@
{
"id": "6af0a6c1-0161-4528-8685-65776e838d44",
"revision": 0,
"last_node_id": 75,
"last_link_id": 245,
"last_node_id": 76,
"last_link_id": 0,
"nodes": [
{
"id": 75,
"type": "488652fd-6edf-4d06-8f9f-4d84d3a34eaf",
"id": 76,
"type": "96338968-1242-4f02-b6a1-d496af4bcffe",
"pos": [
600,
830
670,
1280
],
"size": [
400,
110
201.3125
],
"flags": {},
"order": 0,
@ -59,47 +58,44 @@
"links": []
}
],
"title": "Image Depth Estimation (Lotus Depth)",
"properties": {
"proxyWidgets": [
[
"-1",
"28",
"sigma"
],
[
"-1",
"10",
"unet_name"
],
[
"-1",
"14",
"vae_name"
]
],
"cnr_id": "comfy-core",
"ver": "0.14.1"
},
"widgets_values": [
999.0000000000002,
"lotus-depth-d-v1-1.safetensors",
"vae-ft-mse-840000-ema-pruned.safetensors"
]
"widgets_values": []
}
],
"links": [],
"groups": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "488652fd-6edf-4d06-8f9f-4d84d3a34eaf",
"id": "96338968-1242-4f02-b6a1-d496af4bcffe",
"version": 1,
"state": {
"lastGroupId": 1,
"lastNodeId": 75,
"lastNodeId": 76,
"lastLinkId": 245,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Image to Depth Map (Lotus)",
"name": "Image Depth Estimation (Lotus Depth)",
"inputNode": {
"id": -10,
"bounding": [
@ -191,12 +187,12 @@
"id": 10,
"type": "UNETLoader",
"pos": [
108.05555555555557,
-253.05555555555557
110,
-250
],
"size": [
254.93706597222226,
82
260,
90
],
"flags": {},
"order": 4,
@ -234,9 +230,9 @@
}
],
"properties": {
"Node name for S&R": "UNETLoader",
"cnr_id": "comfy-core",
"ver": "0.3.34",
"Node name for S&R": "UNETLoader",
"models": [
{
"name": "lotus-depth-d-v1-1.safetensors",
@ -255,12 +251,12 @@
"id": 18,
"type": "DisableNoise",
"pos": [
607.0641494069639,
-268.33337840371513
610,
-270
],
"size": [
175,
33.333333333333336
180,
40
],
"flags": {},
"order": 0,
@ -278,26 +274,25 @@
}
],
"properties": {
"Node name for S&R": "DisableNoise",
"cnr_id": "comfy-core",
"ver": "0.3.34",
"Node name for S&R": "DisableNoise",
"widget_ue_connectable": {}
},
"widgets_values": []
}
},
{
"id": 23,
"id": 74,
"type": "VAEEncode",
"pos": [
620,
160
],
"size": [
175,
180,
50
],
"flags": {},
"order": 10,
"order": 11,
"mode": 0,
"inputs": [
{
@ -325,12 +320,11 @@
}
],
"properties": {
"Node name for S&R": "VAEEncode",
"cnr_id": "comfy-core",
"ver": "0.3.34",
"Node name for S&R": "VAEEncode",
"widget_ue_connectable": {}
},
"widgets_values": []
}
},
{
"id": 21,
@ -341,7 +335,7 @@
],
"size": [
210,
58
60
],
"flags": {},
"order": 1,
@ -369,9 +363,9 @@
}
],
"properties": {
"Node name for S&R": "KSamplerSelect",
"cnr_id": "comfy-core",
"ver": "0.3.34",
"Node name for S&R": "KSamplerSelect",
"widget_ue_connectable": {}
},
"widgets_values": [
@ -386,7 +380,7 @@
-170
],
"size": [
175,
180,
50
],
"flags": {},
@ -418,12 +412,11 @@
}
],
"properties": {
"Node name for S&R": "BasicGuider",
"cnr_id": "comfy-core",
"ver": "0.3.34",
"Node name for S&R": "BasicGuider",
"widget_ue_connectable": {}
},
"widgets_values": []
}
},
{
"id": 16,
@ -433,8 +426,8 @@
-130
],
"size": [
295.99609375,
271.65798611111114
300,
280
],
"flags": {},
"order": 6,
@ -490,12 +483,11 @@
}
],
"properties": {
"Node name for S&R": "SamplerCustomAdvanced",
"cnr_id": "comfy-core",
"ver": "0.3.34",
"Node name for S&R": "SamplerCustomAdvanced",
"widget_ue_connectable": {}
},
"widgets_values": []
}
},
{
"id": 28,
@ -506,10 +498,10 @@
],
"size": [
210,
58
60
],
"flags": {},
"order": 11,
"order": 10,
"mode": 0,
"inputs": [
{
@ -540,9 +532,9 @@
}
],
"properties": {
"Node name for S&R": "SetFirstSigma",
"cnr_id": "comfy-core",
"ver": "0.3.34",
"Node name for S&R": "SetFirstSigma",
"widget_ue_connectable": {}
},
"widgets_values": [
@ -557,7 +549,7 @@
-120
],
"size": [
175,
180,
50
],
"flags": {},
@ -589,12 +581,11 @@
}
],
"properties": {
"Node name for S&R": "VAEDecode",
"cnr_id": "comfy-core",
"ver": "0.3.34",
"Node name for S&R": "VAEDecode",
"widget_ue_connectable": {}
},
"widgets_values": []
}
},
{
"id": 22,
@ -604,8 +595,8 @@
-220
],
"size": [
175,
33.333333333333336
180,
40
],
"flags": {},
"order": 9,
@ -630,12 +621,11 @@
}
],
"properties": {
"Node name for S&R": "ImageInvert",
"cnr_id": "comfy-core",
"ver": "0.3.34",
"Node name for S&R": "ImageInvert",
"widget_ue_connectable": {}
},
"widgets_values": []
}
},
{
"id": 14,
@ -645,8 +635,8 @@
-90
],
"size": [
254.93706597222226,
58
260,
60
],
"flags": {},
"order": 5,
@ -675,9 +665,9 @@
}
],
"properties": {
"Node name for S&R": "VAELoader",
"cnr_id": "comfy-core",
"ver": "0.3.34",
"Node name for S&R": "VAELoader",
"models": [
{
"name": "vae-ft-mse-840000-ema-pruned.safetensors",
@ -692,15 +682,15 @@
]
},
{
"id": 68,
"id": 75,
"type": "LotusConditioning",
"pos": [
400,
-150
],
"size": [
175,
33.333333333333336
180,
40
],
"flags": {},
"order": 2,
@ -718,12 +708,11 @@
}
],
"properties": {
"Node name for S&R": "LotusConditioning",
"cnr_id": "comfy-core",
"ver": "0.3.34",
"Node name for S&R": "LotusConditioning",
"widget_ue_connectable": {}
},
"widgets_values": []
}
},
{
"id": 20,
@ -734,7 +723,7 @@
],
"size": [
210,
106
110
],
"flags": {},
"order": 8,
@ -786,9 +775,9 @@
}
],
"properties": {
"Node name for S&R": "BasicScheduler",
"cnr_id": "comfy-core",
"ver": "0.3.34",
"Node name for S&R": "BasicScheduler",
"widget_ue_connectable": {}
},
"widgets_values": [
@ -850,7 +839,7 @@
},
{
"id": 201,
"origin_id": 23,
"origin_id": 74,
"origin_slot": 0,
"target_id": 16,
"target_slot": 4,
@ -866,7 +855,7 @@
},
{
"id": 238,
"origin_id": 68,
"origin_id": 75,
"origin_slot": 0,
"target_id": 19,
"target_slot": 1,
@ -892,7 +881,7 @@
"id": 38,
"origin_id": 14,
"origin_slot": 0,
"target_id": 23,
"target_id": 74,
"target_slot": 1,
"type": "VAE"
},
@ -908,7 +897,7 @@
"id": 37,
"origin_id": -10,
"origin_slot": 0,
"target_id": 23,
"target_id": 74,
"target_slot": 0,
"type": "IMAGE"
},
@ -948,12 +937,11 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Depth to image",
"category": "Conditioning & Preprocessors/Depth",
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
}
]
},
"config": {},
"extra": {
"ds": {
"scale": 1.3589709866044692,
@ -961,8 +949,6 @@
-138.53613935617864,
-786.0629126022195
]
},
"workflowRendererVersion": "LG"
},
"version": 0.4
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,779 @@
{
"revision": 0,
"last_node_id": 33,
"last_link_id": 0,
"nodes": [
{
"id": 33,
"type": "6062babb-b649-4a71-be9e-20ebce567744",
"pos": [
-450,
4240
],
"size": [
420,
400
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": null
},
{
"name": "face_landmarker",
"type": "FACE_LANDMARKER",
"link": null
},
{
"name": "detector_variant",
"type": "COMBO",
"widget": {
"name": "detector_variant"
},
"link": null
},
{
"name": "num_faces",
"type": "INT",
"widget": {
"name": "num_faces"
},
"link": null
},
{
"label": "custom_face_oval",
"name": "regions.face_oval",
"type": "BOOLEAN",
"widget": {
"name": "regions.face_oval"
},
"link": null
},
{
"label": "custom_lips",
"name": "regions.lips",
"type": "BOOLEAN",
"widget": {
"name": "regions.lips"
},
"link": null
},
{
"label": "custom_left_eye",
"name": "regions.left_eye",
"type": "BOOLEAN",
"widget": {
"name": "regions.left_eye"
},
"link": null
},
{
"label": "custom_right_eye",
"name": "regions.right_eye",
"type": "BOOLEAN",
"widget": {
"name": "regions.right_eye"
},
"link": null
},
{
"label": "custom_irises",
"name": "regions.irises",
"type": "BOOLEAN",
"widget": {
"name": "regions.irises"
},
"link": null
},
{
"name": "model_name",
"type": "COMBO",
"widget": {
"name": "model_name"
},
"link": null
}
],
"outputs": [
{
"localized_name": "face_landmarks",
"name": "face_landmarks",
"type": "FACE_LANDMARKS",
"links": []
},
{
"localized_name": "bboxes",
"name": "bboxes",
"type": "BOUNDING_BOX",
"links": []
},
{
"label": "mask",
"name": "MASK_1",
"type": "MASK",
"links": []
}
],
"title": "Image Face Detection (Mediapipe)",
"properties": {
"proxyWidgets": [
[
"11",
"detector_variant"
],
[
"11",
"num_faces"
],
[
"20",
"regions.face_oval"
],
[
"20",
"regions.lips"
],
[
"20",
"regions.left_eye"
],
[
"20",
"regions.right_eye"
],
[
"20",
"regions.irises"
],
[
"2",
"model_name"
]
],
"cnr_id": "comfy-core",
"ver": "0.22.0",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": []
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "6062babb-b649-4a71-be9e-20ebce567744",
"version": 1,
"state": {
"lastGroupId": 2,
"lastNodeId": 158,
"lastLinkId": 140,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Image Face Detection (Mediapipe)",
"description": "Detects facial landmarks from an image using MediaPipe, outputting landmark data, face bounding boxes, and an optional face-region mask.",
"inputNode": {
"id": -10,
"bounding": [
-710,
4300,
148.880859375,
248
]
},
"outputNode": {
"id": -20,
"bounding": [
140,
4480,
137.677734375,
108
]
},
"inputs": [
{
"id": "705dc1ae-6dc9-4155-92df-52f816ad451e",
"name": "image",
"type": "IMAGE",
"linkIds": [
60
],
"localized_name": "image",
"pos": [
-585.119140625,
4324
]
},
{
"id": "d6277190-732c-4604-b7cd-d3a9588bf761",
"name": "face_landmarker",
"type": "FACE_LANDMARKER",
"linkIds": [
74
],
"pos": [
-585.119140625,
4344
]
},
{
"id": "ac473a08-6a86-42a7-b460-e70c6c5e1e2b",
"name": "detector_variant",
"type": "COMBO",
"linkIds": [
75
],
"pos": [
-585.119140625,
4364
]
},
{
"id": "1bec2252-ca2d-496e-8a33-33a61d21f897",
"name": "num_faces",
"type": "INT",
"linkIds": [
76
],
"pos": [
-585.119140625,
4384
]
},
{
"id": "17994fa2-0ea0-4c9b-a70a-19789c459c80",
"name": "regions.face_oval",
"type": "BOOLEAN",
"linkIds": [
77
],
"label": "custom_face_oval",
"pos": [
-585.119140625,
4404
]
},
{
"id": "1c6c5893-2aee-4c37-b702-15ef2e20d863",
"name": "regions.lips",
"type": "BOOLEAN",
"linkIds": [
78
],
"label": "custom_lips",
"pos": [
-585.119140625,
4424
]
},
{
"id": "f353fcea-4b6f-42a1-8fdd-32b3aa1e1f09",
"name": "regions.left_eye",
"type": "BOOLEAN",
"linkIds": [
79
],
"label": "custom_left_eye",
"pos": [
-585.119140625,
4444
]
},
{
"id": "1387e121-c1fb-4522-8f0d-43459e11dd86",
"name": "regions.right_eye",
"type": "BOOLEAN",
"linkIds": [
80
],
"label": "custom_right_eye",
"pos": [
-585.119140625,
4464
]
},
{
"id": "14acb0a0-d1f4-48f3-ba31-811b26236ef9",
"name": "regions.irises",
"type": "BOOLEAN",
"linkIds": [
81
],
"label": "custom_irises",
"pos": [
-585.119140625,
4484
]
},
{
"id": "25a82859-87de-42c8-8431-09948665546e",
"name": "model_name",
"type": "COMBO",
"linkIds": [
86
],
"pos": [
-585.119140625,
4504
]
}
],
"outputs": [
{
"id": "d2ba3f92-e8b1-49c3-9590-cfad56c54cf4",
"name": "face_landmarks",
"type": "FACE_LANDMARKS",
"linkIds": [
44
],
"localized_name": "face_landmarks",
"pos": [
164,
4504
]
},
{
"id": "4f356bb0-d4c4-4f93-b4cf-0845a65c4e6d",
"name": "bboxes",
"type": "BOUNDING_BOX",
"linkIds": [
25
],
"localized_name": "bboxes",
"pos": [
164,
4524
]
},
{
"id": "f6309e1d-6397-4363-b38f-778a122abc51",
"name": "MASK_1",
"type": "MASK",
"linkIds": [
83
],
"label": "mask",
"pos": [
164,
4544
]
}
],
"widgets": [],
"nodes": [
{
"id": 11,
"type": "MediaPipeFaceLandmarker",
"pos": [
-280,
4280
],
"size": [
350,
220
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "face_detection_model",
"name": "face_detection_model",
"type": "FACE_DETECTION_MODEL",
"link": 66
},
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 60
},
{
"localized_name": "detector_variant",
"name": "detector_variant",
"type": "COMBO",
"widget": {
"name": "detector_variant"
},
"link": 75
},
{
"localized_name": "num_faces",
"name": "num_faces",
"type": "INT",
"widget": {
"name": "num_faces"
},
"link": 76
},
{
"localized_name": "min_confidence",
"name": "min_confidence",
"type": "FLOAT",
"widget": {
"name": "min_confidence"
},
"link": null
},
{
"localized_name": "missing_frame_fallback",
"name": "missing_frame_fallback",
"type": "COMBO",
"widget": {
"name": "missing_frame_fallback"
},
"link": null
},
{
"name": "face_landmarker",
"type": "FACE_LANDMARKER",
"link": 74
}
],
"outputs": [
{
"localized_name": "face_landmarks",
"name": "face_landmarks",
"type": "FACE_LANDMARKS",
"links": [
44,
46
]
},
{
"localized_name": "bboxes",
"name": "bboxes",
"type": "BOUNDING_BOX",
"links": [
25
]
}
],
"properties": {
"Node name for S&R": "MediaPipeFaceLandmarker",
"cnr_id": "comfy-core",
"ver": "0.22.0",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [
"full",
0,
0.5,
"empty"
]
},
{
"id": 2,
"type": "LoadMediaPipeFaceLandmarker",
"pos": [
-290,
4060
],
"size": [
350,
140
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "model_name",
"name": "model_name",
"type": "COMBO",
"widget": {
"name": "model_name"
},
"link": 86
}
],
"outputs": [
{
"localized_name": "FACE_DETECTION_MODEL",
"name": "FACE_DETECTION_MODEL",
"type": "FACE_DETECTION_MODEL",
"links": [
66
]
}
],
"properties": {
"Node name for S&R": "LoadMediaPipeFaceLandmarker",
"cnr_id": "comfy-core",
"ver": "0.22.0",
"models": [
{
"name": "mediapipe_face_fp32.safetensors",
"url": "https://huggingface.co/Comfy-Org/mediapipe/resolve/main/detection/mediapipe_face_fp32.safetensors",
"directory": "detection"
}
],
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [
"mediapipe_face_fp32.safetensors"
]
},
{
"id": 20,
"type": "MediaPipeFaceMask",
"pos": [
-290,
4560
],
"size": [
360,
180
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "face_landmarks",
"name": "face_landmarks",
"type": "FACE_LANDMARKS",
"link": 46
},
{
"localized_name": "regions",
"name": "regions",
"type": "COMFY_DYNAMICCOMBO_V3",
"widget": {
"name": "regions"
},
"link": null
},
{
"localized_name": "regions.face_oval",
"name": "regions.face_oval",
"type": "BOOLEAN",
"widget": {
"name": "regions.face_oval"
},
"link": 77
},
{
"localized_name": "regions.lips",
"name": "regions.lips",
"type": "BOOLEAN",
"widget": {
"name": "regions.lips"
},
"link": 78
},
{
"localized_name": "regions.left_eye",
"name": "regions.left_eye",
"type": "BOOLEAN",
"widget": {
"name": "regions.left_eye"
},
"link": 79
},
{
"localized_name": "regions.right_eye",
"name": "regions.right_eye",
"type": "BOOLEAN",
"widget": {
"name": "regions.right_eye"
},
"link": 80
},
{
"localized_name": "regions.irises",
"name": "regions.irises",
"type": "BOOLEAN",
"widget": {
"name": "regions.irises"
},
"link": 81
}
],
"outputs": [
{
"localized_name": "MASK",
"name": "MASK",
"type": "MASK",
"links": [
83
]
}
],
"properties": {
"Node name for S&R": "MediaPipeFaceMask",
"cnr_id": "comfy-core",
"ver": "0.22.0",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [
"custom",
true,
false,
false,
false,
false
]
}
],
"groups": [],
"links": [
{
"id": 66,
"origin_id": 2,
"origin_slot": 0,
"target_id": 11,
"target_slot": 0,
"type": "FACE_DETECTION_MODEL"
},
{
"id": 46,
"origin_id": 11,
"origin_slot": 0,
"target_id": 20,
"target_slot": 0,
"type": "FACE_LANDMARKS"
},
{
"id": 60,
"origin_id": -10,
"origin_slot": 0,
"target_id": 11,
"target_slot": 1,
"type": "IMAGE"
},
{
"id": 44,
"origin_id": 11,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "FACE_LANDMARKS"
},
{
"id": 25,
"origin_id": 11,
"origin_slot": 1,
"target_id": -20,
"target_slot": 1,
"type": "BOUNDING_BOX"
},
{
"id": 74,
"origin_id": -10,
"origin_slot": 1,
"target_id": 11,
"target_slot": 6,
"type": "FACE_LANDMARKER"
},
{
"id": 75,
"origin_id": -10,
"origin_slot": 2,
"target_id": 11,
"target_slot": 2,
"type": "COMBO"
},
{
"id": 76,
"origin_id": -10,
"origin_slot": 3,
"target_id": 11,
"target_slot": 3,
"type": "INT"
},
{
"id": 77,
"origin_id": -10,
"origin_slot": 4,
"target_id": 20,
"target_slot": 2,
"type": "BOOLEAN"
},
{
"id": 78,
"origin_id": -10,
"origin_slot": 5,
"target_id": 20,
"target_slot": 3,
"type": "BOOLEAN"
},
{
"id": 79,
"origin_id": -10,
"origin_slot": 6,
"target_id": 20,
"target_slot": 4,
"type": "BOOLEAN"
},
{
"id": 80,
"origin_id": -10,
"origin_slot": 7,
"target_id": 20,
"target_slot": 5,
"type": "BOOLEAN"
},
{
"id": 81,
"origin_id": -10,
"origin_slot": 8,
"target_id": 20,
"target_slot": 6,
"type": "BOOLEAN"
},
{
"id": 83,
"origin_id": 20,
"origin_slot": 0,
"target_id": -20,
"target_slot": 2,
"type": "MASK"
},
{
"id": 86,
"origin_id": -10,
"origin_slot": 9,
"target_id": 2,
"target_slot": 0,
"type": "COMBO"
}
],
"extra": {},
"category": "Conditioning & Preprocessors/Face Detection"
}
]
},
"extra": {}
}

View File

@ -703,7 +703,7 @@
}
],
"extra": {},
"category": "Image Tools/Image Segmentation",
"category": "Conditioning & Preprocessors/Segmentation & Mask",
"description": "Segments images into masks using Meta SAM3 from text prompts, points, or boxes."
}
]

View File

@ -1302,7 +1302,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Enhance",
"category": "Image generation and editing/Upscale",
"description": "Upscales images to higher resolution using Z-Image-Turbo."
}
]
@ -1312,4 +1312,4 @@
"workflowRendererVersion": "LG"
},
"version": 0.4
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,888 @@
{
"revision": 0,
"last_node_id": 675,
"last_link_id": 0,
"nodes": [
{
"id": 675,
"type": "01b6a731-fb78-4070-9a38-c87146da9604",
"pos": [
-2480,
3400
],
"size": [
360,
433.3125
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "input",
"name": "input",
"type": "IMAGE,MASK",
"link": null
},
{
"label": "resize_target_longer_size",
"name": "resize_type.longer_size",
"type": "INT",
"widget": {
"name": "resize_type.longer_size"
},
"link": null
},
{
"name": "scale_method",
"type": "COMBO",
"widget": {
"name": "scale_method"
},
"link": null
},
{
"name": "draw_body",
"type": "BOOLEAN",
"widget": {
"name": "draw_body"
},
"link": null
},
{
"name": "draw_hands",
"type": "BOOLEAN",
"widget": {
"name": "draw_hands"
},
"link": null
},
{
"name": "draw_face",
"type": "BOOLEAN",
"widget": {
"name": "draw_face"
},
"link": null
},
{
"name": "draw_feet",
"type": "BOOLEAN",
"widget": {
"name": "draw_feet"
},
"link": null
},
{
"name": "stick_width",
"type": "INT",
"widget": {
"name": "stick_width"
},
"link": null
},
{
"name": "face_point_size",
"type": "INT",
"widget": {
"name": "face_point_size"
},
"link": null
},
{
"name": "score_threshold",
"type": "FLOAT",
"widget": {
"name": "score_threshold"
},
"link": null
},
{
"name": "ckpt_name",
"type": "COMBO",
"widget": {
"name": "ckpt_name"
},
"link": null
},
{
"name": "bboxes",
"shape": 7,
"type": "BOUNDING_BOX",
"link": null
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"type": "IMAGE",
"links": []
},
{
"name": "keypoints",
"type": "POSE_KEYPOINT",
"links": null
}
],
"properties": {
"proxyWidgets": [
[
"674",
"resize_type.longer_size"
],
[
"674",
"scale_method"
],
[
"672",
"draw_body"
],
[
"672",
"draw_hands"
],
[
"672",
"draw_face"
],
[
"672",
"draw_feet"
],
[
"672",
"stick_width"
],
[
"672",
"face_point_size"
],
[
"672",
"score_threshold"
],
[
"673",
"ckpt_name"
]
],
"cnr_id": "comfy-core",
"ver": "0.15.1",
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [],
"title": "Image to Pose Map (SDPose-OOD)"
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "01b6a731-fb78-4070-9a38-c87146da9604",
"version": 1,
"state": {
"lastGroupId": 0,
"lastNodeId": 676,
"lastLinkId": 1715,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Image to Pose Map (SDPose-OOD)",
"inputNode": {
"id": -10,
"bounding": [
-3290,
3590,
190.8984375,
288
]
},
"outputNode": {
"id": -20,
"bounding": [
-1756.2451602089645,
3366,
128,
88
]
},
"inputs": [
{
"id": "e24699c3-1356-4634-9eb4-19bb58e5c0b0",
"name": "input",
"type": "IMAGE,MASK",
"linkIds": [
1700
],
"localized_name": "input",
"pos": [
-3123.1015625,
3614
]
},
{
"id": "088eefc1-cd8a-4573-993f-9e4da008a12d",
"name": "resize_type.longer_size",
"type": "INT",
"linkIds": [
1704
],
"label": "resize_target_longer_size",
"pos": [
-3123.1015625,
3634
]
},
{
"id": "b6449bd3-73d4-41c8-b81f-cf8d33f76a2e",
"name": "scale_method",
"type": "COMBO",
"linkIds": [
1705
],
"pos": [
-3123.1015625,
3654
]
},
{
"id": "4cff52ad-ed07-4c97-8803-fcbd89554fd0",
"name": "draw_body",
"type": "BOOLEAN",
"linkIds": [
1706
],
"pos": [
-3123.1015625,
3674
]
},
{
"id": "7af63dce-f7df-4d7e-8215-d7c7f60bf81c",
"name": "draw_hands",
"type": "BOOLEAN",
"linkIds": [
1707
],
"pos": [
-3123.1015625,
3694
]
},
{
"id": "af3a9bce-61f9-4aca-b530-9f65e028b35e",
"name": "draw_face",
"type": "BOOLEAN",
"linkIds": [
1708
],
"pos": [
-3123.1015625,
3714
]
},
{
"id": "4620f6a3-2c85-4b79-ad8f-35d0326b568f",
"name": "draw_feet",
"type": "BOOLEAN",
"linkIds": [
1709
],
"pos": [
-3123.1015625,
3734
]
},
{
"id": "fee5d0c9-8d4b-4934-81d8-ba2206dc56cb",
"name": "stick_width",
"type": "INT",
"linkIds": [
1710
],
"pos": [
-3123.1015625,
3754
]
},
{
"id": "aafdd060-ba81-4324-a9cc-b656e1ebc133",
"name": "face_point_size",
"type": "INT",
"linkIds": [
1711
],
"pos": [
-3123.1015625,
3774
]
},
{
"id": "514c5503-f9e6-4d23-b1ae-1d3291acb2a3",
"name": "score_threshold",
"type": "FLOAT",
"linkIds": [
1712
],
"pos": [
-3123.1015625,
3794
]
},
{
"id": "ae46de61-2cc6-483e-8ee9-87e4144a2ffa",
"name": "ckpt_name",
"type": "COMBO",
"linkIds": [
1713
],
"pos": [
-3123.1015625,
3814
]
},
{
"id": "41bec0c6-dffa-4c78-9289-ee678715ae54",
"name": "bboxes",
"type": "BOUNDING_BOX",
"linkIds": [
1714
],
"pos": [
-3123.1015625,
3834
]
}
],
"outputs": [
{
"id": "f05ed8cc-9403-4f14-8085-4364b06f8a48",
"name": "IMAGE",
"type": "IMAGE",
"linkIds": [
1701
],
"localized_name": "IMAGE",
"pos": [
-1732.2451602089645,
3390
]
},
{
"id": "29a6584e-4685-4986-8ffd-e6d8539953fd",
"name": "keypoints",
"type": "POSE_KEYPOINT",
"linkIds": [
1715
],
"pos": [
-1732.2451602089645,
3410
]
}
],
"widgets": [],
"nodes": [
{
"id": 671,
"type": "SDPoseKeypointExtractor",
"pos": [
-2470,
3250
],
"size": [
270,
180
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "model",
"name": "model",
"type": "MODEL",
"link": 1696
},
{
"localized_name": "vae",
"name": "vae",
"type": "VAE",
"link": 1697
},
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 1698
},
{
"localized_name": "bboxes",
"name": "bboxes",
"shape": 7,
"type": "BOUNDING_BOX",
"link": 1714
},
{
"localized_name": "batch_size",
"name": "batch_size",
"type": "INT",
"widget": {
"name": "batch_size"
},
"link": null
}
],
"outputs": [
{
"localized_name": "keypoints",
"name": "keypoints",
"type": "POSE_KEYPOINT",
"links": [
1699,
1715
]
}
],
"properties": {
"Node name for S&R": "SDPoseKeypointExtractor",
"cnr_id": "comfy-core",
"ver": "0.15.0",
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
16
]
},
{
"id": 674,
"type": "ResizeImageMaskNode",
"pos": [
-2960,
3490
],
"size": [
270,
110
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "input",
"name": "input",
"type": "IMAGE,MASK",
"link": 1700
},
{
"localized_name": "resize_type",
"name": "resize_type",
"type": "COMFY_DYNAMICCOMBO_V3",
"widget": {
"name": "resize_type"
},
"link": null
},
{
"localized_name": "resize_type.longer_size",
"name": "resize_type.longer_size",
"type": "INT",
"widget": {
"name": "resize_type.longer_size"
},
"link": 1704
},
{
"localized_name": "scale_method",
"name": "scale_method",
"type": "COMBO",
"widget": {
"name": "scale_method"
},
"link": 1705
}
],
"outputs": [
{
"localized_name": "resized",
"name": "resized",
"type": "*",
"links": [
1698
]
}
],
"properties": {
"Node name for S&R": "ResizeImageMaskNode",
"cnr_id": "comfy-core",
"ver": "0.15.0",
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"scale longer dimension",
1024,
"area"
]
},
{
"id": 672,
"type": "SDPoseDrawKeypoints",
"pos": [
-2120,
3260
],
"size": [
270,
280
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "keypoints",
"name": "keypoints",
"type": "POSE_KEYPOINT",
"link": 1699
},
{
"localized_name": "draw_body",
"name": "draw_body",
"type": "BOOLEAN",
"widget": {
"name": "draw_body"
},
"link": 1706
},
{
"localized_name": "draw_hands",
"name": "draw_hands",
"type": "BOOLEAN",
"widget": {
"name": "draw_hands"
},
"link": 1707
},
{
"localized_name": "draw_face",
"name": "draw_face",
"type": "BOOLEAN",
"widget": {
"name": "draw_face"
},
"link": 1708
},
{
"localized_name": "draw_feet",
"name": "draw_feet",
"type": "BOOLEAN",
"widget": {
"name": "draw_feet"
},
"link": 1709
},
{
"localized_name": "stick_width",
"name": "stick_width",
"type": "INT",
"widget": {
"name": "stick_width"
},
"link": 1710
},
{
"localized_name": "face_point_size",
"name": "face_point_size",
"type": "INT",
"widget": {
"name": "face_point_size"
},
"link": 1711
},
{
"localized_name": "score_threshold",
"name": "score_threshold",
"type": "FLOAT",
"widget": {
"name": "score_threshold"
},
"link": 1712
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"type": "IMAGE",
"links": [
1701
]
}
],
"properties": {
"Node name for S&R": "SDPoseDrawKeypoints",
"cnr_id": "comfy-core",
"ver": "0.15.0",
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
true,
true,
true,
true,
4,
2,
0.5
]
},
{
"id": 673,
"type": "CheckpointLoaderSimple",
"pos": [
-2960,
3250
],
"size": [
390,
190
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "ckpt_name",
"name": "ckpt_name",
"type": "COMBO",
"widget": {
"name": "ckpt_name"
},
"link": 1713
}
],
"outputs": [
{
"localized_name": "MODEL",
"name": "MODEL",
"type": "MODEL",
"links": [
1696
]
},
{
"localized_name": "CLIP",
"name": "CLIP",
"type": "CLIP",
"links": []
},
{
"localized_name": "VAE",
"name": "VAE",
"type": "VAE",
"links": [
1697
]
}
],
"properties": {
"Node name for S&R": "CheckpointLoaderSimple",
"cnr_id": "comfy-core",
"ver": "0.15.0",
"models": [
{
"name": "sdpose_wholebody_fp16.safetensors",
"url": "https://huggingface.co/Comfy-Org/SDPose/resolve/main/checkpoints/sdpose_wholebody_fp16.safetensors",
"directory": "checkpoints"
}
],
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"sdpose_wholebody_fp16.safetensors"
]
}
],
"groups": [],
"links": [
{
"id": 1696,
"origin_id": 673,
"origin_slot": 0,
"target_id": 671,
"target_slot": 0,
"type": "MODEL"
},
{
"id": 1697,
"origin_id": 673,
"origin_slot": 2,
"target_id": 671,
"target_slot": 1,
"type": "VAE"
},
{
"id": 1698,
"origin_id": 674,
"origin_slot": 0,
"target_id": 671,
"target_slot": 2,
"type": "IMAGE"
},
{
"id": 1699,
"origin_id": 671,
"origin_slot": 0,
"target_id": 672,
"target_slot": 0,
"type": "POSE_KEYPOINT"
},
{
"id": 1700,
"origin_id": -10,
"origin_slot": 0,
"target_id": 674,
"target_slot": 0,
"type": "IMAGE,MASK"
},
{
"id": 1701,
"origin_id": 672,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 1704,
"origin_id": -10,
"origin_slot": 1,
"target_id": 674,
"target_slot": 2,
"type": "INT"
},
{
"id": 1705,
"origin_id": -10,
"origin_slot": 2,
"target_id": 674,
"target_slot": 3,
"type": "COMBO"
},
{
"id": 1706,
"origin_id": -10,
"origin_slot": 3,
"target_id": 672,
"target_slot": 1,
"type": "BOOLEAN"
},
{
"id": 1707,
"origin_id": -10,
"origin_slot": 4,
"target_id": 672,
"target_slot": 2,
"type": "BOOLEAN"
},
{
"id": 1708,
"origin_id": -10,
"origin_slot": 5,
"target_id": 672,
"target_slot": 3,
"type": "BOOLEAN"
},
{
"id": 1709,
"origin_id": -10,
"origin_slot": 6,
"target_id": 672,
"target_slot": 4,
"type": "BOOLEAN"
},
{
"id": 1710,
"origin_id": -10,
"origin_slot": 7,
"target_id": 672,
"target_slot": 5,
"type": "INT"
},
{
"id": 1711,
"origin_id": -10,
"origin_slot": 8,
"target_id": 672,
"target_slot": 6,
"type": "INT"
},
{
"id": 1712,
"origin_id": -10,
"origin_slot": 9,
"target_id": 672,
"target_slot": 7,
"type": "FLOAT"
},
{
"id": 1713,
"origin_id": -10,
"origin_slot": 10,
"target_id": 673,
"target_slot": 0,
"type": "COMBO"
},
{
"id": 1714,
"origin_id": -10,
"origin_slot": 11,
"target_id": 671,
"target_slot": 3,
"type": "BOUNDING_BOX"
},
{
"id": 1715,
"origin_id": 671,
"origin_slot": 0,
"target_id": -20,
"target_slot": 1,
"type": "POSE_KEYPOINT"
}
],
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Conditioning & Preprocessors/Pose",
"description": "Extracts human pose keypoints and stick-figure visuals from an image using SDPose-OOD, with optional bounding-box input per subject."
}
]
},
"extra": {
"ue_links": []
}
}

1219
blueprints/Merge Videos.json Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1298,7 +1298,7 @@
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"category": "Image generation and editing/Pose to image",
"category": "Image generation and editing/Conditioned",
"description": "Generates an image from pose keypoints using Z-Image-Turbo with text conditioning."
}
]

View File

@ -3870,7 +3870,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Pose to video",
"category": "Video generation and editing/Conditioned",
"description": "Generates video from pose reference frames using LTX-2, with optional synchronized audio."
}
]

View File

@ -270,7 +270,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Text generation/Prompt enhance",
"category": "Text Tools",
"description": "Expands short text prompts into detailed descriptions using a text generation model for better generation quality."
}
]

View File

@ -389,7 +389,7 @@
}
],
"extra": {},
"category": "Image generation and editing/Background Removal"
"category": "Image Tools/Background Removal"
}
]
},

View File

@ -0,0 +1,485 @@
{
"revision": 0,
"last_node_id": 10,
"last_link_id": 0,
"nodes": [
{
"id": 10,
"type": "3fb7557a-470d-4983-9d8c-6d5caa9788f0",
"pos": [
-250,
8590
],
"size": [
280,
360
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "text_per_line",
"name": "text_per_line",
"type": "STRING",
"widget": {
"name": "text_per_line"
},
"link": null
},
{
"localized_name": "index",
"name": "index",
"type": "INT",
"widget": {
"name": "index"
},
"link": null
}
],
"outputs": [
{
"localized_name": "selected_line",
"name": "selected_line",
"type": "STRING",
"links": []
}
],
"properties": {
"proxyWidgets": [
[
"2",
"string"
],
[
"3",
"value"
]
],
"cnr_id": "comfy-core",
"ver": "0.19.0",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {}
}
},
"widgets_values": [],
"title": "Select Per-Line Text by Index"
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "3fb7557a-470d-4983-9d8c-6d5caa9788f0",
"version": 1,
"state": {
"lastGroupId": 0,
"lastNodeId": 10,
"lastLinkId": 14,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Select Per-Line Text by Index",
"inputNode": {
"id": -10,
"bounding": [
-990,
8595,
128,
88
]
},
"outputNode": {
"id": -20,
"bounding": [
710,
8585,
128,
68
]
},
"inputs": [
{
"id": "75417d82-a934-4ac9-b667-d8dcd5a3bfb3",
"name": "text_per_line",
"type": "STRING",
"linkIds": [
13
],
"localized_name": "text_per_line",
"pos": [
-886,
8619
]
},
{
"id": "46e69a73-1804-4ca6-9175-31445bf0be96",
"name": "index",
"type": "INT",
"linkIds": [
14
],
"localized_name": "index",
"pos": [
-886,
8639
]
}
],
"outputs": [
{
"id": "e34e8ad1-84d2-4bd2-a460-eb7de6067c10",
"name": "selected_line",
"type": "STRING",
"linkIds": [
10
],
"localized_name": "selected_line",
"pos": [
734,
8609
]
}
],
"widgets": [],
"nodes": [
{
"id": 1,
"type": "PreviewAny",
"pos": [
-500,
8400
],
"size": [
230,
180
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "source",
"name": "source",
"type": "*",
"link": 1
}
],
"outputs": [
{
"localized_name": "STRING",
"name": "STRING",
"type": "STRING",
"links": [
6
]
}
],
"properties": {
"Node name for S&R": "PreviewAny",
"cnr_id": "comfy-core",
"ver": "0.19.0",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {}
}
},
"widgets_values": [
null,
null,
null
]
},
{
"id": 2,
"type": "RegexExtract",
"pos": [
-240,
8740
],
"size": [
470,
460
],
"flags": {},
"order": 1,
"mode": 0,
"showAdvanced": false,
"inputs": [
{
"localized_name": "string",
"name": "string",
"type": "STRING",
"widget": {
"name": "string"
},
"link": 13
},
{
"localized_name": "regex_pattern",
"name": "regex_pattern",
"type": "STRING",
"widget": {
"name": "regex_pattern"
},
"link": 9
},
{
"localized_name": "mode",
"name": "mode",
"type": "COMBO",
"widget": {
"name": "mode"
},
"link": null
},
{
"localized_name": "case_insensitive",
"name": "case_insensitive",
"type": "BOOLEAN",
"widget": {
"name": "case_insensitive"
},
"link": null
},
{
"localized_name": "multiline",
"name": "multiline",
"type": "BOOLEAN",
"widget": {
"name": "multiline"
},
"link": null
},
{
"localized_name": "dotall",
"name": "dotall",
"type": "BOOLEAN",
"widget": {
"name": "dotall"
},
"link": null
},
{
"localized_name": "group_index",
"name": "group_index",
"type": "INT",
"widget": {
"name": "group_index"
},
"link": null
}
],
"outputs": [
{
"localized_name": "STRING",
"name": "STRING",
"type": "STRING",
"links": [
10
]
}
],
"properties": {
"Node name for S&R": "RegexExtract",
"cnr_id": "comfy-core",
"ver": "0.19.0",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"",
"",
"First Group",
false,
false,
false,
1
]
},
{
"id": 3,
"type": "PrimitiveInt",
"pos": [
-810,
8400
],
"size": [
270,
110
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "value",
"name": "value",
"type": "INT",
"widget": {
"name": "value"
},
"link": 14
}
],
"outputs": [
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
1
]
}
],
"title": "Int (line index)",
"properties": {
"Node name for S&R": "Int (line index)",
"cnr_id": "comfy-core",
"ver": "0.19.0",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {}
}
},
"widgets_values": [
0,
"fixed"
]
},
{
"id": 8,
"type": "StringReplace",
"pos": [
-240,
8400
],
"size": [
400,
280
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "string",
"name": "string",
"type": "STRING",
"widget": {
"name": "string"
},
"link": null
},
{
"localized_name": "find",
"name": "find",
"type": "STRING",
"widget": {
"name": "find"
},
"link": null
},
{
"localized_name": "replace",
"name": "replace",
"type": "STRING",
"widget": {
"name": "replace"
},
"link": 6
}
],
"outputs": [
{
"localized_name": "STRING",
"name": "STRING",
"type": "STRING",
"links": [
9
]
}
],
"properties": {
"Node name for S&R": "StringReplace",
"cnr_id": "comfy-core",
"ver": "0.19.0",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"^(?:[^\\n]*\\n){index}([^\\n]*)(?:\\n|$)",
"index",
""
]
}
],
"groups": [],
"links": [
{
"id": 1,
"origin_id": 3,
"origin_slot": 0,
"target_id": 1,
"target_slot": 0,
"type": "INT"
},
{
"id": 9,
"origin_id": 8,
"origin_slot": 0,
"target_id": 2,
"target_slot": 1,
"type": "STRING"
},
{
"id": 6,
"origin_id": 1,
"origin_slot": 0,
"target_id": 8,
"target_slot": 2,
"type": "STRING"
},
{
"id": 10,
"origin_id": 2,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "STRING"
},
{
"id": 13,
"origin_id": -10,
"origin_slot": 0,
"target_id": 2,
"target_slot": 0,
"type": "STRING"
},
{
"id": 14,
"origin_id": -10,
"origin_slot": 1,
"target_id": 3,
"target_slot": 0,
"type": "INT"
}
],
"extra": {},
"category": "Text Tools",
"description": "Selects one line from multiline text by zero-based index for batch or list-driven prompt workflows."
}
]
},
"extra": {
"ue_links": [],
"links_added_by_ue": []
}
}

View File

@ -0,0 +1,714 @@
{
"revision": 0,
"last_node_id": 251,
"last_link_id": 0,
"nodes": [
{
"id": 251,
"type": "609e1fd1-b731-4b78-89ac-d19b1156b025",
"pos": [
-1490,
130
],
"size": [
230,
164
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "source_image",
"name": "source_image",
"type": "IMAGE",
"link": null
},
{
"localized_name": "columns",
"name": "columns",
"type": "INT",
"widget": {
"name": "columns"
},
"link": null
},
{
"localized_name": "rows",
"name": "rows",
"type": "INT",
"widget": {
"name": "rows"
},
"link": null
}
],
"outputs": [
{
"localized_name": "tiles",
"name": "tiles",
"type": "IMAGE",
"links": []
}
],
"properties": {
"proxyWidgets": [
[
"228",
"value"
],
[
"252",
"value"
]
],
"cnr_id": "comfy-core",
"ver": "0.20.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [],
"title": "Split Image Grid to Tiles"
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "609e1fd1-b731-4b78-89ac-d19b1156b025",
"version": 1,
"state": {
"lastGroupId": 9,
"lastNodeId": 252,
"lastLinkId": 429,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Split Image Grid to Tiles",
"inputNode": {
"id": -10,
"bounding": [
-1690,
260,
128,
108
]
},
"outputNode": {
"id": -20,
"bounding": [
-510,
590,
128,
68
]
},
"inputs": [
{
"id": "866ac798-cfbc-450a-b755-e704f86404d9",
"name": "source_image",
"type": "IMAGE",
"linkIds": [
386,
389
],
"localized_name": "source_image",
"pos": [
-1586,
284
]
},
{
"id": "bc37b1f8-8ab2-4f19-bd00-75d4fbc4feb3",
"name": "columns",
"type": "INT",
"linkIds": [
427
],
"localized_name": "columns",
"pos": [
-1586,
304
]
},
{
"id": "d45915da-e848-43dd-9ccc-e3161e9c99d9",
"name": "rows",
"type": "INT",
"linkIds": [
428
],
"localized_name": "rows",
"pos": [
-1586,
324
]
}
],
"outputs": [
{
"id": "18bc780f-064b-4038-87c6-67dba71deb08",
"name": "tiles",
"type": "IMAGE",
"linkIds": [
394
],
"localized_name": "tiles",
"shape": 6,
"pos": [
-486,
614
]
}
],
"widgets": [],
"nodes": [
{
"id": 225,
"type": "SplitImageToTileList",
"pos": [
-1010,
620
],
"size": [
290,
170
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 386
},
{
"localized_name": "tile_width",
"name": "tile_width",
"type": "INT",
"widget": {
"name": "tile_width"
},
"link": 403
},
{
"localized_name": "tile_height",
"name": "tile_height",
"type": "INT",
"widget": {
"name": "tile_height"
},
"link": 404
},
{
"localized_name": "overlap",
"name": "overlap",
"type": "INT",
"widget": {
"name": "overlap"
},
"link": null
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"shape": 6,
"type": "IMAGE",
"links": [
394
]
}
],
"properties": {
"Node name for S&R": "SplitImageToTileList",
"cnr_id": "comfy-core",
"ver": "0.20.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [
1024,
1024,
0
]
},
{
"id": 231,
"type": "ComfyMathExpression",
"pos": [
-1080,
330
],
"size": [
370,
190
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"label": "a",
"localized_name": "values.a",
"name": "values.a",
"type": "FLOAT,INT,BOOLEAN",
"link": 390
},
{
"label": "b",
"localized_name": "values.b",
"name": "values.b",
"shape": 7,
"type": "FLOAT,INT,BOOLEAN",
"link": 429
},
{
"label": "c",
"localized_name": "values.c",
"name": "values.c",
"shape": 7,
"type": "FLOAT,INT,BOOLEAN",
"link": null
},
{
"localized_name": "expression",
"name": "expression",
"type": "STRING",
"widget": {
"name": "expression"
},
"link": null
}
],
"outputs": [
{
"localized_name": "FLOAT",
"name": "FLOAT",
"type": "FLOAT",
"links": null
},
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
404
]
},
{
"localized_name": "BOOL",
"name": "BOOL",
"type": "BOOLEAN",
"links": null
}
],
"title": "Math Expression Height",
"properties": {
"Node name for S&R": "ComfyMathExpression",
"cnr_id": "comfy-core",
"ver": "0.18.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"max(1, (int(a) + int(b) - 1) // int(b))"
]
},
{
"id": 229,
"type": "ComfyMathExpression",
"pos": [
-1090,
-30
],
"size": [
370,
190
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"label": "a",
"localized_name": "values.a",
"name": "values.a",
"type": "FLOAT,INT,BOOLEAN",
"link": 387
},
{
"label": "b",
"localized_name": "values.b",
"name": "values.b",
"shape": 7,
"type": "FLOAT,INT,BOOLEAN",
"link": 388
},
{
"label": "c",
"localized_name": "values.c",
"name": "values.c",
"shape": 7,
"type": "FLOAT,INT,BOOLEAN",
"link": null
},
{
"localized_name": "expression",
"name": "expression",
"type": "STRING",
"widget": {
"name": "expression"
},
"link": null
}
],
"outputs": [
{
"localized_name": "FLOAT",
"name": "FLOAT",
"type": "FLOAT",
"links": null
},
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
403
]
},
{
"localized_name": "BOOL",
"name": "BOOL",
"type": "BOOLEAN",
"links": null
}
],
"title": "Math Expression Width",
"properties": {
"Node name for S&R": "ComfyMathExpression",
"cnr_id": "comfy-core",
"ver": "0.18.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"max(1, (int(a) + int(b) - 1) // int(b))"
]
},
{
"id": 228,
"type": "PrimitiveInt",
"pos": [
-1380,
90
],
"size": [
230,
110
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "value",
"name": "value",
"type": "INT",
"widget": {
"name": "value"
},
"link": 427
}
],
"outputs": [
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
388
]
}
],
"title": "Int (grid columns)",
"properties": {
"Node name for S&R": "Int (grid columns)",
"cnr_id": "comfy-core",
"ver": "0.18.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
2,
"fixed"
]
},
{
"id": 230,
"type": "GetImageSize",
"pos": [
-1380,
290
],
"size": [
230,
100
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 389
}
],
"outputs": [
{
"localized_name": "width",
"name": "width",
"type": "INT",
"links": [
387
]
},
{
"localized_name": "height",
"name": "height",
"type": "INT",
"links": [
390
]
},
{
"localized_name": "batch_size",
"name": "batch_size",
"type": "INT",
"links": null
}
],
"properties": {
"Node name for S&R": "GetImageSize",
"cnr_id": "comfy-core",
"ver": "0.18.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
}
},
{
"id": 252,
"type": "PrimitiveInt",
"pos": [
-1380,
470
],
"size": [
230,
110
],
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"localized_name": "value",
"name": "value",
"type": "INT",
"widget": {
"name": "value"
},
"link": 428
}
],
"outputs": [
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
429
]
}
],
"title": "Int (grid rows)",
"properties": {
"Node name for S&R": "Int (grid rows)",
"cnr_id": "comfy-core",
"ver": "0.18.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
3,
"fixed"
]
}
],
"groups": [],
"links": [
{
"id": 403,
"origin_id": 229,
"origin_slot": 1,
"target_id": 225,
"target_slot": 1,
"type": "INT"
},
{
"id": 404,
"origin_id": 231,
"origin_slot": 1,
"target_id": 225,
"target_slot": 2,
"type": "INT"
},
{
"id": 390,
"origin_id": 230,
"origin_slot": 1,
"target_id": 231,
"target_slot": 0,
"type": "INT"
},
{
"id": 387,
"origin_id": 230,
"origin_slot": 0,
"target_id": 229,
"target_slot": 0,
"type": "INT"
},
{
"id": 388,
"origin_id": 228,
"origin_slot": 0,
"target_id": 229,
"target_slot": 1,
"type": "INT"
},
{
"id": 386,
"origin_id": -10,
"origin_slot": 0,
"target_id": 225,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 389,
"origin_id": -10,
"origin_slot": 0,
"target_id": 230,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 394,
"origin_id": 225,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 427,
"origin_id": -10,
"origin_slot": 1,
"target_id": 228,
"target_slot": 0,
"type": "INT"
},
{
"id": 428,
"origin_id": -10,
"origin_slot": 2,
"target_id": 252,
"target_slot": 0,
"type": "INT"
},
{
"id": 429,
"origin_id": 252,
"origin_slot": 0,
"target_id": 231,
"target_slot": 1,
"type": "INT"
}
],
"extra": {},
"category": "Image Tools/Crop",
"description": "Splits an image into a configurable columns×rows grid of equal tiles for tiled generation or processing."
}
]
},
"extra": {}
}

File diff suppressed because it is too large Load Diff

View File

@ -307,9 +307,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Text generation/Video Captioning",
"category": "Video Tools",
"description": "Generates descriptive captions for video input using Google's Gemini multimodal LLM."
}
]
}
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -818,7 +818,7 @@
}
],
"extra": {},
"category": "Video Tools",
"category": "Conditioning & Preprocessors/Segmentation & Mask",
"description": "Segments video into temporally consistent masks using Meta SAM3 from text or interactive prompts."
}
]

View File

@ -412,7 +412,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Enhance video",
"category": "Video generation and editing/Upscale",
"description": "Upscales video to 4× resolution using a GAN-based upscaling model."
}
]

File diff suppressed because it is too large Load Diff

View File

@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use, as a comma-separated list (e.g. '0' or '0,1'). All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")

View File

@ -1,6 +1,5 @@
"""Comfy-specific type hinting"""
from __future__ import annotations
from typing import Literal, TypedDict, Optional
from typing_extensions import NotRequired
from abc import ABC, abstractmethod

View File

@ -15,13 +15,14 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import torch
from enum import Enum
import math
import os
import logging
import copy
import comfy.utils
import comfy.model_management
import comfy.model_detection
@ -38,7 +39,7 @@ import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
import comfy.ldm.qwen_image.controlnet
import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from comfy.hooks import HookGroup
@ -64,6 +65,18 @@ class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
class ControlIsolation:
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
def __init__(self, control: ControlBase):
self.control = control
self.orig_previous_controlnet = control.previous_controlnet
def __enter__(self):
self.control.previous_controlnet = None
def __exit__(self, *args):
self.control.previous_controlnet = self.orig_previous_controlnet
class ControlBase:
def __init__(self):
self.cond_hint_original = None
@ -77,7 +90,7 @@ class ControlBase:
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}
self.previous_controlnet = None
self.previous_controlnet: Union[ControlBase, None] = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
@ -85,6 +98,7 @@ class ControlBase:
self.extra_concat = None
self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a
self.multigpu_clones: dict[torch.device, ControlBase] = {}
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
@ -111,17 +125,38 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
for device_cnet in self.multigpu_clones.values():
with ControlIsolation(device_cnet):
device_cnet.cleanup()
self.cond_hint = None
self.extra_concat = None
self.timestep_range = None
def get_models(self):
out = []
for device_cnet in self.multigpu_clones.values():
out += device_cnet.get_models_only_self()
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def get_models_only_self(self):
'Calls get_models, but temporarily sets previous_controlnet to None.'
with ControlIsolation(self):
return self.get_models()
def get_instance_for_device(self, device):
'Returns instance of this Control object intended for selected device.'
return self.multigpu_clones.get(device, self)
def deepclone_multigpu(self, load_device, autoregister=False):
'''
Create deep clone of Control object where model(s) is set to other devices.
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
'''
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:
@ -130,7 +165,7 @@ class ControlBase:
out += self.previous_controlnet.get_extra_hooks()
return out
def copy_to(self, c):
def copy_to(self, c: ControlBase):
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
@ -284,6 +319,14 @@ class ControlNet(ControlBase):
self.copy_to(c)
return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.control_model = copy.deepcopy(c.control_model)
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
if autoregister:
self.multigpu_clones[load_device] = c
return c
def get_models(self):
out = super().get_models()
out.append(self.control_model_wrapped)
@ -314,6 +357,10 @@ class QwenFunControlNet(ControlNet):
super().pre_run(model, percent_to_timestep_function)
self.set_extra_arg("base_model", model.diffusion_model)
def cleanup(self):
self.extra_args.pop("base_model", None)
super().cleanup()
def copy(self):
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
@ -906,6 +953,14 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.t2i_model = copy.deepcopy(c.t2i_model)
c.device = load_device
if autoregister:
self.multigpu_clones[load_device] = c
return c
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8
upscale_algorithm = 'nearest-exact'

View File

@ -799,13 +799,15 @@ class ZImagePixelSpace(ChromaRadiance):
"""
pass
class HiDreamO1Pixel(ChromaRadiance):
"""Pixel-space latent format for HiDream-O1.
No VAE model patches/unpatches raw RGB internally with patch_size=32.
"""
pass
class PixelDiTPixel(ChromaRadiance):
pass
class CogVideoX(LatentFormat):
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).

View File

@ -607,9 +607,13 @@ class HunYuanDiTPlain(nn.Module):
def forward(self, x, t, context, transformer_options = {}, **kwargs):
x = x.movedim(-1, -2)
if context.shape[0] >= 2:
uncond_emb, cond_emb = context.chunk(2, dim = 0)
context = torch.cat([cond_emb, uncond_emb], dim = 0)
swap_cfg_halves = context.shape[0] >= 2
if swap_cfg_halves:
first_half, second_half = context.chunk(2, dim = 0)
context = torch.cat([second_half, first_half], dim = 0)
main_condition = context
t = 1.0 - t
@ -657,8 +661,8 @@ class HunYuanDiTPlain(nn.Module):
output = self.final_layer(combined)
output = output.movedim(-2, -1) * (-1.0)
if output.shape[0] >= 2:
cond_emb, uncond_emb = output.chunk(2, dim = 0)
return torch.cat([uncond_emb, cond_emb])
else:
return output
if swap_cfg_halves:
first_half, second_half = output.chunk(2, dim = 0)
output = torch.cat([second_half, first_half], dim = 0)
return output

510
comfy/ldm/lens/model.py Normal file
View File

@ -0,0 +1,510 @@
"""Lens denoising transformer (DiT)"""
from __future__ import annotations
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ldm.flux.layers
import comfy.patcher_extension
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.attention import optimized_attention
def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor:
return comfy.ldm.flux.layers.timestep_embedding(t, dim)
def _lens_position_ids(
frame: int, height: int, width: int, text_seq_len: int,
scale_rope: bool = True, device=None,
) -> torch.Tensor:
"""Lens axial (frame, h, w) position ids for joint image + text sequence.
With ``scale_rope=True`` h/w are centered around 0 (negative + positive
halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``;
caller adds a batch dim for ``EmbedND``.
"""
if scale_rope:
h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device),
torch.arange(0, height // 2, device=device)])
w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device),
torch.arange(0, width // 2, device=device)])
text_start = max(height // 2, width // 2)
else:
h_pos = torch.arange(height, device=device)
w_pos = torch.arange(width, device=device)
text_start = max(height, width)
f_pos = torch.arange(frame, device=device)
img_ids = torch.zeros(frame, height, width, 3, device=device)
img_ids[..., 0] = f_pos[:, None, None]
img_ids[..., 1] = h_pos[None, :, None]
img_ids[..., 2] = w_pos[None, None, :]
img_ids = img_ids.reshape(-1, 3)
# Text positions replicate across all 3 axes (matches original packing).
txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float()
txt_ids = txt_pos[:, None].expand(text_seq_len, 3)
return torch.cat([img_ids, txt_ids], dim=0)
class _TimestepEmbedder(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None) -> None:
super().__init__()
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = F.silu(x)
return self.linear_2(x)
class LensTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, dtype=None, device=None, operations=None) -> None:
super().__init__()
self.timestep_embedder = _TimestepEmbedder(256, embedding_dim, dtype=dtype, device=device, operations=operations)
def forward(self, timestep: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
proj = _lens_time_proj(timestep, 256)
return self.timestep_embedder(proj.to(dtype=hidden_states.dtype))
class GateMLP(nn.Module):
"""SwiGLU MLP."""
def __init__(self, dim: int, hidden_dim: int, dtype=None, device=None, operations=None) -> None:
super().__init__()
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
def forward(self, x):
return self.w2(F.silu(self.w1(x), inplace=True).mul_(self.w3(x)))
class LensJointAttention(nn.Module):
"""Joint image+text attention with fused QKV per stream."""
def __init__(
self,
query_dim: int,
added_kv_proj_dim: int,
dim_head: int = 64,
heads: int = 8,
out_dim: Optional[int] = None,
eps: float = 1e-5,
dtype=None,
device=None,
operations=None,
) -> None:
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.heads = self.inner_dim // dim_head
self.dim_head = dim_head
self.out_dim = out_dim if out_dim is not None else query_dim
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.img_qkv = operations.Linear(query_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
self.txt_qkv = operations.Linear(added_kv_proj_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
# ModuleList([Linear, Identity]) for state-dict key compatibility.
self.to_out = nn.ModuleList([
operations.Linear(self.inner_dim, self.out_dim, bias=True, dtype=dtype, device=device),
nn.Identity(),
])
self.to_add_out = operations.Linear(self.inner_dim, query_dim, bias=True, dtype=dtype, device=device)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
bsz, seq_img, _ = hidden_states.shape
seq_txt = encoder_hidden_states.shape[1]
# image stream
img_qkv = self.img_qkv(hidden_states).view(bsz, seq_img, 3, self.heads, self.dim_head)
img_q, img_k, img_v = img_qkv.unbind(dim=2)
img_q = self.norm_q(img_q)
img_k = self.norm_k(img_k)
del img_qkv
# text stream
txt_qkv = self.txt_qkv(encoder_hidden_states).view(bsz, seq_txt, 3, self.heads, self.dim_head)
txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2)
txt_q = self.norm_added_q(txt_q)
txt_k = self.norm_added_k(txt_k)
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
del img_q, txt_q
k = torch.cat([img_k, txt_k], dim=1).transpose(1, 2)
del img_k, txt_k
v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
del img_v, txt_v
q, k = apply_rope(q, k, freqs_cis)
if attention_mask is not None:
expected = (bsz, 1, 1, seq_img + seq_txt)
if attention_mask.shape != expected:
raise ValueError(
f"attention_mask must be {expected}, got {tuple(attention_mask.shape)}"
)
attention_mask = attention_mask.to(q.dtype)
out = optimized_attention(
q, k, v, self.heads, mask=attention_mask, skip_reshape=True,
transformer_options=transformer_options,
)
img_out = self.to_out[1](self.to_out[0](out[:, :seq_img, :]))
txt_out = self.to_add_out(out[:, seq_img:, :])
return img_out, txt_out
class LensTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
eps: float = 1e-6,
rms_norm: bool = True,
dtype=None,
device=None,
operations=None,
) -> None:
super().__init__()
self.attn = LensJointAttention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
eps=1e-5,
dtype=dtype,
device=device,
operations=operations,
)
if rms_norm:
NormCls = operations.RMSNorm
norm_kwargs = {}
else:
NormCls = operations.LayerNorm
norm_kwargs = {"elementwise_affine": False}
mlp_hidden = int(dim / 3 * 8)
# Sequential(SiLU, Linear) so state-dict lands at img_mod.1.{weight,bias}.
self.img_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.img_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.img_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.img_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
self.txt_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.txt_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.txt_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
@staticmethod
def _modulate(x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod1, img_mod2 = self.img_mod(temb).chunk(2, dim=-1)
txt_mod1, txt_mod2 = self.txt_mod(temb).chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
img_attn, txt_attn = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=txt_modulated,
freqs_cis=freqs_cis,
attention_mask=attention_mask,
transformer_options=transformer_options,
)
hidden_states = hidden_states + img_gate1 * img_attn
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
return encoder_hidden_states, hidden_states
class _AdaLayerNormContinuousNoAffine(nn.Module):
"""AdaLayerNormContinuous(elementwise_affine=False).
The reference uses ``scale, shift = chunk(2)`` (scale first) opposite
to Flux's ``LastLayer``.
"""
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, eps: float = 1e-6,
dtype=None, device=None, operations=None) -> None:
super().__init__()
self.linear = operations.Linear(
conditioning_embedding_dim, embedding_dim * 2, bias=True, dtype=dtype, device=device
)
self.eps = eps
self.embedding_dim = embedding_dim
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
emb = self.linear(F.silu(conditioning))
scale, shift = torch.chunk(emb, 2, dim=-1)
x = F.layer_norm(x, (self.embedding_dim,), None, None, self.eps)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class LensTransformer2DModel(nn.Module):
"""Lens dual-stream MMDiT (48 blocks, inner_dim=1536, multi-layer text)."""
def __init__(
self,
patch_size: int = 2,
in_channels: int = 128,
out_channels: Optional[int] = 32,
num_layers: int = 48,
attention_head_dim: int = 64,
num_attention_heads: int = 24,
enc_hidden_dim: int = 2880,
axes_dims_rope: Tuple[int, int, int] = (8, 28, 28),
rms_norm: bool = True,
multi_layer_encoder_feature: bool = True,
selected_layer_index: Tuple[int, ...] = (5, 11, 17, 23),
image_model=None, # unused; accepted for detection-side configs.
dtype=None,
device=None,
operations=None,
) -> None:
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels if out_channels is not None else in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.multi_layer_encoder_feature = multi_layer_encoder_feature
self.selected_layer_index = list(selected_layer_index)
self.dtype = dtype
self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
self.time_text_embed = LensTimestepProjEmbeddings(
embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
)
if self.multi_layer_encoder_feature:
self.txt_norm = nn.ModuleList(
[operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
for _ in self.selected_layer_index]
)
self.txt_in = operations.Linear(
enc_hidden_dim * len(self.selected_layer_index),
self.inner_dim, bias=True, dtype=dtype, device=device,
)
else:
self.txt_norm = operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
self.txt_in = operations.Linear(enc_hidden_dim, self.inner_dim, bias=True, dtype=dtype, device=device)
self.img_in = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList([
LensTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
eps=1e-6,
rms_norm=rms_norm,
dtype=dtype, device=device, operations=operations,
)
for _ in range(num_layers)
])
self.norm_out = _AdaLayerNormContinuousNoAffine(
self.inner_dim, self.inner_dim, eps=1e-6,
dtype=dtype, device=device, operations=operations,
)
self.proj_out = operations.Linear(
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True,
dtype=dtype, device=device,
)
def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None, **kwargs) -> torch.Tensor:
if transformer_options is None:
transformer_options = {}
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timestep, context, attention_mask, transformer_options, **kwargs)
def _forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
control: Optional[Dict[str, Any]] = None,
**kwargs,
) -> torch.Tensor:
"""ComfyUI bridge: ``(x[B,128,h,w], t[B], context[B,S,L*H], mask[B,S])``."""
if transformer_options is None:
transformer_options = {}
transformer_options = transformer_options.copy()
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
B, C, h, w = x.shape
hidden_states = x.permute(0, 2, 3, 1).reshape(B, h * w, C)
if self.multi_layer_encoder_feature:
L = len(self.selected_layer_index)
enc_dim = context.shape[-1] // L
encoder_hidden_states = list(
context.reshape(B, -1, L, enc_dim).unbind(dim=2)
)
text_seq_len = encoder_hidden_states[0].shape[1]
else:
encoder_hidden_states = context
text_seq_len = context.shape[1]
if attention_mask is None:
attention_mask = torch.ones(
(B, text_seq_len), dtype=torch.bool, device=x.device
)
img_len = h * w
joint_mask = self._build_joint_attention_mask(attention_mask, img_len)
hidden_states = self.img_in(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if self.multi_layer_encoder_feature:
normed = [self.txt_norm[i](encoder_hidden_states[i]) for i in range(L)]
encoder_hidden_states = torch.cat(normed, dim=-1)
else:
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if "post_input" in patches:
for p in patches["post_input"]:
out = p({
"img": hidden_states,
"txt": encoder_hidden_states,
"transformer_options": transformer_options,
})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
temb = self.time_text_embed(timestep, hidden_states)
ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0)
freqs_cis = self.pos_embed(ids)
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(
hidden_states=args["img"],
encoder_hidden_states=args["txt"],
temb=args["vec"],
freqs_cis=args["pe"],
attention_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"),
)
return out
out = blocks_replace[("double_block", i)](
{
"img": hidden_states,
"txt": encoder_hidden_states,
"vec": temb,
"pe": freqs_cis,
"attn_mask": joint_mask,
"transformer_options": transformer_options,
},
{"original_block": block_wrap},
)
encoder_hidden_states = out["txt"]
hidden_states = out["img"]
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
freqs_cis=freqs_cis,
attention_mask=joint_mask,
transformer_options=transformer_options,
)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({
"img": hidden_states,
"txt": encoder_hidden_states,
"x": x,
"block_index": i,
"transformer_options": transformer_options,
})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
if control is not None:
control_i = control.get("input")
if control_i is not None and i < len(control_i):
add = control_i[i]
if add is not None:
hidden_states[:, :add.shape[1]] += add
hidden_states = self.norm_out(hidden_states, temb)
out = self.proj_out(hidden_states)
return out.reshape(B, h, w, C).permute(0, 3, 1, 2).contiguous()
@staticmethod
def _build_joint_attention_mask(text_mask: torch.Tensor, img_len: int) -> torch.Tensor:
if text_mask.dtype != torch.bool:
text_mask = text_mask.bool()
bsz = text_mask.shape[0]
img_ones = torch.ones((bsz, img_len), dtype=torch.bool, device=text_mask.device)
joint = torch.cat([img_ones, text_mask], dim=1)
additive = torch.zeros_like(joint, dtype=torch.float32)
additive.masked_fill_(~joint, torch.finfo(torch.float32).min)
return additive[:, None, None, :]

View File

@ -767,25 +767,25 @@ class LTXAVModel(LTXVModel):
# Cross-attention timesteps - compress these too
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
timestep.max().expand_as(a_timestep_flat),
a_timestep_flat,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
a_timestep.max().expand_as(timestep_flat),
timestep_flat,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
a_timestep_scaled.max().expand_as(timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
timestep_scaled.max().expand_as(a_timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,

View File

@ -1,4 +1,3 @@
from __future__ import annotations
import torch
from torch import nn
from torch.nn import functional as F

View File

@ -1,4 +1,3 @@
from __future__ import annotations
import threading
import torch
from torch import nn

View File

@ -1,5 +1,4 @@
# Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py
from __future__ import annotations
from typing import List, Optional, Tuple

View File

@ -211,7 +211,7 @@ class TimestepEmbedder(nn.Module):
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None, max_period=10000):
super().__init__()
if output_size is None:
output_size = hidden_size
@ -221,9 +221,10 @@ class TimestepEmbedder(nn.Module):
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
def forward(self, t, dtype, **kwargs):
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_freq = timestep_embedding(t, self.frequency_embedding_size, max_period=self.max_period).to(dtype)
t_emb = self.mlp(t_freq)
return t_emb

View File

@ -1,6 +1,5 @@
"""Pure-torch + scipy geometry helpers for MoGe inference and mesh export."""
from __future__ import annotations
from typing import Optional, Tuple

View File

@ -4,7 +4,6 @@ V1: DINOv2 backbone + multi-output head (points, mask).
V2: DINOv2 encoder + neck + per-output heads (points, mask, normal, optional metric-scale MLP).
"""
from __future__ import annotations
from numbers import Number
from typing import Any, Dict, List, Optional, Tuple, Union

View File

@ -1,6 +1,5 @@
"""Building blocks for MoGe: residual conv stack, resamplers, MLP, DINOv2 encoder, v1 head."""
from __future__ import annotations
from typing import List, Optional, Sequence, Tuple, Union

View File

@ -6,7 +6,6 @@ equirect distance map via a multi-scale Poisson + gradient sparse solve.
Image sampling uses F.grid_sample (GPU); the sparse solve uses lsmr (CPU).
"""
from __future__ import annotations
from typing import Callable, List, Optional, Tuple

239
comfy/ldm/pixeldit/model.py Normal file
View File

@ -0,0 +1,239 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ldm.common_dit
import comfy.patcher_extension
from comfy.ldm.flux.math import apply_rope, rope
from comfy.ldm.hidream.model import FeedForwardSwiGLU
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from .modules import (
FinalLayer,
PatchTokenEmbedder,
PiTBlock,
PixelTokenEmbedder,
apply_adaln_,
precompute_freqs_cis_2d,
)
class MMDiTJointAttention(nn.Module):
"""Joint MMDiT attention with separate Q/K/V/proj for image and text streams.
RoPE is applied to each stream before concatenation so each stream uses its own
2D/1D positional encoding. Concat order is [text, image] (text first).
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv_x = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.qkv_y = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.proj_x = operations.Linear(dim, dim, dtype=dtype, device=device)
self.proj_y = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x, y, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
B, Nx, _ = x.shape
_, Ny, _ = y.shape
H = self.num_heads
D = self.head_dim
qkv_x = self.qkv_x(x).reshape(B, Nx, 3, H, D).permute(2, 0, 3, 1, 4)
qx, kx, vx = qkv_x.unbind(0)
qx = self.q_norm_x(qx)
kx = self.k_norm_x(kx)
qkv_y = self.qkv_y(y).reshape(B, Ny, 3, H, D).permute(2, 0, 3, 1, 4)
qy, ky, vy = qkv_y.unbind(0)
qy = self.q_norm_y(qy)
ky = self.k_norm_y(ky)
qx, kx = apply_rope(qx, kx, pos_img[None, None])
if pos_txt is not None:
qy, ky = apply_rope(qy, ky, pos_txt[None, None])
q_joint = torch.cat([qy, qx], dim=2)
k_joint = torch.cat([ky, kx], dim=2)
v_joint = torch.cat([vy, vx], dim=2)
out_joint = optimized_attention(
q_joint, k_joint, v_joint, H,
mask=attn_mask, skip_reshape=True, skip_output_reshape=True,
transformer_options=transformer_options,
)
out_y = out_joint[:, :, :Ny, :].transpose(1, 2).reshape(B, Ny, H * D)
out_x = out_joint[:, :, Ny:, :].transpose(1, 2).reshape(B, Nx, H * D)
return self.proj_x(out_x), self.proj_y(out_y)
class MMDiTBlockT2I(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, dtype=None, device=None, operations=None):
super().__init__()
self.norm_x1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.norm_y1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.attn = MMDiTJointAttention(hidden_size, num_heads=groups, qkv_bias=False, dtype=dtype, device=device, operations=operations)
self.norm_x2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.norm_y2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_x = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
self.mlp_y = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
self.adaLN_modulation_img = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
self.adaLN_modulation_txt = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
def forward(self, x, y, c, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1)
shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1)
x_norm = apply_adaln_(self.norm_x1(x), shift_msa_x, scale_msa_x)
y_norm = apply_adaln_(self.norm_y1(y), shift_msa_y, scale_msa_y)
attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options)
x = torch.addcmul(x, gate_msa_x, attn_x)
y = torch.addcmul(y, gate_msa_y, attn_y)
x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln_(self.norm_x2(x), shift_mlp_x, scale_mlp_x)))
y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln_(self.norm_y2(y), shift_mlp_y, scale_mlp_y)))
return x, y
class PixDiT_T2I(nn.Module):
"""PixelDiT T2I model. Hardcoded for the released 1024px Stage-3 checkpoint
(also runs at 512px when fed the appropriate latent size and flow_shift).
Forward:
x: [B, 3, H, W] pixel-space input (no VAE)
timesteps:[B] in [0, 1000] (ComfyUI flow sampling convention)
context: [B, Ltxt, 2304] Gemma-2-2b-it hidden states (chi_prompt prepended)
Returns flow-matching velocity [B, 3, H, W].
"""
def __init__(
self,
in_channels=3,
num_groups=24,
hidden_size=1536,
pixel_hidden_size=16,
pixel_attn_hidden_size=1152,
pixel_num_groups=16,
patch_depth=14,
pixel_depth=2,
patch_size=16,
txt_embed_dim=2304,
txt_max_length=300,
use_text_rope=True,
text_rope_theta=10000.0,
image_model=None,
dtype=None,
device=None,
operations=None,
pixel_mlp_chunks=2,
):
super().__init__()
self.dtype = dtype
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.patch_depth = patch_depth
self.pixel_depth = pixel_depth
self.patch_size = patch_size
self.pixel_hidden_size = pixel_hidden_size
self.pixel_attn_hidden_size = pixel_attn_hidden_size
self.pixel_num_groups = pixel_num_groups
self.txt_embed_dim = txt_embed_dim
self.txt_max_length = txt_max_length
self.use_text_rope = use_text_rope
self.text_rope_theta = text_rope_theta
self.pixel_embedder = PixelTokenEmbedder(self.in_channels, self.pixel_hidden_size, dtype=dtype, device=device, operations=operations)
self.s_embedder = PatchTokenEmbedder(self.in_channels * self.patch_size ** 2, self.hidden_size, bias=True, dtype=dtype, device=device, operations=operations)
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations, max_period=10)
self.y_embedder = PatchTokenEmbedder(self.txt_embed_dim, self.hidden_size, bias=True, use_norm=True, dtype=dtype, device=device, operations=operations)
self.y_pos_embedding = nn.Parameter(torch.empty(1, self.txt_max_length, self.hidden_size, dtype=dtype, device=device))
self.patch_blocks = nn.ModuleList([
MMDiTBlockT2I(self.hidden_size, self.num_groups,
dtype=dtype, device=device, operations=operations)
for _ in range(self.patch_depth)
])
self.pixel_blocks = nn.ModuleList([
PiTBlock(
self.pixel_hidden_size,
self.hidden_size,
patch_size=self.patch_size,
num_heads=self.num_groups,
attn_hidden_size=self.pixel_attn_hidden_size,
attn_num_heads=self.pixel_num_groups,
dtype=dtype, device=device, operations=operations,
mlp_chunks=pixel_mlp_chunks,
)
for _ in range(self.pixel_depth)
])
self.final_layer = FinalLayer(self.pixel_hidden_size, self.out_channels, dtype=dtype, device=device, operations=operations)
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
return precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width, device=device, dtype=dtype, **rope_opts)
def _fetch_text_pos(self, length, device, dtype):
return rope(torch.arange(length, dtype=torch.float32, device=device).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0).to(dtype=dtype)
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
def _pre_patch_block(self, s, i, **kwargs):
"""Hook for subclasses to inject per-block state into the patch stream (e.g. PiD's LQ gate)."""
return s
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
H_orig, W_orig = x.shape[2], x.shape[3]
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
B, _, H, W = x.shape
Hs = H // self.patch_size
Ws = W // self.patch_size
L = Hs * Ws
pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size)
if context is None or context.dim() != 3:
raise ValueError("PixDiT_T2I requires context (text embeddings) of shape [B, L, D]")
Ltxt = min(context.shape[1], self.txt_max_length)
y = context[:, :Ltxt, :]
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb) # y_pos_embedding is a raw nn.Parameter
condition = F.silu(t_emb)
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
s = self.s_embedder(x_patches)
for i, blk in enumerate(self.patch_blocks):
s = self._pre_patch_block(s, i, **kwargs)
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options)
s = F.silu(t_emb + s)
s_cond = s.view(B * L, self.hidden_size)
x_pixels = self.pixel_embedder(x, patch_size=self.patch_size)
for blk in self.pixel_blocks:
x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, transformer_options=transformer_options)
x_pixels = self.final_layer(x_pixels)
C_out = self.out_channels
P2 = self.patch_size * self.patch_size
x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).reshape(B, C_out * P2, L)
out = F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return out[:, :, :H_orig, :W_orig]

View File

@ -0,0 +1,187 @@
import torch
import torch.nn as nn
from comfy.ldm.flux.math import apply_rope, rope
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, get_1d_sincos_pos_embed_from_grid_torch
def apply_adaln_(x, shift, scale):
return x.addcmul_(x, scale).add_(shift)
def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0,
ref_grid_h=None, ref_grid_w=None,
scale_x=1.0, scale_y=1.0, shift_x=0.0, shift_y=0.0,
device=None, dtype=torch.float32, **kwargs):
"""2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim.
rope_options:
scale_x / scale_y multiply the position range (RoPE extrapolation).
shift_x / shift_y offset the position origin (tiled / regional inference).
With ref_grid_h/w set, also applies NTK-aware per-axis theta scaling
(rope_mode='ntk_aware'): theta_axis = theta * (current/ref)^(dim_axis/(dim_axis-2)).
Returns Flux-format rotation matrices of shape [H*W, dim/2, 2, 2].
Layout of head-dim pairs: [x_0, y_0, x_1, y_1, ..., x_{dim/4-1}, y_{dim/4-1}].
"""
dim_axis = dim // 2
if ref_grid_h is not None and dim_axis > 2:
h_ntk = (height / ref_grid_h) ** (dim_axis / (dim_axis - 2))
w_ntk = (width / ref_grid_w) ** (dim_axis / (dim_axis - 2))
else:
h_ntk = w_ntk = 1.0
x_lin = torch.linspace(shift_x, scale * scale_x + shift_x, width, device=device)
y_lin = torch.linspace(shift_y, scale * scale_y + shift_y, height, device=device)
y_grid, x_grid = torch.meshgrid(y_lin, x_lin, indexing="ij")
x_rope = rope(x_grid.reshape(1, -1), dim_axis, theta * w_ntk).squeeze(0)
y_rope = rope(y_grid.reshape(1, -1), dim_axis, theta * h_ntk).squeeze(0)
out = torch.stack([x_rope, y_rope], dim=2).reshape(height * width, dim // 2, 2, 2)
return out.to(dtype=dtype)
def get_2d_sincos_pos_embed(embed_dim, height, width, device=None, dtype=torch.float32):
"""Standard 2D sin/cos absolute positional embedding (ViT-style).
first half encodes W-coordinates, second half H.
"""
assert embed_dim % 4 == 0
grid_h = torch.arange(height, dtype=torch.float32, device=device)
grid_w = torch.arange(width, dtype=torch.float32, device=device)
grid_y, grid_x = torch.meshgrid(grid_h, grid_w, indexing="ij")
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_x.reshape(-1), device=device)
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_y.reshape(-1), device=device)
return torch.cat([emb_w, emb_h], dim=1).to(dtype=dtype)
class RotaryAttention(nn.Module):
"""Single-stream self-attention with rotary positional encoding (used inside PiTBlock)."""
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x, pos, mask=None, transformer_options={}):
B, N, C = x.shape
H = self.num_heads
D = self.head_dim
qkv = self.qkv(x).reshape(B, N, 3, H, D).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = apply_rope(self.q_norm(q), self.k_norm(k), pos[None, None])
x = optimized_attention(q, k, v, H, mask=mask, skip_reshape=True, transformer_options=transformer_options)
return self.proj(x)
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, x):
return self.linear(self.norm(x))
class PatchTokenEmbedder(nn.Module):
"""Linear projection used both for patchified-image tokens and text-feature tokens."""
def __init__(self, in_chans, embed_dim, use_norm=False, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(in_chans, embed_dim, bias=bias, dtype=dtype, device=device)
self.norm = operations.RMSNorm(embed_dim, eps=1e-6, dtype=dtype, device=device) if use_norm else nn.Identity()
def forward(self, x):
return self.norm(self.proj(x))
class PixelTokenEmbedder(nn.Module):
"""Pixel-level embedder: lifts each RGB pixel to hidden_size and packs into per-patch sequences."""
def __init__(self, in_channels, hidden_size_output, dtype=None, device=None, operations=None):
super().__init__()
self.in_channels = in_channels
self.hidden_size_output = hidden_size_output
self.proj = operations.Linear(self.in_channels, self.hidden_size_output, bias=True, dtype=dtype, device=device)
def forward(self, inputs, patch_size):
B, _, H, W = inputs.shape
Hs, Ws = H // patch_size, W // patch_size
P2 = patch_size * patch_size
x = inputs.permute(0, 2, 3, 1).contiguous()
x = self.proj(x)
pos_full = get_2d_sincos_pos_embed(self.hidden_size_output, H, W, device=x.device, dtype=x.dtype).view(H, W, self.hidden_size_output)
x = x + pos_full.unsqueeze(0)
x = x.view(B, Hs, patch_size, Ws, patch_size, self.hidden_size_output)
return x.permute(0, 1, 3, 2, 4, 5).reshape(B * Hs * Ws, P2, self.hidden_size_output)
class PiTBlock(nn.Module):
"""Pixel-level transformer block.
Compresses each patch's P^2 pixel tokens → 1 attention token via a linear,
runs global self-attention across patches with 2D RoPE, then expands back to P^2 tokens.
Conditioning is per-pixel adaLN from the patch-level features.
"""
def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0,
attn_hidden_size=None, attn_num_heads=None, dtype=None, device=None, operations=None, mlp_chunks=1):
super().__init__()
self.pixel_dim = pixel_hidden_size
self.context_dim = patch_hidden_size
self.attn_dim = attn_hidden_size if attn_hidden_size is not None else patch_hidden_size
self.num_heads = attn_num_heads if attn_num_heads is not None else num_heads
assert self.attn_dim % self.num_heads == 0
p2 = patch_size * patch_size
self.compress_to_attn = operations.Linear(p2 * self.pixel_dim, self.attn_dim, bias=True, dtype=dtype, device=device)
self.expand_from_attn = operations.Linear(self.attn_dim, p2 * self.pixel_dim, bias=True, dtype=dtype, device=device)
self.norm1 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
self.attn = RotaryAttention(self.attn_dim, num_heads=self.num_heads, qkv_bias=False, dtype=dtype, device=device, operations=operations)
self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio), dtype=dtype, device=device, operations=operations)
self.adaLN_modulation_msa = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
self.adaLN_modulation_mlp = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
self._rope_fn = precompute_freqs_cis_2d
self.mlp_chunks = max(1, int(mlp_chunks))
def _fetch_pos(self, height, width, device, dtype, **rope_opts):
return self._rope_fn(self.attn_dim // self.num_heads, height, width, device=device, dtype=dtype, **rope_opts)
def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}):
BL, P2, _ = x.shape
Hs, Ws = image_height // patch_size, image_width // patch_size
L = Hs * Ws
B = BL // L
# Attention path uses only msa params; compute, use, free before mlp params allocate.
msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim)
shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1)
x_norm = apply_adaln_(self.norm1(x), shift_msa, scale_msa)
x_flat = x_norm.view(BL, P2 * self.pixel_dim)
x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim)
pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
attn_out = self.attn(x_comp, pos_comp, mask=mask, transformer_options=transformer_options)
attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim))
attn_exp = attn_flat.view(BL, P2, self.pixel_dim)
x = torch.addcmul(x, gate_msa, attn_exp)
del msa_params, shift_msa, scale_msa, gate_msa
mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim)
shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1)
gate_mlp = gate_mlp.contiguous() # detach from mlp_params so the del below frees shift+scale storage before the MLP
mlp_input = apply_adaln_(self.norm2(x), shift_mlp, scale_mlp)
del mlp_params, shift_mlp, scale_mlp
# MLP in chunks since the peak memory usage is huge here
chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks
for s in range(0, BL, chunk_size):
e = min(s + chunk_size, BL)
x[s:e].addcmul_(gate_mlp[s:e], self.mlp(mlp_input[s:e]))
return x

226
comfy/ldm/pixeldit/pid.py Normal file
View File

@ -0,0 +1,226 @@
"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent
directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I
body + LQ projection branch injected before each MMDiT patch block.
"""
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from .model import PixDiT_T2I
from .modules import precompute_freqs_cis_2d
class SigmaAwareGatePerTokenPerDim(nn.Module):
"""gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq.
Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1.
"""
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device)
self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device))
def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
content_logit = self.content_proj(torch.cat([x, lq], dim=-1))
# log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM.
log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32)
sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1)
gate = torch.sigmoid(content_logit + sigma_offset)
return x + (gate * lq).to(x.dtype)
class ResBlock(nn.Module):
"""Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip."""
def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None):
super().__init__()
self.block = nn.Sequential(
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.block(x)
class LQProjection2D(nn.Module):
"""LQ latent -> per-block patch-aligned features for controlnet-style injection."""
def __init__(
self,
latent_channels: int,
hidden_dim: int = 512,
out_dim: int = 1536,
patch_size: int = 16,
sr_scale: int = 4,
latent_spatial_down_factor: int = 8,
num_res_blocks: int = 4,
num_outputs: int = 7,
interval: int = 2,
dtype=None, device=None, operations=None,
):
super().__init__()
self.latent_channels = latent_channels
self.hidden_dim = hidden_dim
self.out_dim = out_dim
self.patch_size = patch_size
self.sr_scale = sr_scale
self.latent_spatial_down_factor = latent_spatial_down_factor
self.num_outputs = num_outputs
self.interval = interval
z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size
self.z_to_patch_ratio = z_to_patch_ratio
if z_to_patch_ratio >= 1:
self.latent_fold_factor = 0
latent_proj_in_ch = latent_channels
else:
fold_factor = int(1 / z_to_patch_ratio)
assert fold_factor * z_to_patch_ratio == 1.0
self.latent_fold_factor = fold_factor
latent_proj_in_ch = latent_channels * fold_factor * fold_factor
layers = [
operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
]
for _ in range(num_res_blocks):
layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations))
self.latent_proj = nn.Sequential(*layers)
self.output_heads = nn.ModuleList(
[operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)]
)
self.gate_modules = nn.ModuleList(
[SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations)
for _ in range(num_outputs)]
)
def is_gate_active(self, block_idx: int) -> bool:
return block_idx % self.interval == 0
def output_index(self, block_idx: int) -> int:
return block_idx // self.interval
def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor:
return self.gate_modules[out_idx](x, lq_feature, sigma)
def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor:
B, z_dim = lq_latent.shape[:2]
if self.z_to_patch_ratio >= 1:
if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW:
z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest")
else:
z_aligned = lq_latent
else:
f = self.latent_fold_factor
zH_expected, zW_expected = pH * f, pW * f
if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected:
lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest")
z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4)
z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW)
return self.latent_proj(z_aligned)
def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]:
feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW)
B, C, H, W = feat.shape
tokens = feat.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)
return [head(tokens) for head in self.output_heads]
class PidNet(PixDiT_T2I):
"""PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block)."""
def __init__(
self,
lq_latent_channels: int = 16,
lq_hidden_dim: int = 512,
lq_num_res_blocks: int = 4,
lq_interval: int = 2,
sr_scale: int = 4,
latent_spatial_down_factor: int = 8,
rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64.
rope_ref_w: int = 1024,
image_model=None,
dtype=None, device=None, operations=None,
**pixdit_kwargs,
):
super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs)
self.rope_ref_grid_h = rope_ref_h // self.patch_size
self.rope_ref_grid_w = rope_ref_w // self.patch_size
# Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware.
def _pit_rope_fn(head_dim, h, w, device=None, dtype=torch.float32, **rope_opts):
return precompute_freqs_cis_2d(head_dim, h, w, ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, device=device, dtype=dtype, **rope_opts)
for blk in self.pixel_blocks:
blk._rope_fn = _pit_rope_fn
num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval
self.lq_proj = LQProjection2D(
latent_channels=lq_latent_channels,
hidden_dim=lq_hidden_dim,
out_dim=self.hidden_size,
patch_size=self.patch_size,
sr_scale=sr_scale,
latent_spatial_down_factor=latent_spatial_down_factor,
num_res_blocks=lq_num_res_blocks,
num_outputs=num_lq_outputs,
interval=lq_interval,
dtype=dtype,
device=device,
operations=operations,
)
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
return precompute_freqs_cis_2d(
self.hidden_size // self.num_groups,
height, width,
ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w,
device=device, dtype=dtype, **rope_opts,
)
def _pre_patch_block(self, s, i, pid_lq_features, pid_degrade_sigma, **kwargs):
if not self.lq_proj.is_gate_active(i):
return s
out_idx = self.lq_proj.output_index(i)
if out_idx >= len(pid_lq_features):
return s
return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx)
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs):
if lq_latent is None:
raise ValueError("PidNet requires lq_latent — attach via PiDConditioning")
expected_c = self.lq_proj.latent_channels
if lq_latent.shape[1] != expected_c:
raise ValueError(
f"Input latent has {lq_latent.shape[1]} channels, this model variant expects {expected_c}. "
f"Flux1/SD3 = 16 channels, Flux2 = 128 channels."
)
B = x.shape[0]
Hs = x.shape[2] // self.patch_size
Ws = x.shape[3] // self.patch_size
degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
if degrade_sigma.numel() == 1 and B > 1:
degrade_sigma = degrade_sigma.expand(B).contiguous()
lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws)
return super()._forward(
x, timesteps,
context=context, attention_mask=attention_mask,
transformer_options=transformer_options,
pid_lq_features=lq_features,
pid_degrade_sigma=degrade_sigma,
**kwargs,
)

View File

@ -16,7 +16,6 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import comfy.memory_management
import comfy.utils
import comfy.model_management

View File

@ -1,6 +1,5 @@
import math
import ctypes
import threading
import dataclasses
import torch
from typing import NamedTuple
@ -10,7 +9,7 @@ from comfy.quant_ops import QuantizedTensor
class TensorFileSlice(NamedTuple):
file_ref: object
thread_id: int
lock: object
offset: int
size: int
@ -43,7 +42,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
file_obj = info.file_ref
if (destination.device.type != "cpu"
or file_obj is None
or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size
or tensor.numel() * tensor.element_size() != info.size
or tensor.storage_offset() != 0
@ -57,27 +55,29 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
if hostbuf is not None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
device_ptr = destination2.data_ptr() if destination2 is not None else 0
hostbuf.read_file_slice(file_obj, info.offset, info.size,
offset=destination.data_ptr() - hostbuf.get_raw_address(),
stream=stream_ptr,
device_ptr=device_ptr,
device=None if destination2 is None else destination2.device.index)
with info.lock:
hostbuf.read_file_slice(file_obj, info.offset, info.size,
offset=destination.data_ptr() - hostbuf.get_raw_address(),
stream=stream_ptr,
device_ptr=device_ptr,
device=None if destination2 is None else destination2.device.index)
return True
buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))
try:
file_obj.seek(info.offset)
done = 0
while done < info.size:
try:
n = file_obj.readinto(view[done:])
except OSError:
return False
if n <= 0:
return False
done += n
with info.lock:
file_obj.seek(info.offset)
done = 0
while done < info.size:
try:
n = file_obj.readinto(view[done:])
except OSError:
return False
if n <= 0:
return False
done += n
return True
finally:
view.release()

View File

@ -35,6 +35,7 @@ import comfy.ldm.hydit.models
import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
import comfy.ldm.flux.model
import comfy.ldm.lens.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model
@ -48,6 +49,8 @@ import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.chroma_radiance.model
import comfy.ldm.pixeldit.model
import comfy.ldm.pixeldit.pid
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
@ -1058,6 +1061,27 @@ class Flux2(Flux):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class Lens(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(
model_config, model_type, device=device,
unet_model=comfy.ldm.lens.model.LensTransformer2DModel,
)
def encode_adm(self, **kwargs):
return None # Lens has no pooled/ADM conditioning.
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
return out
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
@ -1375,6 +1399,36 @@ class ZImagePixelSpace(Lumina2):
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
self.memory_usage_factor_conds = ("ref_latents",)
class PixelDiTT2I(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device,
unet_model=comfy.ldm.pixeldit.model.PixDiT_T2I)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
return out
class PiD(PixelDiTT2I):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
BaseModel.__init__(self, model_config, model_type, device=device,
unet_model=comfy.ldm.pixeldit.pid.PidNet)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
lq_latent = kwargs.get("lq_latent", None)
if lq_latent is not None:
out["lq_latent"] = comfy.conds.CONDRegular(lq_latent)
degrade_sigma = kwargs.get("degrade_sigma", None)
if degrade_sigma is not None:
out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma)
return out
class WAN21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)

View File

@ -463,6 +463,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config
# PiD (Pixel Diffusion Decoder). Must check BEFORE plain PixelDiT_T2I.
_lq_w_key = '{}lq_proj.latent_proj.0.weight'.format(key_prefix)
if _lq_w_key in state_dict_keys:
in_ch = int(state_dict[_lq_w_key].shape[1])
_gate_prefix = '{}lq_proj.gate_modules.'.format(key_prefix)
num_gates = len({k[len(_gate_prefix):].split('.')[0]
for k in state_dict_keys if k.startswith(_gate_prefix)})
dit_config = {"image_model": "pid",
"lq_latent_channels": in_ch,
"latent_spatial_down_factor": 16 if in_ch >= 64 else 8}
if num_gates > 0:
dit_config["lq_interval"] = (14 + num_gates - 1) // num_gates
return dit_config
if '{}core.pixel_embedder.proj.weight'.format(key_prefix) in state_dict_keys: # PixelDiT T2I
return {"image_model": "pixeldit_t2i"}
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
dit_config = {}
dit_config["image_model"] = "lumina2"
@ -755,6 +772,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["timestep_scale"] = 1000.0
return dit_config
if '{}transformer_blocks.0.attn.norm_added_q.weight'.format(key_prefix) in state_dict_keys \
and '{}transformer_blocks.0.img_mlp.w1.weight'.format(key_prefix) in state_dict_keys: # Lens
img_in_w = state_dict['{}img_in.weight'.format(key_prefix)]
proj_out_w = state_dict['{}proj_out.weight'.format(key_prefix)]
multi_layer = '{}txt_norm.0.weight'.format(key_prefix) in state_dict_keys
if multi_layer:
enc_hidden_dim = state_dict['{}txt_norm.0.weight'.format(key_prefix)].shape[0]
# Indices are TE-side; the DiT just consumes L layers in order.
selected_layer_index = tuple(range(count_blocks(state_dict_keys, '{}txt_norm.'.format(key_prefix) + '{}.')))
else:
enc_hidden_dim = state_dict['{}txt_norm.weight'.format(key_prefix)].shape[0]
selected_layer_index = (0,)
return {
"image_model": "lens",
"in_channels": img_in_w.shape[1],
"out_channels": proj_out_w.shape[0] // 4, # patch_size ** 2 (=2² default)
"num_layers": count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.'),
"num_attention_heads": img_in_w.shape[0] // 64, # // attention_head_dim default
"enc_hidden_dim": enc_hidden_dim,
"multi_layer_encoder_feature": multi_layer,
"selected_layer_index": selected_layer_index,
}
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {}
dit_config["image_model"] = "qwen_image"

View File

@ -15,6 +15,7 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import psutil
import logging
@ -27,13 +28,18 @@ import platform
import weakref
import gc
import os
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.host_buffer
import comfy_aimdo.vram_buffer
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
@ -204,6 +210,107 @@ def get_torch_device():
else:
return torch.device(torch.cuda.current_device())
def get_all_torch_devices(exclude_current=False):
global cpu_state
devices = []
if cpu_state == CPUState.GPU:
# NVIDIA + AMD/ROCm both expose their GPUs through torch.cuda.*;
# without the AMD arm, single-GPU ROCm users get an empty list
# which silently turns unload_all_models() into a no-op.
if is_nvidia() or is_amd():
for i in range(torch.cuda.device_count()):
devices.append(torch.device("cuda", i))
elif is_intel_xpu():
for i in range(torch.xpu.device_count()):
devices.append(torch.device("xpu", i))
elif is_ascend_npu():
for i in range(torch.npu.device_count()):
devices.append(torch.device("npu", i))
elif is_mlu():
for i in range(torch.mlu.device_count()):
devices.append(torch.device("mlu", i))
else:
# Fallback for unhandled GPU backends (e.g. DirectML): at least
# report the current device so callers like unload_all_models()
# do not silently no-op.
devices.append(get_torch_device())
else:
devices.append(get_torch_device())
if exclude_current:
current = get_torch_device()
if current in devices:
devices.remove(current)
return devices
def get_gpu_device_options():
"""Return list of device option strings for node widgets.
Always includes "default" and "cpu". When multiple GPUs are present,
adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels).
"""
options = ["default", "cpu"]
devices = get_all_torch_devices()
if len(devices) > 1:
for i in range(len(devices)):
options.append(f"gpu:{i}")
return options
def get_gpu_device_options_no_cpu():
"""Variant of get_gpu_device_options that omits "cpu".
Intended for components like the VAE selector where running on CPU
is impractical and should not be offered as a choice.
"""
return [o for o in get_gpu_device_options() if o != "cpu"]
def resolve_gpu_device_option(option: str):
"""Resolve a device option string to a torch.device.
Returns None for "default" (let the caller use its normal default).
Returns torch.device("cpu") for "cpu".
For "gpu:N", returns the Nth torch device. Returns None if the
index is out of range, the option string is malformed, or
unrecognized (callers are expected to log their own context-rich
message before falling back to the default device).
"""
if option is None or option == "default":
return None
if option == "cpu":
return torch.device("cpu")
if option.startswith("gpu:"):
try:
idx = int(option[4:])
except ValueError:
return None
devices = get_all_torch_devices()
if 0 <= idx < len(devices):
return devices[idx]
return None
@contextmanager
def cuda_device_context(device):
"""Context manager that sets torch.cuda.current_device to match *device*.
Used when running operations on a non-default CUDA device so that custom
CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
device index. The previous device is restored on exit.
No-op when *device* is not CUDA, has no explicit index, or already matches
the current device.
"""
prev = None
if device.type == "cuda" and device.index is not None:
prev = torch.cuda.current_device()
if prev != device.index:
torch.cuda.set_device(device)
else:
prev = None
try:
yield
finally:
if prev is not None:
torch.cuda.set_device(prev)
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
if dev is None:
@ -492,9 +599,13 @@ try:
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
except:
logging.warning("Could not pick default device.")
try:
for device in get_all_torch_devices(exclude_current=True):
logging.info("Device: {}".format(get_torch_device_name(device)))
except:
pass
current_loaded_models = []
current_loaded_models: list[LoadedModel] = []
DIRTY_MMAPS = set()
@ -554,7 +665,7 @@ def ensure_pin_registerable(size, evict_active=False):
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
class LoadedModel:
def __init__(self, model):
def __init__(self, model: ModelPatcher):
self._set_model(model)
self.device = model.load_device
self.real_model = None
@ -562,7 +673,7 @@ class LoadedModel:
self.model_finalizer = None
self._patcher_finalizer = None
def _set_model(self, model):
def _set_model(self, model: ModelPatcher):
self._model = weakref.ref(model)
if model.parent is not None:
self._parent_model = weakref.ref(model.parent)
@ -573,6 +684,7 @@ class LoadedModel:
model = self._parent_model()
if model is not None:
self._set_model(model)
self.device = model.load_device
@property
def model(self):
@ -1848,7 +1960,34 @@ def soft_empty_cache(force=False):
torch.cuda.ipc_collect()
def unload_all_models():
free_memory(1e30, get_torch_device())
for device in get_all_torch_devices():
free_memory(1e30, device)
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
'Unload only model and its clones - primarily for multigpu cloning purposes.'
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
additional_models = []
if unload_additional_models:
additional_models = model.get_nested_additional_models()
keep_loaded = []
for loaded_model in initial_keep_loaded:
if loaded_model.model is not None:
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
continue
# check additional models if they are a match
skip = False
for add_model in additional_models:
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
skip = True
break
if skip:
continue
keep_loaded.append(loaded_model)
if not all_devices:
free_memory(1e30, get_torch_device(), keep_loaded)
else:
for device in get_all_torch_devices():
free_memory(1e30, device, keep_loaded)
def debug_memory_summary():
if is_amd() or is_nvidia():

View File

@ -78,12 +78,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
def create_hook_patches_clone(orig_hook_patches):
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
new_hook_patches = {}
for hook_ref in orig_hook_patches:
new_hook_patches[hook_ref] = {}
for k in orig_hook_patches[hook_ref]:
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
if copy_tuples:
for i in range(len(new_hook_patches[hook_ref][k])):
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
return new_hook_patches
def wipe_lowvram_weight(m):
@ -329,7 +332,10 @@ class ModelPatcher:
self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
self.cached_patcher_init: tuple[Callable, tuple] | None = None
self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None
self.is_multigpu_base_clone = False
self.clone_base_uuid = uuid.uuid4()
if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0
@ -366,7 +372,8 @@ class ModelPatcher:
#than pays for CFG. So return everything both torch and Aimdo could give us
aimdo_mem = 0
if comfy.memory_management.aimdo_enabled:
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
aimdo_device = device.index if getattr(device, "type", None) == "cuda" else None
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze(aimdo_device)
return comfy.model_management.get_free_memory(device) + aimdo_mem
def get_clone_model_override(self):
@ -380,6 +387,8 @@ class ModelPatcher:
if self.cached_patcher_init is None:
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
if len(self.cached_patcher_init) > 2:
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
model_override = temp_model_patcher.get_clone_model_override()
if model_override is None:
model_override = self.get_clone_model_override()
@ -438,19 +447,113 @@ class ModelPatcher:
n.hook_mode = self.hook_mode
n.cached_patcher_init = self.cached_patcher_init
n.is_multigpu_base_clone = self.is_multigpu_base_clone
n.clone_base_uuid = self.clone_base_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n)
return n
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
if self.cached_patcher_init is None:
raise RuntimeError(
f"Cannot create multigpu deepclone of {self.model.__class__.__name__}: "
"the loader that produced this model does not support multigpu "
"(cached_patcher_init is not initialized). Use a core loader "
"(CheckpointLoaderSimple, UNETLoader, CLIPLoader/DualCLIPLoader, VAELoader), "
"or have the custom loader register a cached_patcher_init factory."
)
comfy.model_management.unload_model_and_clones(self)
# Produce a freshly-loaded patcher from the loader factory so the multigpu
# clone owns its own untainted model weights (rather than relying on
# copy.deepcopy of an already-patched/already-loaded module).
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
if len(self.cached_patcher_init) > 2:
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
# Override clone()'s normal "share self.model + share backup containers" with
# the pristine model from temp_model_patcher plus empty backup containers --
# the fresh model has no patches applied, so any deepcopy of self's stale
# backup/object_patches_backup/pinned would just propagate dead state that
# no longer corresponds to anything in n.model.
model_override = (temp_model_patcher.model, ({}, {}, {}, set()))
n = self.clone(model_override=model_override)
# clone() copies hook_backup by reference from self; reset since model is pristine.
n.hook_backup = {}
# set load device, if present
if new_load_device is not None:
n.load_device = new_load_device
# Ensure any per-device bookkeeping (e.g. ModelPatcherDynamic.dynamic_pins)
# has an entry for n.load_device on the freshly-loaded n.model. temp_model_patcher's
# __init__ only registered its own (default) load_device.
if hasattr(n, "register_load_device"):
n.register_load_device(n.load_device)
# multigpu clone should not have multigpu additional_models entry
n.remove_additional_models("multigpu")
# multigpu_clone all stored additional_models; make sure circular references are properly handled
if models_cache is None:
models_cache = {}
for key, model_list in n.additional_models.items():
for i in range(len(model_list)):
add_model = n.additional_models[key][i]
if add_model.clone_base_uuid not in models_cache:
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
callback(self, n)
return n
def match_multigpu_clones(self):
multigpu_models = self.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
new_multigpu_models = []
for mm in multigpu_models:
# clone main model, but bring over relevant props from existing multigpu clone
n = self.clone()
n.load_device = mm.load_device
n.backup = mm.backup
n.object_patches_backup = mm.object_patches_backup
n.hook_backup = mm.hook_backup
n.model = mm.model
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
n.remove_additional_models("multigpu")
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
# figure out which additional models are not present in multigpu clone
models_cache = {}
for mm_add_model in mm.get_additional_models():
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
remove_models_uuids = set(list(models_cache.keys()))
for key, model_list in orig_additional_models.items():
for orig_add_model in model_list:
if orig_add_model.clone_base_uuid not in models_cache:
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
existing_list = n.get_additional_models_with_key(key)
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
n.set_additional_models(key, existing_list)
if orig_add_model.clone_base_uuid in remove_models_uuids:
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
# remove duplicate additional models
for key, model_list in n.additional_models.items():
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
n.set_additional_models(key, new_model_list)
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
callback(self, n)
new_multigpu_models.append(n)
self.set_additional_models("multigpu", new_multigpu_models)
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def clone_has_same_weights(self, clone: 'ModelPatcher'):
if not self.is_clone(clone):
return False
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
if allow_multigpu:
if self.clone_base_uuid != clone.clone_base_uuid:
return False
else:
if not self.is_clone(clone):
return False
if self.current_hooks != clone.current_hooks:
return False
@ -1232,7 +1335,7 @@ class ModelPatcher:
return self.additional_models.get(key, [])
def get_additional_models(self):
all_models = []
all_models: list[ModelPatcher] = []
for models in self.additional_models.values():
all_models.extend(models)
return all_models
@ -1286,9 +1389,18 @@ class ModelPatcher:
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self)
def prepare_state(self, timestep):
def prepare_state(self, timestep, model_options):
ignore_multigpu = model_options.get("ignore_multigpu", False)
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep)
callback(self, timestep, model_options)
if not ignore_multigpu and "multigpu_clones" in model_options:
model_options["ignore_multigpu"] = True
try:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p.prepare_state(timestep, model_options)
finally:
model_options.pop("ignore_multigpu", None)
def restore_hook_patches(self):
if self.hook_patches_backup is not None:
@ -1301,12 +1413,18 @@ class ModelPatcher:
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
curr_t = t[0]
reset_current_hooks = False
multigpu_kf_changed_cache = None
transformer_options = model_options.get("transformer_options", {})
for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
# this will cause the weights to be recalculated when sampling
if changed:
# cache changed for multigpu usage
if "multigpu_clones" in model_options:
if multigpu_kf_changed_cache is None:
multigpu_kf_changed_cache = []
multigpu_kf_changed_cache.append(hook)
# reset current_hooks if contains hook that changed
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
@ -1318,6 +1436,28 @@ class ModelPatcher:
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
if "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
if kf_changed_cache is None:
return
reset_current_hooks = False
# reset current_hooks if contains hook that changed
for hook in kf_changed_cache:
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
if current_hook == hook:
reset_current_hooks = True
break
for cached_group in list(self.cached_hook_patches.keys()):
if cached_group.contains(hook):
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
registered: comfy.hooks.HookGroup = None):
@ -1566,16 +1706,27 @@ class ModelPatcherDynamic(ModelPatcher):
self.model.dynamic_vbars = {}
if not hasattr(self.model, "dynamic_pins"):
self.model.dynamic_pins = {}
if self.load_device not in self.model.dynamic_pins:
self.model.dynamic_pins[self.load_device] = {
self.register_load_device(self.load_device)
self.non_dynamic_delegate_model = None
assert load_device is not None
def register_load_device(self, device):
"""Ensure dynamic_pins has an entry for *device*.
Called from __init__ and also from any code that retargets an
already-constructed patcher to a new load_device (e.g. the
Select{Model,CLIP,VAE}Device selector nodes); without this entry
partially_unload_ram() raises KeyError when it tries to read the
per-device pin state.
"""
if device not in self.model.dynamic_pins:
self.model.dynamic_pins[device] = {
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"hostbufs_initialized": False,
"failed": False,
"active": False,
}
self.non_dynamic_delegate_model = None
assert load_device is not None
def is_dynamic(self):
return True

248
comfy/multigpu.py Normal file
View File

@ -0,0 +1,248 @@
from __future__ import annotations
import queue
import threading
import torch
import logging
from collections import namedtuple
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
import comfy.utils
import comfy.patcher_extension
import comfy.model_management
class MultiGPUThreadPool:
"""Persistent thread pool for multi-GPU work distribution.
Maintains one worker thread per extra GPU device. Each thread calls
torch.cuda.set_device() once at startup so that compiled kernel caches
(inductor/triton) stay warm across diffusion steps.
"""
def __init__(self, devices: list[torch.device]):
self._workers: list[threading.Thread] = []
self._work_queues: dict[torch.device, queue.Queue] = {}
self._result_queues: dict[torch.device, queue.Queue] = {}
for device in devices:
wq = queue.Queue()
rq = queue.Queue()
self._work_queues[device] = wq
self._result_queues[device] = rq
t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
t.start()
self._workers.append(t)
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
try:
torch.cuda.set_device(device)
except Exception as e:
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
while True:
item = work_q.get()
if item is None:
return
result_q.put((None, e))
return
while True:
item = work_q.get()
if item is None:
break
fn, args, kwargs = item
try:
result = fn(*args, **kwargs)
result_q.put((result, None))
except Exception as e:
result_q.put((None, e))
def submit(self, device: torch.device, fn, *args, **kwargs):
self._work_queues[device].put((fn, args, kwargs))
def get_result(self, device: torch.device):
return self._result_queues[device].get()
@property
def devices(self) -> list[torch.device]:
return list(self._work_queues.keys())
def shutdown(self):
for wq in self._work_queues.values():
wq.put(None) # sentinel
for t in self._workers:
t.join(timeout=5.0)
class GPUOptions:
def __init__(self, device_index: int, relative_speed: float):
self.device_index = device_index
self.relative_speed = relative_speed
def clone(self):
return GPUOptions(self.device_index, self.relative_speed)
def create_dict(self):
return {
"relative_speed": self.relative_speed
}
class GPUOptionsGroup:
def __init__(self):
self.options: dict[int, GPUOptions] = {}
def add(self, info: GPUOptions):
self.options[info.device_index] = info
def clone(self):
c = GPUOptionsGroup()
for opt in self.options.values():
c.add(opt)
return c
def register(self, model: ModelPatcher):
opts_dict = {}
# get devices that are valid for this model
devices: list[torch.device] = [model.load_device]
for extra_model in model.get_additional_models_with_key("multigpu"):
extra_model: ModelPatcher
devices.append(extra_model.load_device)
# create dictionary with actual device mapped to its GPUOptions
device_opts_list: list[GPUOptions] = []
for device in devices:
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
opts_dict[device] = device_opts.create_dict()
device_opts_list.append(device_opts)
# make relative_speed relative to 1.0
min_speed = min([x.relative_speed for x in device_opts_list])
for value in opts_dict.values():
value['relative_speed'] /= min_speed
model.model_options['multigpu_options'] = opts_dict
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
model = model.clone()
# check if multigpu is already prepared - get the load devices from them if possible to exclude
skip_devices = set()
multigpu_models = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
for mm in multigpu_models:
skip_devices.add(mm.load_device)
skip_devices = list(skip_devices)
# Exclude the primary model's actual device, not the global current device:
# after SelectModelDevice(gpu:N) the primary may not live on the process's
# current CUDA device, and excluding the wrong device picks bad extras.
all_devices = comfy.model_management.get_all_torch_devices(exclude_current=False)
full_extra_devices = [d for d in all_devices if d != model.load_device]
limit_extra_devices = full_extra_devices[:max_gpus-1]
extra_devices = limit_extra_devices.copy()
# exclude skipped devices
for skip in skip_devices:
if skip in extra_devices:
extra_devices.remove(skip)
# create new deepclones
if len(extra_devices) > 0:
for device in extra_devices:
device_patcher = None
if reuse_loaded:
# Only reuse a previously-loaded MultiGPU clone. A SelectModelDevice
# patcher on the same device shares clone_base_uuid but has
# is_multigpu_base_clone=False, which would later be filtered out by
# prepare_model_patcher_multigpu_clones() and silently shrink the
# work split back to one GPU.
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
for lm in loaded_models:
if lm.model is None:
continue
if lm.load_device != device:
continue
if lm.clone_base_uuid != model.clone_base_uuid:
continue
if not getattr(lm, "is_multigpu_base_clone", False):
continue
device_patcher = lm.clone()
logging.info(f"Reusing loaded multigpu deepclone of {device_patcher.model.__class__.__name__} for {device}")
break
if device_patcher is None:
device_patcher = model.deepclone_multigpu(new_load_device=device)
# Always flag the clone; whether reused or freshly deepcloned, it must
# advertise itself as a MultiGPU base clone so the cond scheduler picks
# it up in prepare_model_patcher_multigpu_clones().
device_patcher.is_multigpu_base_clone = True
multigpu_models = model.get_additional_models_with_key("multigpu")
multigpu_models.append(device_patcher)
model.set_additional_models("multigpu", multigpu_models)
model.match_multigpu_clones()
if gpu_options is None:
gpu_options = GPUOptionsGroup()
gpu_options.register(model)
else:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
# only keep model clones that don't go 'past' the intended max_gpu count;
# this prunes any inherited multigpu clones whose load_device is no longer allowed
# when max_gpus is lowered between runs.
allowed_devices = set(limit_extra_devices)
allowed_devices.add(model.load_device)
multigpu_models = model.get_additional_models_with_key("multigpu")
new_multigpu_models = [m for m in multigpu_models if m.load_device in allowed_devices]
if len(new_multigpu_models) != len(multigpu_models):
model.set_additional_models("multigpu", new_multigpu_models)
model.match_multigpu_clones()
return model
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
opts_dict = model_options['multigpu_options']
devices = list(model_options['multigpu_clones'].keys())
speed_per_device = []
work_per_device = []
# get sum of each device's relative_speed
total_speed = 0.0
for opts in opts_dict.values():
total_speed += opts['relative_speed']
# get relative work for each device;
# obtained by w = (W*r)/R
for device in devices:
relative_speed = opts_dict[device]['relative_speed']
relative_work = (total_work*relative_speed) / total_speed
speed_per_device.append(relative_speed)
work_per_device.append(relative_work)
# relative work must be expressed in whole numbers, but likely is a decimal;
# perform rounding while maintaining total sum equal to total work (sum of relative works)
work_per_device = round_preserved(work_per_device)
dict_work_per_device = {}
for device, relative_work in zip(devices, work_per_device):
dict_work_per_device[device] = relative_work
if not return_idle_time:
return LoadBalance(dict_work_per_device, None)
# divide relative work by relative speed to get estimated completion time of said work by each device;
# time here is relative and does not correspond to real-world units
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
# calculate relative time spent by the devices waiting on each other after their work is completed
idle_time = abs(min(completion_time) - max(completion_time))
# if need to compare work idle time, need to normalize to a common total work
if work_normalized:
idle_time *= (work_normalized/total_work)
return LoadBalance(dict_work_per_device, idle_time)
def round_preserved(values: list[float]):
'Round all values in a list, preserving the combined sum of values.'
# get floor of values; casting to int does it too
floored = [int(x) for x in values]
total_floored = sum(floored)
# get remainder to distribute
remainder = round(sum(values)) - total_floored
# pair values with fractional portions
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
# sort by fractional part in descending order
fractional.sort(key=lambda x: x[1], reverse=True)
# distribute the remainder
for i in range(remainder):
index = fractional[i][0]
floored[index] += 1
return floored

View File

@ -18,6 +18,7 @@
import torch
import logging
import contextlib
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
@ -1047,6 +1048,144 @@ class QuantLinearFunc(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None
# Quantized-weight module helpers
def _quantized_apply(module, fn, recurse=True):
"""Re-wrap Parameters after fn so .to()/.cuda() propagate through QuantizedTensor weights."""
if recurse:
for child in module.children():
child._apply(fn)
for key, param in module._parameters.items():
if param is None:
continue
p = fn(param)
if (not torch.is_inference_mode_enabled()) and p.is_inference():
p = p.clone()
module.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in module._buffers.items():
if buf is not None:
module._buffers[key] = fn(buf)
return module
def _load_quantized_module(module, super_load, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs, load_extra_params=False):
"""Shared _load_from_state_dict body for quantized-weight modules.
Pops weight (+ scales, +/- extras), populates module.weight as a Parameter
or Parameter-wrapped QuantizedTensor, then calls super_load and strips
consumed keys from missing_keys. Reads compute_dtype from factory_kwargs
and disabled formats from module._disabled_formats.
"""
device = module.factory_kwargs["device"]
compute_dtype = module.factory_kwargs["dtype"]
disabled_formats = module._disabled_formats
layer_name = prefix.rstrip('.')
weight = state_dict.pop(f"{prefix}weight", None)
if weight is None:
logging.warning(f"Missing weight for layer {layer_name}")
module.weight = None
return
manually_loaded_keys = [f"{prefix}weight"]
def pop_scale(name, dtype=None):
key = f"{prefix}{name}"
v = state_dict.pop(key, None)
if v is not None:
v = v.to(device=device)
if dtype is not None:
v = v.view(dtype=dtype)
manually_loaded_keys.append(key)
return v
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
module.weight = torch.nn.Parameter(weight.to(device=device, dtype=compute_dtype), requires_grad=False)
else:
module.quant_format = layer_conf.get("format", None)
module._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
if not module._full_precision_mm:
module._full_precision_mm = module._full_precision_mm_config
if module.quant_format in disabled_formats:
module._full_precision_mm = True
if module.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[module.quant_format]
module.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(module.layout_type)
# Per-format scales; fp8 dtype views handle both legacy uint8-on-disk and native fp8.
if module.quant_format in ("float8_e4m3fn", "float8_e5m2"):
scales = {"scale": pop_scale("weight_scale")}
elif module.quant_format == "mxfp8":
bs = pop_scale("weight_scale", torch.float8_e8m0fnu)
if bs is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
scales = {"scale": bs}
elif module.quant_format == "nvfp4":
ts = pop_scale("weight_scale_2")
bs = pop_scale("weight_scale", torch.float8_e4m3fn)
if ts is None or bs is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
scales = {"scale": ts, "block_scale": bs}
else:
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
params = layout_cls.Params(**scales, orig_dtype=compute_dtype, orig_shape=module._orig_shape)
module.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), module.layout_type, params),
requires_grad=False,
)
if load_extra_params:
for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
continue
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
module.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super_load(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extra_quant_params=()):
"""Shared state_dict body. extra_quant_conf merges into the comfy_quant JSON;
extra_quant_params names attributes written as additional top-level keys."""
if not hasattr(module, 'weight'):
logging.warning(f"Warning: state dict on uninitialized op {prefix}")
return sd
bias = getattr(module, 'bias', None)
if bias is not None:
sd[f"{prefix}bias"] = bias
if module.weight is None:
return sd
if isinstance(module.weight, QuantizedTensor):
sd.update(module.weight.state_dict(f"{prefix}weight"))
quant_conf = {"format": module.quant_format}
if getattr(module, '_full_precision_mm_config', False):
quant_conf["full_precision_matrix_mult"] = True
if extra_quant_conf:
quant_conf.update(extra_quant_conf)
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
for name in extra_quant_params:
value = getattr(module, name, None)
if value is not None:
sd[f"{prefix}{name}"] = value
else:
sd[f"{prefix}weight"] = module.weight
return sd
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast):
@ -1056,21 +1195,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
_disabled = disabled
class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
_disabled_formats = disabled
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
self._orig_shape = (out_features, in_features)
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
@ -1083,151 +1217,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def reset_parameters(self):
return None
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
key = f"{prefix}{param_name}"
value = state_dict.pop(key, None)
if value is not None:
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
manually_loaded_keys.append(key)
return value
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
logging.warning(f"Missing weight for layer {layer_name}")
self.weight = None
return
manually_loaded_keys = [weight_key]
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
self.quant_format = layer_conf.get("format", None)
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
if not self._full_precision_mm:
self._full_precision_mm = self._full_precision_mm_config
if self.quant_format in MixedPrecisionOps._disabled:
self._full_precision_mm = True
if self.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[self.quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
# Load format-specific parameters
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
# FP8: single tensor scale
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
params = layout_cls.Params(
scale=scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "mxfp8":
# MXFP8: E8M0 block scales stored as uint8 in safetensors
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.uint8)
if block_scale is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
block_scale = block_scale.view(torch.float8_e8m0fnu)
params = layout_cls.Params(
scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.float8_e4m3fn)
if tensor_scale is None or block_scale is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
params = layout_cls.Params(
scale=tensor_scale,
block_scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
else:
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
requires_grad=False
)
for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
continue # Already handled above
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def _load_from_state_dict(self, *args):
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=True)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
if destination is not None:
sd = destination
else:
sd = {}
if not hasattr(self, 'weight'):
logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
return sd
if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias
if self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
input_scale = getattr(self, 'input_scale', None)
if input_scale is not None:
sd["{}input_scale".format(prefix)] = input_scale
else:
sd["{}weight".format(prefix)] = self.weight
return sd
sd = destination if destination is not None else {}
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_params=("input_scale",))
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
@ -1317,25 +1312,126 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.weight = torch.nn.Parameter(weight, requires_grad=False)
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
if recurse:
for module in self.children():
module._apply(fn)
return _quantized_apply(self, fn, recurse)
for key, param in self._parameters.items():
if param is None:
continue
p = fn(param)
if (not torch.is_inference_mode_enabled()) and p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
class MoEExperts(torch.nn.Module, CastWeightBiasOp):
"""Container for E quantized expert weights, indexed via expert_weight(i).
The bank lives on self.weight as a single 3D tensor either a
compute_dtype Parameter or a Parameter wrapping a QuantizedTensor
with leading expert dim.
State-dict layout matches mixed_precision_ops.Linear with a leading
expert dim:
{prefix}.weight quant data (storage_t), leading dim = E
{prefix}.weight_scale block / per-tensor scale
{prefix}.weight_scale_2 [E] or scalar NVFP4 only
{prefix}.bias [E, out_features] optional, compute_dtype
{prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}}
Without comfy_quant the weight loads as a plain compute_dtype 3D Parameter [E, out, in].
"""
_disabled_formats = disabled
def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
super().__init__()
self.num_experts = num_experts
self.in_features = in_features
self.out_features = out_features
self._orig_shape = (num_experts, out_features, in_features)
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
if bias:
self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
# Populated by _load_from_state_dict:
self.weight = None
self.quant_format = None
self.layout_type = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
self._full_precision_mm_config = False
self._resident_bank = None
def reset_parameters(self):
return None
def _apply(self, fn, recurse=True):
return _quantized_apply(self, fn, recurse)
def _load_from_state_dict(self, *args):
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=False)
def expert_weight(self, i: int):
"""Expert i's weight (Tensor or per-expert QuantizedTensor view)."""
if isinstance(self.weight, QuantizedTensor):
return self._expert_qt_from(self.weight, i)
return self.weight[i]
@contextlib.contextmanager
def bank_resident(self, input):
"""Cast the whole bank once; expert_linear inside reuses the cast.
Not re-entrant do not nest calls on the same instance.
"""
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
self._resident_bank = (weight, bias)
try:
yield self
finally:
self._resident_bank = None
uncast_bias_weight(self, weight, bias, offload_stream)
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
"""Linear against expert i's weight (with optional bias)."""
resident = getattr(self, "_resident_bank", None)
if resident is not None:
weight, bias = resident
return self._expert_linear_impl(input, weight, bias, i)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
try:
return self._expert_linear_impl(input, weight, bias, i)
finally:
uncast_bias_weight(self, weight, bias, offload_stream)
def _expert_linear_impl(self, input, weight, bias, i):
if isinstance(weight, QuantizedTensor):
qw = self._expert_qt_from(weight, i)
else:
qw = weight[i]
b = cast_to_input(bias[i], input, copy=False) if bias is not None else None
if isinstance(qw, QuantizedTensor):
use_fast = (
not self._full_precision_mm
and qw.layout_cls.supports_fast_matmul()
and input.dim() == 2
)
if use_fast:
qin = QuantizedTensor.from_float(input, self.layout_type)
return torch.nn.functional.linear(qin, qw, b)
out = input @ qw.dequantize().t()
return out + b if b is not None else out
return torch.nn.functional.linear(input, qw, b)
def _expert_qt_from(self, weight: QuantizedTensor, i: int) -> QuantizedTensor:
"""Build a per-expert QuantizedTensor by indexing into a resident bank."""
params = weight._params
kwargs = {
"scale": params.scale[i] if params.scale.dim() else params.scale,
"orig_dtype": params.orig_dtype,
"orig_shape": (self.out_features, self.in_features),
}
if hasattr(params, "block_scale"): # NVFP4
kwargs["block_scale"] = params.block_scale[i]
return QuantizedTensor(weight._qdata[i], weight._layout_cls, type(params)(**kwargs))
def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = destination if destination is not None else {}
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_conf={"num_experts": self.num_experts})
class Embedding(manual_cast.Embedding):
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
weight_key = f"{prefix}weight"
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
@ -1343,14 +1439,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
# Only fp8 makes sense for embeddings (per-row dequant via index select).
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
quant_format = layer_conf.get("format") if layer_conf is not None else None
manually_loaded_keys = []
if quant_format in ("float8_e4m3fn", "float8_e5m2") and weight_key in state_dict:
self.quant_format = quant_format
qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
weight = state_dict.pop(weight_key)
manually_loaded_keys = [weight_key]
manually_loaded_keys.append(weight_key)
scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(scale_key, None)
@ -1366,35 +1464,19 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
requires_grad=False)
elif layer_conf is not None:
# Unsupported format — restore the marker so it round-trips; fall through to default load.
state_dict[f"{prefix}comfy_quant"] = torch.tensor(
list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for k in manually_loaded_keys:
if k in missing_keys:
missing_keys.remove(k)
else:
if layer_conf is not None:
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for k in manually_loaded_keys:
if k in missing_keys:
missing_keys.remove(k)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
if destination is not None:
sd = destination
else:
sd = {}
if not hasattr(self, 'weight') or self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
else:
sd["{}weight".format(prefix)] = self.weight
return sd
sd = destination if destination is not None else {}
return _quantized_weight_state_dict(self, sd, prefix)
def forward_comfy_cast_weights(self, input, out_dtype=None):
weight = self.weight

View File

@ -1,8 +1,9 @@
from __future__ import annotations
from typing import Callable
class CallbacksMP:
ON_CLONE = "on_clone"
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
ON_LOAD = "on_load_after"
ON_DETACH = "on_detach_after"
ON_CLEANUP = "on_cleanup"

View File

@ -1,16 +1,18 @@
from __future__ import annotations
import torch
import uuid
import math
import collections
import comfy.model_management
import comfy.conds
import comfy.model_patcher
import comfy.utils
import comfy.hooks
import comfy.patcher_extension
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.controlnet import ControlBase
def prepare_mask(noise_mask, shape, device):
@ -119,6 +121,47 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'):
m.cleanup()
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) == 0:
return
extra_devices = [x.load_device for x in multigpu_models]
# handle controlnets
controlnets: set[ControlBase] = set()
for k in conds:
for kk in conds[k]:
if 'control' in kk:
controlnets.add(kk['control'])
if len(controlnets) > 0:
# first, unload all controlnet clones
for cnet in list(controlnets):
cnet_models = cnet.get_models()
for cm in cnet_models:
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
# next, make sure each controlnet has a deepclone for all relevant devices
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
for device in extra_devices:
if device not in curr_cnet.multigpu_clones:
curr_cnet.deepclone_multigpu(device, autoregister=True)
curr_cnet = curr_cnet.previous_controlnet
# since all device clones are now present, recreate the linked list for cloned cnets per device
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
prev_cnet = curr_cnet.previous_controlnet
for device in extra_devices:
device_cnet = curr_cnet.get_instance_for_device(device)
prev_device_cnet = None
if prev_cnet is not None:
prev_device_cnet = prev_cnet.get_instance_for_device(device)
device_cnet.set_previous_controlnet(prev_device_cnet)
curr_cnet = prev_cnet
# potentially handle gligen - since not widely used, ignored for now
def estimate_memory(model, noise_shape, conds):
cond_shapes = collections.defaultdict(list)
cond_shapes_min = {}
@ -143,7 +186,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
real_model: BaseModel = None
model.match_multigpu_clones()
preprocess_multigpu_conds(conds, model, model_options)
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
@ -155,7 +199,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
memory_required += inference_memory
minimum_memory_required += inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
real_model = model.model
real_model: BaseModel = model.model
return real_model, conds, models
@ -201,3 +245,18 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False)
return to_load_options
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
'''
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
'''
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
if len(multigpu_patchers) > 0:
multigpu_dict: dict[torch.device, ModelPatcher] = {}
multigpu_dict[model_patcher.load_device] = model_patcher
for x in multigpu_patchers:
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
multigpu_dict[x.load_device] = x
model_options["multigpu_clones"] = multigpu_dict
return multigpu_patchers

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import comfy.model_management
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
from typing import TYPE_CHECKING, Callable, NamedTuple
from typing import TYPE_CHECKING, Callable, NamedTuple, Any
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
@ -16,6 +18,7 @@ import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
import comfy.multigpu
import comfy.utils
import scipy.stats
import numpy
@ -141,7 +144,7 @@ def can_concat_cond(c1, c2):
return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list):
def cond_cat(c_list, device=None):
temp = {}
for x in c_list:
for k in x:
@ -153,6 +156,8 @@ def cond_cat(c_list):
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
if device is not None and hasattr(out[k], 'to'):
out[k] = out[k].to(device)
return out
@ -212,7 +217,12 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
)
return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
# NOTE: keep in sync with _calc_cond_batch_multigpu below. Shared logic
# (hooked_to_run accumulation, memory-fit batching, per-chunk output
# aggregation) is duplicated there with per-device scheduling layered on top.
if 'multigpu_clones' in model_options:
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
out_conds = []
out_counts = []
# separate conds by matching hooks
@ -244,7 +254,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep)
model.current_patcher.prepare_state(timestep, model_options)
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
@ -344,6 +354,239 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
return out_conds
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
# NOTE: keep in sync with _calc_cond_batch above. Same conds-by-hooks
# accumulation, memory-fit batching, and output aggregation, but adds a
# per-device scheduler, per-device patcher/control lookup, tensor .to(device)
# placement, and MultiGPUThreadPool dispatch around the inner loop.
out_conds = []
out_counts = []
# separate conds by matching hooks
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
default_conds = []
has_default_conds = False
output_device = x_in.device
for i in range(len(conds)):
out_conds.append(torch.zeros_like(x_in))
out_counts.append(torch.ones_like(x_in) * 1e-37)
cond = conds[i]
default_c = []
if cond is not None:
for x in cond:
if 'default' in x:
default_c.append(x)
has_default_conds = True
continue
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
default_conds.append(default_c)
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep, model_options)
devices = list(model_options['multigpu_clones'].keys())
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
# Track conds currently scheduled per device; single source of truth for capacity checks.
device_load: dict[torch.device, int] = {d: 0 for d in devices}
total_conds = sum(len(to_run) for to_run in hooked_to_run.values())
conds_per_device = max(1, math.ceil(total_conds / len(devices)))
def next_available_device(start: int) -> tuple[int, torch.device]:
"""Return (index, device) for the next device with remaining capacity, starting at `start`.
Scans at most len(devices) positions, so this always terminates. Raises if no device
has remaining capacity, which would indicate a bug in conds_per_device accounting.
"""
for offset in range(len(devices)):
i = (start + offset) % len(devices)
if device_load[devices[i]] < conds_per_device:
return i, devices[i]
raise RuntimeError(
f"MultiGPU scheduler: all {len(devices)} devices at capacity "
f"({conds_per_device}) but conds remain to schedule"
)
# run every hooked_to_run separately
index_device = 0
for hooks, to_run in hooked_to_run.items():
while len(to_run) > 0:
index_device, current_device = next_available_device(index_device)
remaining_capacity = conds_per_device - device_load[current_device]
first = to_run[0]
first_shape = first[0][0].shape
# collect candidate indices that can be concatenated with `first`, up to remaining capacity
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity:
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = comfy.model_management.get_free_memory(current_device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
cond_shapes = collections.defaultdict(list)
for tt in batch_amount:
for k, v in to_run[tt][0].conditioning.items():
cond_shapes[k].append(v.size())
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
to_batch = batch_amount
break
conds_to_batch = [to_run.pop(x) for x in to_batch]
device_load[current_device] += len(conds_to_batch)
device_batched_hooked_to_run.setdefault(current_device, []).append((hooks, conds_to_batch))
if device_load[current_device] >= conds_per_device:
index_device += 1
class thread_result(NamedTuple):
output: Any
mult: Any
area: Any
batch_chunks: int
cond_or_uncond: Any
error: Exception = None
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
try:
# TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
# we extend multigpu QA beyond CUDA. Unconditional call crashes on
# XPU/NPU/MPS/CPU/DirectML backends.
torch.cuda.set_device(device)
model_current: BaseModel = model_options["multigpu_clones"][device].model
# run every hooked_to_run separately
with torch.no_grad():
for hooks, to_batch in batch_tuple:
input_x = []
mult = []
c = []
cond_or_uncond = []
uuids = []
area = []
control: ControlBase = None
patches = None
for x in to_batch:
o = x
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
uuids.append(p.uuid)
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x).to(device)
c = cond_cat(c, device=device)
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options:
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'],
copy_dict1=False)
if patches is not None:
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
transformer_options.get("patches", {}),
patches
)
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
transformer_options["sigmas"] = timestep.to(device)
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
transformer_options["multigpu_thread_device"] = device
cast_transformer_options(transformer_options, device=device)
c['transformer_options'] = transformer_options
if control is not None:
device_control = control.get_instance_for_device(device)
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
else:
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
# TODO: non-NVIDIA support -- the `.to(output_device)` copies
# above are async on CUDA, so the main thread's aggregation
# could race with in-flight transfers. CUDA-only QA has not
# surfaced this in practice, but before extending multigpu
# beyond NVIDIA add a `torch.cuda.synchronize(output_device)`
# here (guarded by `output_device.type == "cuda"`).
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
except Exception as e:
results.append(thread_result(None, None, None, None, None, error=e))
raise
def _handle_batch_pooled(device, batch_tuple):
worker_results = []
_handle_batch(device, batch_tuple, worker_results)
return worker_results
results: list[thread_result] = []
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
# Submit all GPU work to pool threads
pool_devices = []
for device, batch_tuple in device_batched_hooked_to_run.items():
if thread_pool is not None:
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
pool_devices.append(device)
else:
# Fallback: no pool, run everything on main thread
_handle_batch(device, batch_tuple, results)
# Collect results from pool workers
for device in pool_devices:
worker_results, error = thread_pool.get_result(device)
if error is not None:
raise error
results.extend(worker_results)
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
if error is not None:
raise error
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]
return out_conds
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
@ -642,12 +885,21 @@ def calculate_start_end_timesteps(model, conds):
def pre_run_control(model, conds):
s = model.model_sampling
# Per-device model lookup so multigpu control clones get the matching
# diffusion_model (e.g. QwenFunControlNet stashes it into extra_args).
device_models: dict = {}
patcher = getattr(model, "current_patcher", None)
if patcher is not None:
for p in patcher.get_additional_models_with_key("multigpu"):
device_models[p.load_device] = p.model
for t in range(len(conds)):
x = conds[t]
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function)
for device, device_cnet in x['control'].multigpu_clones.items():
device_cnet.pre_run(device_models.get(device, model), percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
@ -890,7 +1142,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
to_load_options = model_options.get("to_load_options", None)
if to_load_options is None:
return
cast_transformer_options(to_load_options, device, dtype)
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
casts = []
if device is not None:
casts.append(device)
@ -899,18 +1153,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# if nothing to apply, do nothing
if len(casts) == 0:
return
# try to call .to on patches
if "patches" in to_load_options:
patches = to_load_options["patches"]
if "patches" in transformer_options:
patches = transformer_options["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
for cast in casts:
patch_list[i] = patch_list[i].to(cast)
if "patches_replace" in to_load_options:
patches = to_load_options["patches_replace"]
if "patches_replace" in transformer_options:
patches = transformer_options["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
@ -920,8 +1173,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks:
if wc_name in to_load_options:
wc: dict[str, list] = to_load_options[wc_name]
if wc_name in transformer_options:
wc: dict[str, list] = transformer_options[wc_name]
for wc_dict in wc.values():
for wc_list in wc_dict.values():
for i in range(len(wc_list)):
@ -929,7 +1182,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
for cast in casts:
wc_list[i] = wc_list[i].to(cast)
class CFGGuider:
def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher
@ -984,16 +1236,32 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
try:
self.model_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
self.model_patcher.cleanup()
# Create persistent thread pool for all GPU devices (main + extras)
if multigpu_patchers:
extra_devices = [p.load_device for p in multigpu_patchers]
all_devices = [device] + extra_devices
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
with comfy.model_management.cuda_device_context(device):
try:
noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
self.model_patcher.pre_run()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
if thread_pool is not None:
thread_pool.shutdown()
self.model_patcher.cleanup()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model

View File

@ -1,4 +1,3 @@
from __future__ import annotations
import json
import torch
from enum import Enum
@ -50,6 +49,7 @@ import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.pixeldit
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.text_encoders.ace
@ -69,6 +69,7 @@ import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4
import comfy.text_encoders.cogvideo
import comfy.text_encoders.sa3
import comfy.text_encoders.gpt_oss
import comfy.model_patcher
import comfy.lora
@ -335,41 +336,43 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"execution_device": device})
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
pbar = ProgressBar(len(scheduled_keyframes))
for scheduled_opts in scheduled_keyframes:
t_range = scheduled_opts[0]
# don't bother encoding any conds outside of start_percent and end_percent bounds
if "start_percent" in add_dict:
if t_range[1] < add_dict["start_percent"]:
continue
if "end_percent" in add_dict:
if t_range[0] > add_dict["end_percent"]:
continue
hooks_keyframes = scheduled_opts[1]
for hook, keyframe in hooks_keyframes:
hook.hook_keyframe._current_keyframe = keyframe
# apply appropriate hooks with values that match new hook_keyframe
self.patcher.patch_hooks(all_hooks)
# perform encoding as normal
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
pooled_dict = {"pooled_output": pooled}
# add clip_start_percent and clip_end_percent in pooled
pooled_dict["clip_start_percent"] = t_range[0]
pooled_dict["clip_end_percent"] = t_range[1]
# add/update any keys with the provided add_dict
pooled_dict.update(add_dict)
# add hooks stored on clip
self.add_hooks_to_dict(pooled_dict)
all_cond_pooled.append([cond, pooled_dict])
if show_pbar:
pbar.update(1)
model_management.throw_exception_if_processing_interrupted()
with model_management.cuda_device_context(device):
for scheduled_opts in scheduled_keyframes:
t_range = scheduled_opts[0]
# don't bother encoding any conds outside of start_percent and end_percent bounds
if "start_percent" in add_dict:
if t_range[1] < add_dict["start_percent"]:
continue
if "end_percent" in add_dict:
if t_range[0] > add_dict["end_percent"]:
continue
hooks_keyframes = scheduled_opts[1]
for hook, keyframe in hooks_keyframes:
hook.hook_keyframe._current_keyframe = keyframe
# apply appropriate hooks with values that match new hook_keyframe
self.patcher.patch_hooks(all_hooks)
# perform encoding as normal
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
pooled_dict = {"pooled_output": pooled}
# add clip_start_percent and clip_end_percent in pooled
pooled_dict["clip_start_percent"] = t_range[0]
pooled_dict["clip_end_percent"] = t_range[1]
# add/update any keys with the provided add_dict
pooled_dict.update(add_dict)
# add hooks stored on clip
self.add_hooks_to_dict(pooled_dict)
all_cond_pooled.append([cond, pooled_dict])
if show_pbar:
pbar.update(1)
model_management.throw_exception_if_processing_interrupted()
all_hooks.reset()
return all_cond_pooled
@ -383,8 +386,12 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
o = self.cond_stage_model.encode_token_weights(tokens)
device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"execution_device": device})
with model_management.cuda_device_context(device):
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
if return_dict:
out = {"cond": cond, "pooled_output": pooled}
@ -446,9 +453,12 @@ class CLIP:
self.cond_stage_model.reset_clip_options()
self.load_model(tokens)
device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
self.cond_stage_model.set_clip_options({"execution_device": device})
with model_management.cuda_device_context(device):
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
@ -1026,50 +1036,52 @@ class VAE:
do_tile = False
if self.latent_dim == 2 and samples_in.ndim == 5:
samples_in = samples_in[:, :, 0]
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
# Pre-allocate output for VAEs that support direct buffer writes
preallocated = False
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
preallocated = True
with model_management.cuda_device_context(self.device):
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
if preallocated:
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
else:
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number].copy_(out)
del out
self.process_output(pixel_samples[x:x+batch_number])
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
# Pre-allocate output for VAEs that support direct buffer writes
preallocated = False
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
preallocated = True
if do_tile:
comfy.model_management.soft_empty_cache()
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
if preallocated:
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
else:
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number].copy_(out)
del out
self.process_output(pixel_samples[x:x+batch_number])
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
@ -1087,20 +1099,21 @@ class VAE:
if overlap is not None:
args["overlap"] = overlap
if dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
output = self.decode_tiled_(samples, **args)
elif dims == 3:
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (max(1, overlap_t), overlap, overlap)
if tile_t is not None:
args["tile_t"] = max(2, tile_t)
with model_management.cuda_device_context(self.device):
if dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
output = self.decode_tiled_(samples, **args)
elif dims == 3:
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (max(1, overlap_t), overlap, overlap)
if tile_t is not None:
args["tile_t"] = max(2, tile_t)
output = self.decode_tiled_3d(samples, **args)
output = self.decode_tiled_3d(samples, **args)
return output.movedim(1, -1)
def encode(self, pixel_samples):
@ -1113,44 +1126,46 @@ class VAE:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
else:
pixel_samples = pixel_samples.unsqueeze(2)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
out = self.first_stage_model.encode(pixels_in, device=self.device)
with model_management.cuda_device_context(self.device):
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
out = self.first_stage_model.encode(pixels_in, device=self.device)
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
else:
samples = self.encode_tiled_(pixel_samples)
samples = self.encode_tiled_(pixel_samples)
return samples
@ -1176,26 +1191,27 @@ class VAE:
if overlap is not None:
args["overlap"] = overlap
if dims == 1:
args.pop("tile_y")
samples = self.encode_tiled_1d(pixel_samples, **args)
elif dims == 2:
samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3:
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
with model_management.cuda_device_context(self.device):
if dims == 1:
args.pop("tile_y")
samples = self.encode_tiled_1d(pixel_samples, **args)
elif dims == 2:
samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3:
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
return samples
@ -1269,6 +1285,8 @@ class CLIPType(Enum):
FLUX2 = 25
LONGCAT_IMAGE = 26
COGVIDEOX = 27
LENS = 28
PIXELDIT = 29
@ -1321,6 +1339,7 @@ class TEModel(Enum):
GEMMA_4_E2B = 30
GEMMA_4_31B = 31
T5_GEMMA = 32
GPT_OSS_20B = 33
def detect_te_model(sd):
@ -1362,6 +1381,9 @@ def detect_te_model(sd):
else:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
# Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512).
if "layers.0.self_attn.sinks" in sd and "layers.0.mlp.experts.gate_up_proj.weight" in sd:
return TEModel.GPT_OSS_20B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
if weight.shape[0] == 256:
@ -1508,8 +1530,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.tokenizer = variant.tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
if clip_type == CLIPType.PIXELDIT:
clip_target.clip = comfy.text_encoders.pixeldit.pixeldit_te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer
else:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
@ -1544,6 +1570,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
elif te_model == TEModel.GPT_OSS_20B:
clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.QWEN3_4B:
if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2:
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b")
@ -1710,12 +1740,52 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if output_model and out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
if output_clip and out[1] is not None:
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
if out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
# Register reload factories for the CLIP and VAE produced by the same checkpoint so
# ModelPatcher.deepclone_multigpu can spawn per-device copies (Select{CLIP,VAE}Device,
# MultiGPU work-units, etc.) without falling back to copy.deepcopy of an
# already-loaded module.
if out[1] is not None and getattr(out[1], "patcher", None) is not None:
out[1].patcher.cached_patcher_init = (load_checkpoint_clip_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
if out[2] is not None and getattr(out[2], "patcher", None) is not None:
out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
return out
def load_checkpoint_clip_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
"""Reload only the CLIP patcher from a checkpoint. Used as the cached_patcher_init
factory for the CLIP returned by load_checkpoint_guess_config."""
_, clip, _, _ = load_checkpoint_guess_config(
ckpt_path,
output_vae=False,
output_clip=True,
output_clipvision=False,
embedding_directory=embedding_directory,
output_model=False,
model_options=model_options,
te_model_options=te_model_options,
disable_dynamic=disable_dynamic,
)
return clip.patcher
def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
"""Reload only the VAE patcher from a checkpoint. Used as the cached_patcher_init
factory for the VAE returned by load_checkpoint_guess_config."""
_, _, vae, _ = load_checkpoint_guess_config(
ckpt_path,
output_vae=True,
output_clip=False,
output_clipvision=False,
embedding_directory=embedding_directory,
output_model=False,
model_options=model_options,
te_model_options=te_model_options,
disable_dynamic=disable_dynamic,
)
return vae.patcher
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
embedding_directory=embedding_directory,
@ -1742,7 +1812,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
load_device = model_options.get("load_device", model_management.get_torch_device())
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
@ -1782,13 +1852,15 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
offload_device = model_options.get("offload_device", model_management.unet_offload_device())
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd, metadata=metadata)
vae_device = model_options.get("load_device", None)
vae = VAE(sd=vae_sd, metadata=metadata, device=vae_device)
if output_clip:
if te_model_options.get("custom_operations", None) is None:
@ -1872,7 +1944,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device()
load_device = model_options.get("load_device", model_management.get_torch_device())
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None:
@ -1897,7 +1969,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
else:
logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device()
offload_device = model_options.get("offload_device", model_management.unet_offload_device())
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.quant_config is not None:
weight_dtype = None
@ -1939,6 +2011,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
return model
def load_vae_patcher(vae_path, metadata=None, device=None, disable_dynamic=False):
"""Reload a disk-backed VAE from ``vae_path`` and return its patcher.
Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so
:meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
fresh, untainted VAE patcher (no inherited per-device load state, no
in-place quantization fallout) for multigpu work-units and the
SelectVAEDevice node. The optional ``device`` matches the source loader's
VAE initialization path; the deepclone's ``load_device`` still controls
where the cloned patcher is targeted.
"""
if metadata is None:
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
else:
sd = comfy.utils.load_torch_file(vae_path)
vae = VAE(sd=sd, metadata=metadata, device=device)
vae.throw_exception_if_invalid()
return vae.patcher
def load_unet(unet_path, dtype=None):
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype})

View File

@ -30,6 +30,7 @@ import comfy.text_encoders.longcat_image
import comfy.text_encoders.ernie
import comfy.text_encoders.cogvideo
import comfy.text_encoders.hidream_o1
import comfy.text_encoders.pixeldit
from . import supported_models_base
from . import latent_formats
@ -829,6 +830,50 @@ class Flux2(Flux):
return None
class Lens(supported_models_base.BASE):
"""Microsoft Lens (3.8B dual-stream MMDiT, GPT-OSS-20B text features, Flux2 VAE)."""
unet_config = {
"image_model": "lens",
}
sampling_settings = {
"shift": 1.829, # Default mu for 1440x1440 (and any seq_len > 4300
}
unet_extra_config = {}
latent_format = latent_formats.Flux2
memory_usage_factor = 4.0
supported_inference_dtypes = [torch.bfloat16, torch.float32] # fp16 causes NaNs
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
def get_model(self, state_dict, prefix="", device=None):
return model_base.Lens(self, model_type=model_base.ModelType.FLUX, device=device)
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
for hint in ("gpt_oss.transformer.", ""):
full_prefix = "{}{}".format(pref, hint)
if "{}layers.0.self_attn.sinks".format(full_prefix) in state_dict:
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix)
return supported_models_base.ClipTarget(
comfy.text_encoders.gpt_oss.LensTokenizer,
comfy.text_encoders.gpt_oss.lens_te(**detect),
)
return supported_models_base.ClipTarget(
comfy.text_encoders.gpt_oss.LensTokenizer,
comfy.text_encoders.gpt_oss.lens_te(),
)
class GenmoMochi(supported_models_base.BASE):
unet_config = {
"image_model": "mochi_preview",
@ -1159,6 +1204,72 @@ class ZImagePixelSpace(ZImage):
def get_model(self, state_dict, prefix="", device=None):
return model_base.ZImagePixelSpace(self, device=device)
class PixelDiTT2I(supported_models_base.BASE):
unet_config = {
"image_model": "pixeldit_t2i",
}
unet_extra_config = {}
sampling_settings = {
"shift": 4.0, # 1024px stage 3 default; 2.0 for 512px
}
latent_format = latent_formats.PixelDiTPixel
memory_usage_factor = 0.04
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
return model_base.PixelDiTT2I(self, device=device)
def process_unet_state_dict(self, state_dict):
# pixel_dim from pixel_embedder.proj.weight = (pixel_dim, in_channels); p2 derived per-weight from total // (6 * pixel_dim).
pixel_dim = next(v for k, v in state_dict.items() if k.endswith("pixel_embedder.proj.weight")).shape[0]
out = {}
marker = ".adaLN_modulation.0."
for k, v in state_dict.items():
if k.startswith("_repa_projector") or k.startswith("net_ema."):
continue
if k.startswith("core."):
k = k[len("core."):]
elif k.startswith("net."):
k = k[len("net."):]
if "pixel_blocks." in k and marker in k:
# Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM
p2 = v.shape[0] // (6 * pixel_dim)
trail = v.shape[1:] # () for bias, (in_dim,) for weight
vv = v.view(p2, 6, pixel_dim, *trail)
base, suffix = k.split(marker)
out[f"{base}.adaLN_modulation_msa.{suffix}"] = vv[:, 0:3].reshape(3 * p2 * pixel_dim, *trail).contiguous()
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = vv[:, 3:6].reshape(3 * p2 * pixel_dim, *trail).contiguous()
else:
out[k] = v
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(
comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer,
comfy.text_encoders.pixeldit.PixelDiTGemma2TE,
)
class PiD(PixelDiTT2I):
unet_config = {
"image_model": "pid",
}
sampling_settings = {
"shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0]
}
memory_usage_factor = 0.04
def get_model(self, state_dict, prefix="", device=None):
return model_base.PiD(self, device=device)
class WAN21_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "wan2.1",
@ -2069,6 +2180,8 @@ models = [
CosmosI2VPredict2,
ZImagePixelSpace,
ZImage,
PiD,
PixelDiTT2I,
Lumina2,
WAN22_T2V,
WAN21_CausalAR_T2V,
@ -2096,6 +2209,7 @@ models = [
Omnigen2,
QwenImage,
Flux2,
Lens,
Kandinsky5Image,
Kandinsky5,
Anima,

View File

@ -0,0 +1,600 @@
"""GPT-OSS text encoder for Lens."""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any, List, Optional, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
from comfy import sd1_clip
from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device
from comfy.text_encoders.llama import RMSNorm, apply_rope
@dataclass
class GptOss20BConfig:
vocab_size: int = 201088
hidden_size: int = 2880
intermediate_size: int = 2880
num_hidden_layers: int = 24
num_attention_heads: int = 64
num_key_value_heads: int = 8
head_dim: int = 64
num_local_experts: int = 32
num_experts_per_tok: int = 4
sliding_window: int = 128
original_max_position_embeddings: int = 4096
rope_theta: float = 150000.0
rope_factor: float = 32.0
rope_beta_fast: float = 32.0
rope_beta_slow: float = 1.0
rope_truncate: bool = False
rms_norm_eps: float = 1e-5
attention_bias: bool = True
layer_types: Optional[List[str]] = None
moe_alpha: float = 1.702
moe_limit: float = 7.0
def __post_init__(self):
if self.layer_types is None:
self.layer_types = [
"sliding_attention" if (i + 1) % 2 else "full_attention"
for i in range(self.num_hidden_layers)
]
def _yarn_inv_freq(head_dim: int, base: float, factor: float, beta_fast: float, beta_slow: float,
original_max_position_embeddings: int, truncate: bool, device=None) -> tuple[torch.Tensor, float]:
"""YARN inv_freq + attention scaling (matches transformers)."""
dim = head_dim
def find_correction_dim(num_rotations: float) -> float:
return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
def find_correction_range() -> tuple[float, float]:
low = find_correction_dim(beta_fast)
high = find_correction_dim(beta_slow)
if truncate:
low = math.floor(low)
high = math.ceil(high)
return max(low, 0), min(high, dim - 1)
def linear_ramp_factor(min_: float, max_: float, n: int) -> torch.Tensor:
if min_ == max_:
max_ += 0.001
linear = (torch.arange(n, dtype=torch.float32, device=device) - min_) / (max_ - min_)
return torch.clamp(linear, 0, 1)
def get_mscale(scale: float) -> float:
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
attention_scaling = get_mscale(factor)
pos_freqs = base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
low, high = find_correction_range()
extrap_factor = 1 - linear_ramp_factor(low, high, dim // 2)
inv_freq = inv_freq_interpolation * (1 - extrap_factor) + inv_freq_extrapolation * extrap_factor
return inv_freq, attention_scaling
def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_ids: torch.Tensor, dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
inv_freq_e = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
pos_e = position_ids[:, None, :].float()
freqs = (inv_freq_e @ pos_e).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = (emb.cos() * attention_scaling).to(dtype).unsqueeze(1)
sin = (emb.sin() * attention_scaling).to(dtype).unsqueeze(1)
sin_split = sin.shape[-1] // 2
return cos, sin[..., :sin_split], -sin[..., sin_split:]
def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sinks: torch.Tensor,
attention_mask: Optional[torch.Tensor], num_heads: int, num_kv_groups: int) -> torch.Tensor:
"""Attention with per-head sinks.
Sinks add a learned term to each row's softmax denominator but contribute
nothing to the output. We fake this by appending one zero k/v position and
putting the sink logit in the mask at that column.
"""
if num_kv_groups > 1 and not TORCH_HAS_GQA:
k = k.repeat_interleave(num_kv_groups, dim=1)
v = v.repeat_interleave(num_kv_groups, dim=1)
B, _, S_q, D = q.shape
H_kv = k.shape[1]
S_kv = k.shape[-2]
k = torch.cat([k, k.new_zeros(B, H_kv, 1, D)], dim=-2)
v = torch.cat([v, v.new_zeros(B, H_kv, 1, D)], dim=-2)
sinks_col = sinks.to(q.dtype).view(1, num_heads, 1, 1).expand(B, num_heads, S_q, 1)
if attention_mask is not None:
mask_left = attention_mask[..., :S_kv].expand(B, num_heads, S_q, S_kv)
else:
mask_left = q.new_zeros(B, num_heads, S_q, S_kv)
mask = torch.cat([mask_left, sinks_col], dim=-1)
op = optimized_attention_for_device(q.device, mask=True, small_input=True)
return op(q, k, v, num_heads, mask=mask, skip_reshape=True, enable_gqa=True)
class GptOssAttention(nn.Module):
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
super().__init__()
self.layer_idx = layer_idx
self.layer_type = config.layer_types[layer_idx]
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
bias = config.attention_bias
self.q_proj = ops.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=bias, device=device, dtype=dtype)
self.sinks = nn.Parameter(torch.empty(self.num_heads, device=device, dtype=dtype))
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], freqs_cis) -> torch.Tensor:
B, S, _ = hidden_states.shape
q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
q, k = apply_rope(q, k, freqs_cis)
out = _attention_with_sinks(q, k, v, self.sinks, attention_mask, self.num_heads, self.num_kv_groups)
return self.o_proj(out)
# Mixture of Experts
class GptOssTopKRouter(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype))
self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype))
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
weight = comfy.ops.cast_to_input(self.weight, hidden_states, copy=False)
bias = comfy.ops.cast_to_input(self.bias, hidden_states, copy=False)
logits = F.linear(hidden_states, weight, bias)
top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1)
# Softmax over top-k slice only
scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype)
return scores, top_idx
class GptOssExperts(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.alpha = config.moe_alpha
self.limit = config.moe_limit
E = self.num_experts
H = self.hidden_size
I = self.intermediate_size
self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype)
self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype)
def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
gate = gate_up[..., ::2]
up = gate_up[..., 1::2]
gate = gate.clamp(max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
return torch.addcmul(glu, up, glu)
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
N = hidden_states.shape[0]
top_k = router_indices.shape[-1]
H = hidden_states.shape[-1]
per_pair = torch.zeros((N * top_k, H), dtype=hidden_states.dtype, device=hidden_states.device)
expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
with self.gate_up_proj.bank_resident(hidden_states) as gate_up_bank, \
self.down_proj.bank_resident(hidden_states) as down_bank:
for ei in expert_hit:
expert_idx = int(ei.item())
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current = hidden_states[token_idx]
gate_up = gate_up_bank.expert_linear(current, expert_idx)
gated = self._apply_gate(gate_up)
expert_out = down_bank.expert_linear(gated, expert_idx)
weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
flat_idx = token_idx * top_k + top_k_pos
per_pair[flat_idx] = weighted.to(per_pair.dtype)
return per_pair.view(N, top_k, H).sum(dim=1)
class GptOssMLP(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
super().__init__()
self.router = GptOssTopKRouter(config, device=device, dtype=dtype)
self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
B, S, H = hidden_states.shape
flat = hidden_states.reshape(-1, H)
scores, idx = self.router(flat)
out = self.experts(flat, idx, scores)
return out.reshape(B, S, H)
# Decoder layer + model
class GptOssDecoderLayer(nn.Module):
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops)
self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.layer_type = config.layer_types[layer_idx]
def forward(self, x: torch.Tensor, attention_masks: dict[str, Optional[torch.Tensor]], freqs_cis) -> torch.Tensor:
residual = x
x = self.input_layernorm(x)
x = self.self_attn(x, attention_masks[self.layer_type], freqs_cis)
x = residual + x
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = residual + x
return x
def _make_full_causal_mask(B: int, S: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
neg = torch.finfo(dtype).min
mask = torch.full((S, S), neg, dtype=dtype, device=device).triu_(1)
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
if key_padding_mask is not None:
kp = key_padding_mask.to(dtype=dtype)
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
mask = mask + kp
return mask
def _make_sliding_causal_mask(B: int, S: int, window: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
neg = torch.finfo(dtype).min
i = torch.arange(S, device=device).view(-1, 1)
j = torch.arange(S, device=device).view(1, -1)
keep = (j <= i) & (j > i - window)
mask = torch.where(keep, torch.zeros((), dtype=dtype, device=device), torch.full((), neg, dtype=dtype, device=device))
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
if key_padding_mask is not None:
kp = key_padding_mask.to(dtype=dtype)
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
mask = mask + kp
return mask
class GptOssModel(nn.Module):
"""GPT-OSS decoder with multi-layer hidden-state capture + early exit."""
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
super().__init__()
self.config = config
self.dtype = dtype
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList(
[
GptOssDecoderLayer(config, i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
# Always build on CPU so the buffer survives meta-device construction.
inv_freq, attn_scaling = _yarn_inv_freq(
head_dim=config.head_dim,
base=config.rope_theta,
factor=config.rope_factor,
beta_fast=config.rope_beta_fast,
beta_slow=config.rope_beta_slow,
original_max_position_embeddings=config.original_max_position_embeddings,
truncate=config.rope_truncate,
device=torch.device("cpu"),
)
self.register_buffer("rope_inv_freq", inv_freq, persistent=False)
self.rope_attention_scaling = float(attn_scaling)
@property
def num_layers(self) -> int:
return self.config.num_hidden_layers
def get_input_embeddings(self):
return self.embed_tokens
def _build_attention_masks(self, B: int, S: int, attention_mask: Optional[torch.Tensor], dtype: torch.dtype, device,
) -> dict[str, torch.Tensor]:
full = _make_full_causal_mask(B, S, attention_mask, dtype, device)
masks = {"full_attention": full}
if any(t == "sliding_attention" for t in self.config.layer_types):
masks["sliding_attention"] = _make_sliding_causal_mask(
B, S, self.config.sliding_window, attention_mask, dtype, device
)
return masks
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None,
capture_layers: Optional[Sequence[int]] = None) -> dict[str, Any]:
B, S = input_ids.shape
device = input_ids.device
dtype = self.dtype
hidden_states = self.embed_tokens(input_ids, out_dtype=dtype)
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
freqs_cis = _build_freqs_cis(self.rope_inv_freq.to(device=device), self.rope_attention_scaling, position_ids, dtype)
attn_masks = self._build_attention_masks(B, S, attention_mask, dtype, device)
capture_layers = list(capture_layers) if capture_layers else None
if capture_layers:
max_layer = max(capture_layers)
wanted = {idx: pos for pos, idx in enumerate(capture_layers)}
captured: List[Optional[torch.Tensor]] = [None] * len(capture_layers)
else:
max_layer = self.config.num_hidden_layers - 1
wanted = None
captured = None
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, attn_masks, freqs_cis)
if wanted is not None and i in wanted:
captured[wanted[i]] = hidden_states
if i >= max_layer:
break
if captured is not None:
return {"hidden_states": captured}
return {"last_hidden_state": self.norm(hidden_states)}
# Lens chat-template constants (verbatim from the reference pipeline).
_LENS_CHAT_SYSTEM = (
"Describe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background."
)
_LENS_CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description."
LENS_TXT_OFFSET = 97
LENS_SELECTED_LAYERS = (5, 11, 17, 23)
LENS_MAX_TOKENS = 512
# The reference GPT-OSS Harmony template injects today's date here
_LENS_CHAT_DATE = "2026-05-23"
def _lens_render_chat(prompt: str) -> str:
"""Render the Lens prompt in GPT-OSS Harmony format."""
return (
f"<|start|>system<|message|>"
f"You are ChatGPT, a large language model trained by OpenAI.\n"
f"Knowledge cutoff: 2024-06\n"
f"Current date: {_LENS_CHAT_DATE}\n\n"
f"Reasoning: medium\n\n"
f"# Valid channels: analysis, commentary, final. "
f"Channel must be included for every message.<|end|>"
f"<|start|>developer<|message|># Instructions\n\n"
f"{_LENS_CHAT_SYSTEM}\n\n<|end|>"
f"<|start|>user<|message|>{prompt}<|end|>"
f"<|start|>assistant<|channel|>analysis<|message|>"
f"{_LENS_CHAT_ASSISTANT_THINKING}<|end|>"
f"<|start|>assistant<|channel|>final<|message|>"
)
# GPT-OSS-20B fixed token IDs (from the tokenizer's added-tokens table).
_LENS_PAD_TOKEN_ID = 199999 # <|endoftext|>
class _GptOssRawTokenizer:
"""Raw ``tokenizers.Tokenizer`` wrapper.
The tokenizer JSON ships as a byte tensor inside the encoder checkpoint
(``tokenizer_json`` key) rather than as a committed file. Extracted
it in ``sd.py`` and passes it here via ``tokenizer_data``.
"""
def __init__(self, tokenizer_json_bytes=None, **kwargs):
from tokenizers import Tokenizer
if isinstance(tokenizer_json_bytes, torch.Tensor):
tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
if tokenizer_json_bytes is None:
raise ValueError(
"Lens tokenizer requires the ``tokenizer_json`` byte tensor in the "
"encoder state dict. Re-bundle the encoder via bundle_te.py so it "
"embeds the tokenizer."
)
self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))
@classmethod
def from_pretrained(cls, tokenizer_data, **kwargs):
return cls(tokenizer_json_bytes=tokenizer_data, **kwargs)
def __call__(self, text):
return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids}
def get_vocab(self):
return self.tokenizer.get_vocab()
def convert_tokens_to_ids(self, tokens):
return [self.tokenizer.token_to_id(t) for t in tokens]
def decode(self, ids, **kwargs):
return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False))
class LensGptOssTokenizer(sd1_clip.SDTokenizer):
tokenizer_json_data = None
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_json = tokenizer_data.get("tokenizer_json", None)
self.tokenizer_json_data = tokenizer_json
super().__init__(
tokenizer_json,
embedding_directory=embedding_directory,
pad_with_end=False,
embedding_size=2880,
embedding_key="gpt_oss",
tokenizer_class=_GptOssRawTokenizer,
has_start_token=False,
has_end_token=False,
pad_to_max_length=False,
max_length=99999999,
min_length=1,
pad_left=False,
disable_weights=True,
tokenizer_data=tokenizer_data,
)
self.pad_token_id = _LENS_PAD_TOKEN_ID
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
# Empty prompt -> empty list; encode_token_weights returns zeros (uncond).
if not text or not text.strip():
return [[]]
rendered = _lens_render_chat(text)
ids = self.tokenizer(rendered)["input_ids"]
if len(ids) > LENS_MAX_TOKENS:
ids = ids[:LENS_MAX_TOKENS]
return [[(int(t), 1.0) for t in ids]]
def state_dict(self):
if self.tokenizer_json_data is not None:
return {"tokenizer_json": self.tokenizer_json_data}
return {}
class LensTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory,
tokenizer_data=tokenizer_data,
name="gpt_oss",
tokenizer=LensGptOssTokenizer,
)
class LensGptOssClipModel(nn.Module):
"""SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor)."""
def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
super().__init__()
model_options = dict(model_options or {})
operations = model_options.get("custom_operations")
if operations is None:
quant_config = model_options.get("quantization_metadata") or {}
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
self.operations = operations
cfg_overrides = model_options.get("gpt_oss_config", {})
self.config = GptOss20BConfig(**cfg_overrides)
self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS))
self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET))
self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations)
self.num_layers = self.config.num_hidden_layers
self.dtype = dtype
self.execution_device = None
self._pad_token_id = _LENS_PAD_TOKEN_ID
def set_clip_options(self, options):
self.execution_device = options.get("execution_device", self.execution_device)
def reset_clip_options(self):
self.execution_device = None
def _gather_tokens(self, token_weight_pairs):
ids_list = [[int(t[0]) for t in batch] for batch in token_weight_pairs]
pad_id = self._pad_token_id
max_len = max(len(x) for x in ids_list)
device = self.execution_device
ids = torch.full((len(ids_list), max_len), pad_id, dtype=torch.long, device=device)
mask = torch.zeros((len(ids_list), max_len), dtype=torch.long, device=device)
for i, x in enumerate(ids_list):
ids[i, : len(x)] = torch.tensor(x, dtype=torch.long, device=device)
mask[i, : len(x)] = 1
return ids, mask
def encode_token_weights(self, token_weight_pairs):
# Empty negative: emit zero-length features + zero mask
if all(len(batch) == 0 for batch in token_weight_pairs):
device = self.execution_device
B = len(token_weight_pairs)
L = len(self.selected_layers)
H = self.config.hidden_size
flat = torch.zeros(B, 0, L * H, dtype=self.dtype, device=device)
mask = torch.zeros(B, 0, dtype=torch.long, device=device)
return flat, None, {"attention_mask": mask, "num_layers_stacked": L}
input_ids, attn_mask = self._gather_tokens(token_weight_pairs)
out = self.transformer(input_ids, attention_mask=attn_mask, capture_layers=self.selected_layers)
layers = out["hidden_states"] # list of L × [B, S, H]
stacked = torch.stack(layers, dim=2) # [B, S, L, H]
offset = self.txt_offset
if stacked.shape[1] > offset:
stacked = stacked[:, offset:].contiguous()
mask_trim = attn_mask[:, offset:]
else:
stacked = stacked[:, :0]
mask_trim = attn_mask[:, :0]
B, S, L, H = stacked.shape
flat = stacked.reshape(B, S, L * H)
extra = {"attention_mask": mask_trim, "num_layers_stacked": L}
return flat, None, extra
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=True)
class LensTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {})
def lens_te(dtype_llama=None, llama_quantization_metadata=None):
class LensTEModel_(LensTEModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
mo = dict(model_options or {})
if llama_quantization_metadata is not None:
mo["quantization_metadata"] = llama_quantization_metadata
if dtype is None and dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=mo)
return LensTEModel_

View File

@ -0,0 +1,104 @@
import torch
from comfy import sd1_clip
from .lumina2 import Gemma2BTokenizer, LuminaModel
import comfy.text_encoders.llama
class PixelDiTGemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(
device=device, layer=layer, layer_idx=layer_idx,
textmodel_json_config={}, dtype=dtype,
special_tokens={"start": 2, "pad": 0},
layer_norm_hidden_state=False,
model_class=comfy.text_encoders.llama.Gemma2_2B,
enable_attention_masks=attention_mask,
return_attention_masks=attention_mask,
model_options=model_options,
)
_PIXELDIT_CHI_PROMPT = (
'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions '
"suitable for image generation. Evaluate the level of detail in the user prompt:\n"
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, "
"and spatial relationships to create vivid and concrete scenes.\n"
"- If the prompt is already detailed, refine and enhance the existing details slightly without "
"overcomplicating.\n"
"Here are examples of how to transform or refine prompts:\n"
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, "
"sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n"
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring "
"glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus "
"passing by towering glass skyscrapers.\n"
"Please generate only the enhanced description for the prompt below and avoid including any "
"additional commentary or evaluations:\n"
"User Prompt: "
)
_PIXELDIT_MAX_LENGTH = 300
_PIXELDIT_CHI_PROMPT_DETECT_PREFIX = 'Given a user prompt, generate an "Enhanced prompt"'
class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
name="gemma2_2b", tokenizer=Gemma2BTokenizer)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
if not text.strip():
return super().tokenize_with_weights("", return_word_ids=return_word_ids, disable_weights=True, min_length=_PIXELDIT_MAX_LENGTH)
chi_token_count = len(self.gemma2_2b.tokenizer(_PIXELDIT_CHI_PROMPT)["input_ids"])
combined = text if text.startswith(_PIXELDIT_CHI_PROMPT_DETECT_PREFIX) else _PIXELDIT_CHI_PROMPT + text
max_length_all = chi_token_count + _PIXELDIT_MAX_LENGTH - 2
out = super().tokenize_with_weights(combined, return_word_ids=return_word_ids,
disable_weights=True, min_length=max_length_all)
out["gemma2_2b"] = [out["gemma2_2b"][0][:max_length_all]]
return out
def untokenize(self, token_weight_pair):
return self.gemma2_2b.untokenize(token_weight_pair)
def state_dict(self):
return self.gemma2_2b.state_dict()
class PixelDiTGemma2TE(LuminaModel):
# PixelDiT's select_index: keep BOS + last 299 embeddings of the padded sequence.
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="gemma2_2b",
clip_model=PixelDiTGemma2_2BModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
result = super().encode_token_weights(token_weight_pairs)
cond, pooled = result[0], result[1]
extra = result[2] if len(result) > 2 else None
if cond.shape[1] > _PIXELDIT_MAX_LENGTH:
cond = torch.cat([cond[:, :1], cond[:, -(_PIXELDIT_MAX_LENGTH - 1):]], dim=1)
if extra is not None and "attention_mask" in extra:
am = extra["attention_mask"]
extra["attention_mask"] = torch.cat([am[..., :1], am[..., -(_PIXELDIT_MAX_LENGTH - 1):]], dim=-1)
if extra is not None:
return cond, pooled, extra
return cond, pooled
def pixeldit_te(dtype_llama=None, llama_quantization_metadata=None):
class PixelDiTTE_(PixelDiTGemma2TE):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixelDiTTE_

View File

@ -86,6 +86,7 @@ def load_safetensors(ckpt):
import comfy_aimdo.model_mmap
f = open(ckpt, "rb", buffering=0)
file_lock = threading.Lock()
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
file_size = os.path.getsize(ckpt)
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
@ -111,7 +112,7 @@ def load_safetensors(ckpt):
storage = tensor.untyped_storage()
setattr(storage,
"_comfy_tensor_file_slice",
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
sd[name] = tensor

View File

@ -1,5 +1,3 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase

View File

@ -1,4 +1,3 @@
from __future__ import annotations
from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction

View File

@ -1,4 +1,3 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from fractions import Fraction

View File

@ -3,7 +3,6 @@
# timestamp: 2025-07-30T08:54:00+00:00
# pylint: disable
from __future__ import annotations
from datetime import date, datetime
from enum import Enum

View File

@ -1,5 +1,3 @@
from __future__ import annotations
from enum import Enum
from typing import Any, Dict, Optional

View File

@ -1,5 +1,3 @@
from __future__ import annotations
from enum import Enum
from typing import Optional

View File

@ -1,4 +1,3 @@
from __future__ import annotations
from typing import Type, Literal
import nodes

View File

@ -1,5 +1,3 @@
from __future__ import annotations
from typing import TypedDict, Dict, Optional, Tuple
from typing_extensions import override
from PIL import Image

View File

@ -1,4 +1,3 @@
from __future__ import annotations
from comfy_api.latest import IO

View File

@ -2,7 +2,6 @@
+ weighted Procrustes solver. Computes the 4x4 facial transformation matrix.
"""
from __future__ import annotations
import math
import numpy as np

View File

@ -1,7 +1,6 @@
"""Pure-PyTorch port of MediaPipe's face_landmarker_v2_with_blendshapes.task:
BlazeFace detector FaceMesh v2 ARKit-52 blendshapes."""
from __future__ import annotations
import math
from functools import lru_cache

View File

@ -1,5 +1,3 @@
from __future__ import annotations
import av
import torchaudio
import torch

View File

@ -57,24 +57,55 @@ class CFGNorm(io.ComfyNode):
inputs=[
io.Model.Input("model"),
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
io.Boolean.Input(
"pre_cfg",
default=False,
optional=True,
tooltip=(
"If true, rescale the combined noise BEFORE the sampler's CFG combine, "
"without clamping (can amplify). Matches the norm-scaled CFG used by "
"models like Lens. Default false keeps the original post-CFG x0-space "
"attenuate-only behavior."
),
),
],
outputs=[io.Model.Output(display_name="patched_model")],
is_experimental=True,
)
@classmethod
def execute(cls, model, strength) -> io.NodeOutput:
def execute(cls, model, strength, pre_cfg=False) -> io.NodeOutput:
m = model.clone()
def cfg_norm(args):
cond_p = args['cond_denoised']
pred_text_ = args["denoised"]
if pre_cfg:
def cfg_norm_pre(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
comb = uncond + cond_scale * (cond - uncond)
cond_norm = torch.linalg.vector_norm(cond, dim=1, keepdim=True)
comb_norm = torch.linalg.vector_norm(comb, dim=1, keepdim=True)
rescale = torch.where(
comb_norm > 0,
cond_norm / comb_norm.clamp_min(1e-12),
torch.ones_like(comb_norm),
)
rescaled = comb * rescale
# strength blends back toward standard linear CFG (1.0 = full rescale).
if strength != 1.0:
rescaled = strength * rescaled + (1.0 - strength) * comb
return rescaled
m.set_model_sampler_cfg_function(cfg_norm_pre)
else:
def cfg_norm(args):
cond_p = args['cond_denoised']
pred_text_ = args["denoised"]
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
return pred_text_ * scale * strength
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
return pred_text_ * scale * strength
m.set_model_sampler_post_cfg_function(cfg_norm)
m.set_model_sampler_post_cfg_function(cfg_norm)
return io.NodeOutput(m)

View File

@ -1,4 +1,3 @@
from __future__ import annotations
from comfy_api.latest import ComfyExtension, io
import comfy.context_windows
import nodes

View File

@ -1,5 +1,3 @@
from __future__ import annotations
import numpy as np
from comfy_api.latest import ComfyExtension, io

View File

@ -1,5 +1,3 @@
from __future__ import annotations
import nodes
import folder_paths

View File

@ -1,4 +1,3 @@
from __future__ import annotations
from typing import TypedDict
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io

View File

@ -226,10 +226,20 @@ def get_noise_mask(latent):
noise_mask = noise_mask.clone()
return noise_mask
def get_keyframe_idxs(cond):
def get_keyframe_idxs(cond, latent_shape=None):
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
if keyframe_idxs is None:
return None, 0
# Get number of keyframes from latent_shape or guide_attention_entries if available
if latent_shape is not None and len(latent_shape) == 5:
tokens_per_frame = latent_shape[-2] * latent_shape[-1]
num_keyframes = keyframe_idxs.shape[2] // tokens_per_frame
return keyframe_idxs, num_keyframes
entries = conditioning_get_any_value(cond, "guide_attention_entries", None)
if entries:
num_keyframes = sum(e["latent_shape"][0] for e in entries)
return keyframe_idxs, num_keyframes
# fallback, may under-count if keyframes share t-start
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
return keyframe_idxs, num_keyframes
@ -322,9 +332,9 @@ class LTXVAddGuide(io.ComfyNode):
return factor
@classmethod
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors, latent_shape=None):
time_scale_factor, _, _ = scale_factors
_, num_keyframes = get_keyframe_idxs(cond)
_, num_keyframes = get_keyframe_idxs(cond, latent_shape)
latent_count = latent_length - num_keyframes
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
if guide_length > 1 and frame_idx != 0:
@ -436,7 +446,7 @@ class LTXVAddGuide(io.ComfyNode):
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
resolved_frame_idx = frame_idx
if frame_idx < 0:
_, num_keyframes = get_keyframe_idxs(positive)
_, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
resolved_frame_idx = max((latent_length - num_keyframes - 1) * time_scale_factor + 1 + frame_idx, 0)
causal_fix = resolved_frame_idx == 0 or num_frames_to_keep == 1
@ -454,7 +464,7 @@ class LTXVAddGuide(io.ComfyNode):
if latent_downscale_factor > 1:
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors, latent_shape=latent_image.shape)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
positive, negative, latent_image, noise_mask = cls.append_keyframe(
@ -506,7 +516,7 @@ class LTXVCropGuides(io.ComfyNode):
latent_image = latent["samples"].clone()
noise_mask = get_noise_mask(latent)
_, num_keyframes = get_keyframe_idxs(positive)
_, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
if num_keyframes == 0:
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)

View File

@ -4,7 +4,6 @@ Provides a ComfyMathExpression node that evaluates math expressions
against dynamically-grown numeric inputs.
"""
from __future__ import annotations
import math
import string

View File

@ -10,7 +10,6 @@ Custom IO types:
MediaPipeFaceLandmarker also emits the core BOUNDING_BOX type pair with DrawBBoxes.
"""
from __future__ import annotations
import numpy as np
import torch

View File

@ -1,6 +1,5 @@
"""ComfyUI nodes for the native MoGe (Monocular Geometry Estimation) integration."""
from __future__ import annotations
import torch

View File

@ -0,0 +1,408 @@
from __future__ import annotations
import copy
import logging
from inspect import cleandoc
from typing import TYPE_CHECKING
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.sd import CLIP, VAE
import torch
import comfy.model_management
import comfy.multigpu
class MultiGPUCFGSplitNode(io.ComfyNode):
"""
Prepares model to have sampling accelerated via splitting work units.
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
Other than those exceptions, this node can be placed in any order.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MultiGPU_WorkUnits",
display_name="MultiGPU CFG Split",
category="advanced/multigpu",
description=cleandoc(cls.__doc__),
inputs=[
io.Model.Input("model"),
io.Int.Input("max_gpus", default=2, min=1, step=1),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
return io.NodeOutput(model)
def _force_supported_compute_dtype(patcher: ModelPatcher, device: torch.device):
"""Cast compute dtype to one the device supports; no-op if already supported."""
weight_dtype = patcher.model_dtype()
cast_dtype = comfy.model_management.unet_manual_cast(weight_dtype, device)
if cast_dtype is None:
return
logging.info(f"Select Model Device: using {cast_dtype} compute dtype on {device} (model weight dtype was {weight_dtype}).")
patcher.set_model_compute_dtype(cast_dtype)
def _remember_base_devices(patcher: ModelPatcher):
"""Stash the original load/offload device on the underlying model.
Stored on patcher.model (which is shared with the input patcher), so
later "default" selections can recover the loader's original routing.
Only the first Select on a given chain writes these attrs; subsequent
deepclones inherit them onto their freshly-loaded model below.
"""
if not hasattr(patcher.model, "_select_base_load_device"):
patcher.model._select_base_load_device = patcher.load_device
patcher.model._select_base_offload_device = patcher.offload_device
def _propagate_base_devices(src_model, dst_model):
"""Carry the loader-original device attrs onto the freshly-deepcloned model."""
if hasattr(src_model, "_select_base_load_device") and not hasattr(dst_model, "_select_base_load_device"):
dst_model._select_base_load_device = src_model._select_base_load_device
dst_model._select_base_offload_device = src_model._select_base_offload_device
def _retarget_patcher(patcher: ModelPatcher, target_load_device, target_offload_device):
"""Return a patcher whose actual model weights live on *target_load_device*.
If *patcher* is already on *target_load_device* we just retarget the
(already-cloned) patcher's metadata in place. Otherwise we call
:meth:`ModelPatcher.deepclone_multigpu` to spawn a fresh model from
the loader's ``cached_patcher_init`` factory -- the only safe way to
move weights that may already be partially loaded onto another device.
NOTE: reusing the input patcher's model when the requested device
matches its current load_device is a deliberate fast path. Anything
that has already mutated the original model (e.g. a prior KSampler
invocation on the same model) will be observed here. This is by
design and documented on the SelectXDeviceNode docstrings -- placing
Select X Device after a node that consumes the same model is not
recommended.
"""
if patcher.load_device == target_load_device:
# Fast path: weights already on the desired device, just update offload.
patcher.offload_device = target_offload_device
return patcher
src_model = patcher.model
patcher = patcher.deepclone_multigpu(new_load_device=target_load_device)
patcher.offload_device = target_offload_device
_propagate_base_devices(src_model, patcher.model)
if hasattr(patcher, "register_load_device"):
patcher.register_load_device(patcher.load_device)
return patcher
def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None):
"""Resolve the requested device and produce a patcher routed there.
For "default" we restore the loader's original load/offload pair.
For CPU we pin both load and offload to CPU (and, on a dynamic
patcher, downgrade to a plain ModelPatcher so the dynamic-only
code paths are bypassed).
For an explicit GPU we keep the loader's original offload but
target the requested load device; if that differs from the current
load device the patcher is deepcloned onto the new device.
"""
_remember_base_devices(patcher)
base_load = patcher.model._select_base_load_device
base_offload = base_offload_override if base_offload_override is not None else patcher.model._select_base_offload_device
if resolved is None:
# "default" -> route back to the loader's original devices.
return _retarget_patcher(patcher, base_load, base_offload)
if resolved.type == "cpu":
if patcher.is_dynamic():
# clone(disable_dynamic=True) requires cached_patcher_init; let the
# exception surface to the caller (Select*DeviceNode.execute), which
# will translate it into a passthrough+log so unsupported loaders
# don't hard-fail the workflow.
patcher = patcher.clone(disable_dynamic=True)
patcher.load_device = resolved
patcher.offload_device = resolved
return patcher
return _retarget_patcher(patcher, resolved, base_offload)
def _prune_multigpu_collision(model: ModelPatcher, primary_device):
"""Drop any multigpu clone whose load_device matches *primary_device*.
Without pruning, MultiGPU CFG Split would have stacked a clone on
the same device the primary now occupies (i.e. the workflow places
MultiGPU CFG Split before Select Model Device). Keeps the clone set
consistent with the new primary placement.
"""
multigpu_models = model.get_additional_models_with_key("multigpu")
if not multigpu_models:
return
filtered = [m for m in multigpu_models if m.load_device != primary_device]
if len(filtered) != len(multigpu_models):
logging.info(f"Select Model Device: pruning MultiGPU clone on {primary_device} that now collides with the primary model.")
model.set_additional_models("multigpu", filtered)
if hasattr(model, "match_multigpu_clones"):
model.match_multigpu_clones()
class SelectModelDeviceNode(io.ComfyNode):
"""
Place the diffusion model on a specific device (default / cpu / gpu:N).
- "default" restores the device assigned by the loader (even after a
prior Select Model Device call).
- "cpu" pins both the load and offload device to CPU.
- "gpu:N" pins the load device to the Nth available GPU; the offload
device is restored to the loader's original choice.
When the requested device differs from the device the input model is
already on, a fresh model is spawned via the loader's reload factory
(cached_patcher_init) so the new patcher owns independent weights on
the new device. Loaders that don't support multigpu (no factory) will
cause the node to pass through unchanged with a warning.
If the workflow already has MultiGPU CFG Split applied and the chosen
GPU collides with one of the existing multigpu clones, that clone is
dropped so two patchers don't end up bound to the same device.
When the selected device does not exist on the current machine
(e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
the node passes the model through unchanged and logs a message
instead of failing.
NOTE: Placing Select Model Device *after* a node that has already
consumed the same model (e.g. a KSampler that ran on this model on
the original device) is not recommended -- any state the prior
consumer mutated on the original model will be observed when the
selected device matches the original (fast path). Place Select Model
Device before any consumer of the model.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SelectModelDevice",
display_name="Select Model Device",
category="advanced/multigpu",
description=cleandoc(cls.__doc__),
inputs=[
io.Model.Input("model"),
io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options()),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def validate_inputs(cls, device="default"):
# Allow unknown gpu:N values so portable workflows do not error
# at validation time; runtime fallback will handle them.
return True
@classmethod
def execute(cls, model: ModelPatcher, device: str = "default") -> io.NodeOutput:
model = model.clone()
resolved = comfy.model_management.resolve_gpu_device_option(device)
if resolved is None and device not in (None, "default"):
logging.info(f"Select Model Device: requested device '{device}' not available, passing through unchanged.")
return io.NodeOutput(model)
try:
model = _apply_patcher_device(model, resolved)
except RuntimeError as e:
logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})")
return io.NodeOutput(model)
if resolved is not None:
_force_supported_compute_dtype(model, resolved)
_prune_multigpu_collision(model, model.load_device)
return io.NodeOutput(model)
class SelectCLIPDeviceNode(io.ComfyNode):
"""
Place the CLIP text encoder on a specific device (default / cpu / gpu:N).
- "default" restores the device assigned by the loader.
- "cpu" pins both the load and offload device to CPU.
- "gpu:N" pins the load device to the Nth available GPU.
When the selected device does not exist on the current machine
(e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
the node passes the CLIP through unchanged and logs a message
instead of failing.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SelectCLIPDevice",
display_name="Select CLIP Device",
category="advanced/multigpu",
description=cleandoc(cls.__doc__),
inputs=[
io.Clip.Input("clip"),
io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options()),
],
outputs=[
io.Clip.Output(),
],
)
@classmethod
def validate_inputs(cls, device="default"):
return True
@classmethod
def execute(cls, clip: CLIP, device: str = "default") -> io.NodeOutput:
clip = clip.clone()
resolved = comfy.model_management.resolve_gpu_device_option(device)
if resolved is None and device not in (None, "default"):
logging.info(f"Select CLIP Device: requested device '{device}' not available, passing through unchanged.")
return io.NodeOutput(clip)
try:
clip.patcher = _apply_patcher_device(clip.patcher, resolved)
except RuntimeError as e:
logging.warning(f"Select CLIP Device: cannot retarget CLIP, passing through unchanged. ({e})")
return io.NodeOutput(clip)
class SelectVAEDeviceNode(io.ComfyNode):
"""
Place the VAE on a specific device (default / gpu:N).
- "default" restores the device assigned by the loader.
- "gpu:N" pins the load device to the Nth available GPU; the offload
device is set to the standard VAE offload device.
CPU is intentionally not exposed in the UI for the VAE; if a workflow
supplies "cpu" anyway (e.g. opened from another machine), the request
is dropped with a log message and the VAE is passed through unchanged.
When the selected device does not exist on the current machine
(e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
the node passes the VAE through unchanged and logs a message
instead of failing.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SelectVAEDevice",
display_name="Select VAE Device",
category="advanced/multigpu",
description=cleandoc(cls.__doc__),
inputs=[
io.Vae.Input("vae"),
io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options_no_cpu()),
],
outputs=[
io.Vae.Output(),
],
)
@classmethod
def validate_inputs(cls, device="default"):
return True
@classmethod
def execute(cls, vae: VAE, device: str = "default") -> io.NodeOutput:
# VAE has no .clone(); shallow-copy the wrapper and clone the patcher
# so we can retarget load/offload device without affecting the input VAE.
vae = copy.copy(vae)
vae.patcher = vae.patcher.clone()
resolved = comfy.model_management.resolve_gpu_device_option(device)
if resolved is None and device not in (None, "default"):
logging.info(f"Select VAE Device: requested device '{device}' not available, passing through unchanged.")
return io.NodeOutput(vae)
if resolved is not None and resolved.type == "cpu":
logging.info("Select VAE Device: CPU is not a supported choice, passing through unchanged.")
return io.NodeOutput(vae)
if not hasattr(vae, "_select_base_device"):
vae._select_base_device = vae.device
try:
vae.patcher = _apply_patcher_device(
vae.patcher, resolved,
base_offload_override=comfy.model_management.vae_offload_device(),
)
except RuntimeError as e:
logging.warning(f"Select VAE Device: cannot retarget VAE, passing through unchanged. ({e})")
return io.NodeOutput(vae)
# Keep VAE wrapper in sync with whatever model the patcher now owns;
# deepclone_multigpu may have produced a fresh first_stage_model.
vae.first_stage_model = vae.patcher.model
vae.device = vae._select_base_device if resolved is None else resolved
return io.NodeOutput(vae)
class MultiGPUOptionsNode(io.ComfyNode):
"""
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
NOTE (not registered yet, see MultiGPUExtension.get_node_list below):
The output GPUOptionsGroup is plumbed through create_multigpu_deepclones() and stored on
model.model_options['multigpu_options'] via GPUOptionsGroup.register(), but the cond
scheduler in comfy/samplers.py (calc_cond_batch_outer_multigpu) does NOT yet consult
relative_speed when distributing conds across devices; it uses a uniform conds_per_device
round-robin via next_available_device(). Before re-enabling this node, wire its
relative_speed into the scheduler (e.g. via comfy.multigpu.load_balance_devices(),
which already implements the proportional split) so the input actually affects work
distribution.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MultiGPU_Options",
display_name="MultiGPU Options",
category="advanced/multigpu",
description=cleandoc(cls.__doc__),
inputs=[
io.Int.Input("device_index", default=0, min=0, max=64),
io.Float.Input("relative_speed", default=1.0, min=0.0, step=0.01),
io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True),
],
outputs=[
io.Custom("GPU_OPTIONS").Output(),
],
)
@classmethod
def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput:
if not gpu_options:
gpu_options = comfy.multigpu.GPUOptionsGroup()
else:
gpu_options = gpu_options.clone()
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
gpu_options.add(opt)
return io.NodeOutput(gpu_options)
class MultiGPUExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
MultiGPUCFGSplitNode,
SelectModelDeviceNode,
SelectCLIPDeviceNode,
SelectVAEDeviceNode,
# MultiGPUOptionsNode,
]
async def comfy_entrypoint() -> MultiGPUExtension:
return MultiGPUExtension()

View File

@ -4,7 +4,6 @@ Provides a single node that converts INT, FLOAT, STRING, and BOOL
inputs into FLOAT and INT outputs.
"""
from __future__ import annotations
import math

View File

@ -1,5 +1,3 @@
from __future__ import annotations
import hashlib
import os

Some files were not shown because too many files have changed in this diff Show More