mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 17:37:39 +08:00
Merge branch 'master' into matt/asset-image-dimensions-metadata
This commit is contained in:
commit
81c4bc5fe9
@ -160,10 +160,12 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
|
||||
preview_url = None
|
||||
else:
|
||||
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
|
||||
asset_content_hash = result.asset.hash if result.asset else None
|
||||
return schemas_out.Asset(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
asset_hash=result.asset.hash if result.asset else None,
|
||||
hash=asset_content_hash,
|
||||
asset_hash=asset_content_hash,
|
||||
size=int(result.asset.size_bytes) if result.asset else None,
|
||||
mime_type=result.asset.mime_type if result.asset else None,
|
||||
tags=result.tags,
|
||||
|
||||
@ -10,6 +10,7 @@ class Asset(BaseModel):
|
||||
|
||||
id: str
|
||||
name: str
|
||||
hash: str | None = None
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
|
||||
@ -4,7 +4,6 @@ Tier 1: Filesystem metadata (zero parsing)
|
||||
Tier 2: Safetensors header metadata (fast JSON read only)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import folder_paths
|
||||
import glob
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import base64
|
||||
import json
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
2091
blueprints/Audio Generation (Stable Audio 3 Medium Base).json
Normal file
2091
blueprints/Audio Generation (Stable Audio 3 Medium Base).json
Normal file
File diff suppressed because one or more lines are too long
2091
blueprints/Audio Generation (Stable Audio 3 Medium).json
Normal file
2091
blueprints/Audio Generation (Stable Audio 3 Medium).json
Normal file
File diff suppressed because one or more lines are too long
@ -1553,7 +1553,7 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Canny to image",
|
||||
"category": "Image generation and editing/Conditioned",
|
||||
"description": "Generates an image from a Canny edge map using Z-Image-Turbo, with text conditioning."
|
||||
}
|
||||
]
|
||||
|
||||
@ -3600,7 +3600,7 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Canny to video",
|
||||
"category": "Video generation and editing/Conditioned",
|
||||
"description": "Generates video from Canny edge maps using LTX-2, with optional synchronized audio."
|
||||
}
|
||||
]
|
||||
|
||||
@ -1401,7 +1401,7 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/ControlNet",
|
||||
"category": "Image generation and editing/Conditioned",
|
||||
"description": "Generates images from a text prompt and ControlNet conditioning (e.g. depth, canny) using Z-Image-Turbo."
|
||||
}
|
||||
]
|
||||
|
||||
@ -1579,7 +1579,7 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Depth to image",
|
||||
"category": "Image generation and editing/Conditioned",
|
||||
"description": "Generates an image from a depth map using Z-Image-Turbo with text conditioning."
|
||||
},
|
||||
{
|
||||
|
||||
@ -4233,7 +4233,7 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Depth to video",
|
||||
"category": "Video generation and editing/Conditioned",
|
||||
"description": "Generates depth-controlled video with LTX-2: motion and structure follow a depth-reference video alongside text prompting, optional first-frame image conditioning, with optional synchronized audio."
|
||||
},
|
||||
{
|
||||
|
||||
@ -3350,7 +3350,7 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Video generation and editing/First-Last-Frame to Video",
|
||||
"category": "Video generation and editing/Conditioned",
|
||||
"description": "Generates a video interpolating between first and last keyframes using LTX-2.3."
|
||||
}
|
||||
]
|
||||
|
||||
@ -3350,7 +3350,7 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Video generation and editing/First-Last-Frame to Video",
|
||||
"category": "Video generation and editing/FLF2V",
|
||||
"description": "Generates a video that interpolates between the first and last keyframes using LTX-2.3, including optional audio."
|
||||
}
|
||||
]
|
||||
|
||||
1266
blueprints/Geometry Estimation (MoGe).json
Normal file
1266
blueprints/Geometry Estimation (MoGe).json
Normal file
File diff suppressed because it is too large
Load Diff
@ -310,9 +310,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Text generation/Image Captioning",
|
||||
"category": "Image Tools",
|
||||
"description": "Generates descriptive captions for images using Google's Gemini multimodal LLM."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,19 +1,18 @@
|
||||
{
|
||||
"id": "6af0a6c1-0161-4528-8685-65776e838d44",
|
||||
"revision": 0,
|
||||
"last_node_id": 75,
|
||||
"last_link_id": 245,
|
||||
"last_node_id": 76,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 75,
|
||||
"type": "488652fd-6edf-4d06-8f9f-4d84d3a34eaf",
|
||||
"id": 76,
|
||||
"type": "96338968-1242-4f02-b6a1-d496af4bcffe",
|
||||
"pos": [
|
||||
600,
|
||||
830
|
||||
670,
|
||||
1280
|
||||
],
|
||||
"size": [
|
||||
400,
|
||||
110
|
||||
201.3125
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
@ -59,47 +58,44 @@
|
||||
"links": []
|
||||
}
|
||||
],
|
||||
"title": "Image Depth Estimation (Lotus Depth)",
|
||||
"properties": {
|
||||
"proxyWidgets": [
|
||||
[
|
||||
"-1",
|
||||
"28",
|
||||
"sigma"
|
||||
],
|
||||
[
|
||||
"-1",
|
||||
"10",
|
||||
"unet_name"
|
||||
],
|
||||
[
|
||||
"-1",
|
||||
"14",
|
||||
"vae_name"
|
||||
]
|
||||
],
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.14.1"
|
||||
},
|
||||
"widgets_values": [
|
||||
999.0000000000002,
|
||||
"lotus-depth-d-v1-1.safetensors",
|
||||
"vae-ft-mse-840000-ema-pruned.safetensors"
|
||||
]
|
||||
"widgets_values": []
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"groups": [],
|
||||
"version": 0.4,
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "488652fd-6edf-4d06-8f9f-4d84d3a34eaf",
|
||||
"id": "96338968-1242-4f02-b6a1-d496af4bcffe",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 1,
|
||||
"lastNodeId": 75,
|
||||
"lastNodeId": 76,
|
||||
"lastLinkId": 245,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "Image to Depth Map (Lotus)",
|
||||
"name": "Image Depth Estimation (Lotus Depth)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -191,12 +187,12 @@
|
||||
"id": 10,
|
||||
"type": "UNETLoader",
|
||||
"pos": [
|
||||
108.05555555555557,
|
||||
-253.05555555555557
|
||||
110,
|
||||
-250
|
||||
],
|
||||
"size": [
|
||||
254.93706597222226,
|
||||
82
|
||||
260,
|
||||
90
|
||||
],
|
||||
"flags": {},
|
||||
"order": 4,
|
||||
@ -234,9 +230,9 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "UNETLoader",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "UNETLoader",
|
||||
"models": [
|
||||
{
|
||||
"name": "lotus-depth-d-v1-1.safetensors",
|
||||
@ -255,12 +251,12 @@
|
||||
"id": 18,
|
||||
"type": "DisableNoise",
|
||||
"pos": [
|
||||
607.0641494069639,
|
||||
-268.33337840371513
|
||||
610,
|
||||
-270
|
||||
],
|
||||
"size": [
|
||||
175,
|
||||
33.333333333333336
|
||||
180,
|
||||
40
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
@ -278,26 +274,25 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "DisableNoise",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "DisableNoise",
|
||||
"widget_ue_connectable": {}
|
||||
},
|
||||
"widgets_values": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 23,
|
||||
"id": 74,
|
||||
"type": "VAEEncode",
|
||||
"pos": [
|
||||
620,
|
||||
160
|
||||
],
|
||||
"size": [
|
||||
175,
|
||||
180,
|
||||
50
|
||||
],
|
||||
"flags": {},
|
||||
"order": 10,
|
||||
"order": 11,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
@ -325,12 +320,11 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "VAEEncode",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "VAEEncode",
|
||||
"widget_ue_connectable": {}
|
||||
},
|
||||
"widgets_values": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 21,
|
||||
@ -341,7 +335,7 @@
|
||||
],
|
||||
"size": [
|
||||
210,
|
||||
58
|
||||
60
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
@ -369,9 +363,9 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "KSamplerSelect",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "KSamplerSelect",
|
||||
"widget_ue_connectable": {}
|
||||
},
|
||||
"widgets_values": [
|
||||
@ -386,7 +380,7 @@
|
||||
-170
|
||||
],
|
||||
"size": [
|
||||
175,
|
||||
180,
|
||||
50
|
||||
],
|
||||
"flags": {},
|
||||
@ -418,12 +412,11 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "BasicGuider",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "BasicGuider",
|
||||
"widget_ue_connectable": {}
|
||||
},
|
||||
"widgets_values": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 16,
|
||||
@ -433,8 +426,8 @@
|
||||
-130
|
||||
],
|
||||
"size": [
|
||||
295.99609375,
|
||||
271.65798611111114
|
||||
300,
|
||||
280
|
||||
],
|
||||
"flags": {},
|
||||
"order": 6,
|
||||
@ -490,12 +483,11 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "SamplerCustomAdvanced",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "SamplerCustomAdvanced",
|
||||
"widget_ue_connectable": {}
|
||||
},
|
||||
"widgets_values": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 28,
|
||||
@ -506,10 +498,10 @@
|
||||
],
|
||||
"size": [
|
||||
210,
|
||||
58
|
||||
60
|
||||
],
|
||||
"flags": {},
|
||||
"order": 11,
|
||||
"order": 10,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
@ -540,9 +532,9 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "SetFirstSigma",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "SetFirstSigma",
|
||||
"widget_ue_connectable": {}
|
||||
},
|
||||
"widgets_values": [
|
||||
@ -557,7 +549,7 @@
|
||||
-120
|
||||
],
|
||||
"size": [
|
||||
175,
|
||||
180,
|
||||
50
|
||||
],
|
||||
"flags": {},
|
||||
@ -589,12 +581,11 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "VAEDecode",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "VAEDecode",
|
||||
"widget_ue_connectable": {}
|
||||
},
|
||||
"widgets_values": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 22,
|
||||
@ -604,8 +595,8 @@
|
||||
-220
|
||||
],
|
||||
"size": [
|
||||
175,
|
||||
33.333333333333336
|
||||
180,
|
||||
40
|
||||
],
|
||||
"flags": {},
|
||||
"order": 9,
|
||||
@ -630,12 +621,11 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "ImageInvert",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "ImageInvert",
|
||||
"widget_ue_connectable": {}
|
||||
},
|
||||
"widgets_values": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 14,
|
||||
@ -645,8 +635,8 @@
|
||||
-90
|
||||
],
|
||||
"size": [
|
||||
254.93706597222226,
|
||||
58
|
||||
260,
|
||||
60
|
||||
],
|
||||
"flags": {},
|
||||
"order": 5,
|
||||
@ -675,9 +665,9 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "VAELoader",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "VAELoader",
|
||||
"models": [
|
||||
{
|
||||
"name": "vae-ft-mse-840000-ema-pruned.safetensors",
|
||||
@ -692,15 +682,15 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 68,
|
||||
"id": 75,
|
||||
"type": "LotusConditioning",
|
||||
"pos": [
|
||||
400,
|
||||
-150
|
||||
],
|
||||
"size": [
|
||||
175,
|
||||
33.333333333333336
|
||||
180,
|
||||
40
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
@ -718,12 +708,11 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "LotusConditioning",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "LotusConditioning",
|
||||
"widget_ue_connectable": {}
|
||||
},
|
||||
"widgets_values": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 20,
|
||||
@ -734,7 +723,7 @@
|
||||
],
|
||||
"size": [
|
||||
210,
|
||||
106
|
||||
110
|
||||
],
|
||||
"flags": {},
|
||||
"order": 8,
|
||||
@ -786,9 +775,9 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "BasicScheduler",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "BasicScheduler",
|
||||
"widget_ue_connectable": {}
|
||||
},
|
||||
"widgets_values": [
|
||||
@ -850,7 +839,7 @@
|
||||
},
|
||||
{
|
||||
"id": 201,
|
||||
"origin_id": 23,
|
||||
"origin_id": 74,
|
||||
"origin_slot": 0,
|
||||
"target_id": 16,
|
||||
"target_slot": 4,
|
||||
@ -866,7 +855,7 @@
|
||||
},
|
||||
{
|
||||
"id": 238,
|
||||
"origin_id": 68,
|
||||
"origin_id": 75,
|
||||
"origin_slot": 0,
|
||||
"target_id": 19,
|
||||
"target_slot": 1,
|
||||
@ -892,7 +881,7 @@
|
||||
"id": 38,
|
||||
"origin_id": 14,
|
||||
"origin_slot": 0,
|
||||
"target_id": 23,
|
||||
"target_id": 74,
|
||||
"target_slot": 1,
|
||||
"type": "VAE"
|
||||
},
|
||||
@ -908,7 +897,7 @@
|
||||
"id": 37,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 23,
|
||||
"target_id": 74,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
@ -948,12 +937,11 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Depth to image",
|
||||
"category": "Conditioning & Preprocessors/Depth",
|
||||
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
|
||||
}
|
||||
]
|
||||
},
|
||||
"config": {},
|
||||
"extra": {
|
||||
"ds": {
|
||||
"scale": 1.3589709866044692,
|
||||
@ -961,8 +949,6 @@
|
||||
-138.53613935617864,
|
||||
-786.0629126022195
|
||||
]
|
||||
},
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
}
|
||||
1154
blueprints/Image Depth Estimation (MoGe).json
Normal file
1154
blueprints/Image Depth Estimation (MoGe).json
Normal file
File diff suppressed because it is too large
Load Diff
779
blueprints/Image Face Detection (Mediapipe).json
Normal file
779
blueprints/Image Face Detection (Mediapipe).json
Normal file
@ -0,0 +1,779 @@
|
||||
{
|
||||
"revision": 0,
|
||||
"last_node_id": 33,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 33,
|
||||
"type": "6062babb-b649-4a71-be9e-20ebce567744",
|
||||
"pos": [
|
||||
-450,
|
||||
4240
|
||||
],
|
||||
"size": [
|
||||
420,
|
||||
400
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "image",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "face_landmarker",
|
||||
"type": "FACE_LANDMARKER",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "detector_variant",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "detector_variant"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "num_faces",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "num_faces"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "custom_face_oval",
|
||||
"name": "regions.face_oval",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.face_oval"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "custom_lips",
|
||||
"name": "regions.lips",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.lips"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "custom_left_eye",
|
||||
"name": "regions.left_eye",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.left_eye"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "custom_right_eye",
|
||||
"name": "regions.right_eye",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.right_eye"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "custom_irises",
|
||||
"name": "regions.irises",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.irises"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "model_name",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "model_name"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "face_landmarks",
|
||||
"name": "face_landmarks",
|
||||
"type": "FACE_LANDMARKS",
|
||||
"links": []
|
||||
},
|
||||
{
|
||||
"localized_name": "bboxes",
|
||||
"name": "bboxes",
|
||||
"type": "BOUNDING_BOX",
|
||||
"links": []
|
||||
},
|
||||
{
|
||||
"label": "mask",
|
||||
"name": "MASK_1",
|
||||
"type": "MASK",
|
||||
"links": []
|
||||
}
|
||||
],
|
||||
"title": "Image Face Detection (Mediapipe)",
|
||||
"properties": {
|
||||
"proxyWidgets": [
|
||||
[
|
||||
"11",
|
||||
"detector_variant"
|
||||
],
|
||||
[
|
||||
"11",
|
||||
"num_faces"
|
||||
],
|
||||
[
|
||||
"20",
|
||||
"regions.face_oval"
|
||||
],
|
||||
[
|
||||
"20",
|
||||
"regions.lips"
|
||||
],
|
||||
[
|
||||
"20",
|
||||
"regions.left_eye"
|
||||
],
|
||||
[
|
||||
"20",
|
||||
"regions.right_eye"
|
||||
],
|
||||
[
|
||||
"20",
|
||||
"regions.irises"
|
||||
],
|
||||
[
|
||||
"2",
|
||||
"model_name"
|
||||
]
|
||||
],
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.22.0",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65
|
||||
},
|
||||
"widgets_values": []
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"version": 0.4,
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "6062babb-b649-4a71-be9e-20ebce567744",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 2,
|
||||
"lastNodeId": 158,
|
||||
"lastLinkId": 140,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "Image Face Detection (Mediapipe)",
|
||||
"description": "Detects facial landmarks from an image using MediaPipe, outputting landmark data, face bounding boxes, and an optional face-region mask.",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
-710,
|
||||
4300,
|
||||
148.880859375,
|
||||
248
|
||||
]
|
||||
},
|
||||
"outputNode": {
|
||||
"id": -20,
|
||||
"bounding": [
|
||||
140,
|
||||
4480,
|
||||
137.677734375,
|
||||
108
|
||||
]
|
||||
},
|
||||
"inputs": [
|
||||
{
|
||||
"id": "705dc1ae-6dc9-4155-92df-52f816ad451e",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
60
|
||||
],
|
||||
"localized_name": "image",
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4324
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "d6277190-732c-4604-b7cd-d3a9588bf761",
|
||||
"name": "face_landmarker",
|
||||
"type": "FACE_LANDMARKER",
|
||||
"linkIds": [
|
||||
74
|
||||
],
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4344
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "ac473a08-6a86-42a7-b460-e70c6c5e1e2b",
|
||||
"name": "detector_variant",
|
||||
"type": "COMBO",
|
||||
"linkIds": [
|
||||
75
|
||||
],
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4364
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "1bec2252-ca2d-496e-8a33-33a61d21f897",
|
||||
"name": "num_faces",
|
||||
"type": "INT",
|
||||
"linkIds": [
|
||||
76
|
||||
],
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4384
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "17994fa2-0ea0-4c9b-a70a-19789c459c80",
|
||||
"name": "regions.face_oval",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
77
|
||||
],
|
||||
"label": "custom_face_oval",
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4404
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "1c6c5893-2aee-4c37-b702-15ef2e20d863",
|
||||
"name": "regions.lips",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
78
|
||||
],
|
||||
"label": "custom_lips",
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4424
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "f353fcea-4b6f-42a1-8fdd-32b3aa1e1f09",
|
||||
"name": "regions.left_eye",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
79
|
||||
],
|
||||
"label": "custom_left_eye",
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4444
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "1387e121-c1fb-4522-8f0d-43459e11dd86",
|
||||
"name": "regions.right_eye",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
80
|
||||
],
|
||||
"label": "custom_right_eye",
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4464
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "14acb0a0-d1f4-48f3-ba31-811b26236ef9",
|
||||
"name": "regions.irises",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
81
|
||||
],
|
||||
"label": "custom_irises",
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4484
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "25a82859-87de-42c8-8431-09948665546e",
|
||||
"name": "model_name",
|
||||
"type": "COMBO",
|
||||
"linkIds": [
|
||||
86
|
||||
],
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4504
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "d2ba3f92-e8b1-49c3-9590-cfad56c54cf4",
|
||||
"name": "face_landmarks",
|
||||
"type": "FACE_LANDMARKS",
|
||||
"linkIds": [
|
||||
44
|
||||
],
|
||||
"localized_name": "face_landmarks",
|
||||
"pos": [
|
||||
164,
|
||||
4504
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "4f356bb0-d4c4-4f93-b4cf-0845a65c4e6d",
|
||||
"name": "bboxes",
|
||||
"type": "BOUNDING_BOX",
|
||||
"linkIds": [
|
||||
25
|
||||
],
|
||||
"localized_name": "bboxes",
|
||||
"pos": [
|
||||
164,
|
||||
4524
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "f6309e1d-6397-4363-b38f-778a122abc51",
|
||||
"name": "MASK_1",
|
||||
"type": "MASK",
|
||||
"linkIds": [
|
||||
83
|
||||
],
|
||||
"label": "mask",
|
||||
"pos": [
|
||||
164,
|
||||
4544
|
||||
]
|
||||
}
|
||||
],
|
||||
"widgets": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": 11,
|
||||
"type": "MediaPipeFaceLandmarker",
|
||||
"pos": [
|
||||
-280,
|
||||
4280
|
||||
],
|
||||
"size": [
|
||||
350,
|
||||
220
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "face_detection_model",
|
||||
"name": "face_detection_model",
|
||||
"type": "FACE_DETECTION_MODEL",
|
||||
"link": 66
|
||||
},
|
||||
{
|
||||
"localized_name": "image",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"link": 60
|
||||
},
|
||||
{
|
||||
"localized_name": "detector_variant",
|
||||
"name": "detector_variant",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "detector_variant"
|
||||
},
|
||||
"link": 75
|
||||
},
|
||||
{
|
||||
"localized_name": "num_faces",
|
||||
"name": "num_faces",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "num_faces"
|
||||
},
|
||||
"link": 76
|
||||
},
|
||||
{
|
||||
"localized_name": "min_confidence",
|
||||
"name": "min_confidence",
|
||||
"type": "FLOAT",
|
||||
"widget": {
|
||||
"name": "min_confidence"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "missing_frame_fallback",
|
||||
"name": "missing_frame_fallback",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "missing_frame_fallback"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "face_landmarker",
|
||||
"type": "FACE_LANDMARKER",
|
||||
"link": 74
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "face_landmarks",
|
||||
"name": "face_landmarks",
|
||||
"type": "FACE_LANDMARKS",
|
||||
"links": [
|
||||
44,
|
||||
46
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "bboxes",
|
||||
"name": "bboxes",
|
||||
"type": "BOUNDING_BOX",
|
||||
"links": [
|
||||
25
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "MediaPipeFaceLandmarker",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.22.0",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65
|
||||
},
|
||||
"widgets_values": [
|
||||
"full",
|
||||
0,
|
||||
0.5,
|
||||
"empty"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "LoadMediaPipeFaceLandmarker",
|
||||
"pos": [
|
||||
-290,
|
||||
4060
|
||||
],
|
||||
"size": [
|
||||
350,
|
||||
140
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "model_name",
|
||||
"name": "model_name",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "model_name"
|
||||
},
|
||||
"link": 86
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "FACE_DETECTION_MODEL",
|
||||
"name": "FACE_DETECTION_MODEL",
|
||||
"type": "FACE_DETECTION_MODEL",
|
||||
"links": [
|
||||
66
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "LoadMediaPipeFaceLandmarker",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.22.0",
|
||||
"models": [
|
||||
{
|
||||
"name": "mediapipe_face_fp32.safetensors",
|
||||
"url": "https://huggingface.co/Comfy-Org/mediapipe/resolve/main/detection/mediapipe_face_fp32.safetensors",
|
||||
"directory": "detection"
|
||||
}
|
||||
],
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65
|
||||
},
|
||||
"widgets_values": [
|
||||
"mediapipe_face_fp32.safetensors"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 20,
|
||||
"type": "MediaPipeFaceMask",
|
||||
"pos": [
|
||||
-290,
|
||||
4560
|
||||
],
|
||||
"size": [
|
||||
360,
|
||||
180
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "face_landmarks",
|
||||
"name": "face_landmarks",
|
||||
"type": "FACE_LANDMARKS",
|
||||
"link": 46
|
||||
},
|
||||
{
|
||||
"localized_name": "regions",
|
||||
"name": "regions",
|
||||
"type": "COMFY_DYNAMICCOMBO_V3",
|
||||
"widget": {
|
||||
"name": "regions"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "regions.face_oval",
|
||||
"name": "regions.face_oval",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.face_oval"
|
||||
},
|
||||
"link": 77
|
||||
},
|
||||
{
|
||||
"localized_name": "regions.lips",
|
||||
"name": "regions.lips",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.lips"
|
||||
},
|
||||
"link": 78
|
||||
},
|
||||
{
|
||||
"localized_name": "regions.left_eye",
|
||||
"name": "regions.left_eye",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.left_eye"
|
||||
},
|
||||
"link": 79
|
||||
},
|
||||
{
|
||||
"localized_name": "regions.right_eye",
|
||||
"name": "regions.right_eye",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.right_eye"
|
||||
},
|
||||
"link": 80
|
||||
},
|
||||
{
|
||||
"localized_name": "regions.irises",
|
||||
"name": "regions.irises",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.irises"
|
||||
},
|
||||
"link": 81
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "MASK",
|
||||
"name": "MASK",
|
||||
"type": "MASK",
|
||||
"links": [
|
||||
83
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "MediaPipeFaceMask",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.22.0",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65
|
||||
},
|
||||
"widgets_values": [
|
||||
"custom",
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
],
|
||||
"groups": [],
|
||||
"links": [
|
||||
{
|
||||
"id": 66,
|
||||
"origin_id": 2,
|
||||
"origin_slot": 0,
|
||||
"target_id": 11,
|
||||
"target_slot": 0,
|
||||
"type": "FACE_DETECTION_MODEL"
|
||||
},
|
||||
{
|
||||
"id": 46,
|
||||
"origin_id": 11,
|
||||
"origin_slot": 0,
|
||||
"target_id": 20,
|
||||
"target_slot": 0,
|
||||
"type": "FACE_LANDMARKS"
|
||||
},
|
||||
{
|
||||
"id": 60,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 11,
|
||||
"target_slot": 1,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"origin_id": 11,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 0,
|
||||
"type": "FACE_LANDMARKS"
|
||||
},
|
||||
{
|
||||
"id": 25,
|
||||
"origin_id": 11,
|
||||
"origin_slot": 1,
|
||||
"target_id": -20,
|
||||
"target_slot": 1,
|
||||
"type": "BOUNDING_BOX"
|
||||
},
|
||||
{
|
||||
"id": 74,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 1,
|
||||
"target_id": 11,
|
||||
"target_slot": 6,
|
||||
"type": "FACE_LANDMARKER"
|
||||
},
|
||||
{
|
||||
"id": 75,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 2,
|
||||
"target_id": 11,
|
||||
"target_slot": 2,
|
||||
"type": "COMBO"
|
||||
},
|
||||
{
|
||||
"id": 76,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 3,
|
||||
"target_id": 11,
|
||||
"target_slot": 3,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 4,
|
||||
"target_id": 20,
|
||||
"target_slot": 2,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 78,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 5,
|
||||
"target_id": 20,
|
||||
"target_slot": 3,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 79,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 6,
|
||||
"target_id": 20,
|
||||
"target_slot": 4,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 80,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 7,
|
||||
"target_id": 20,
|
||||
"target_slot": 5,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 8,
|
||||
"target_id": 20,
|
||||
"target_slot": 6,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 83,
|
||||
"origin_id": 20,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 2,
|
||||
"type": "MASK"
|
||||
},
|
||||
{
|
||||
"id": 86,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 9,
|
||||
"target_id": 2,
|
||||
"target_slot": 0,
|
||||
"type": "COMBO"
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Conditioning & Preprocessors/Face Detection"
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
@ -703,7 +703,7 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Image Tools/Image Segmentation",
|
||||
"category": "Conditioning & Preprocessors/Segmentation & Mask",
|
||||
"description": "Segments images into masks using Meta SAM3 from text prompts, points, or boxes."
|
||||
}
|
||||
]
|
||||
|
||||
@ -1302,7 +1302,7 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Enhance",
|
||||
"category": "Image generation and editing/Upscale",
|
||||
"description": "Upscales images to higher resolution using Z-Image-Turbo."
|
||||
}
|
||||
]
|
||||
@ -1312,4 +1312,4 @@
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
1206
blueprints/Image to Pose Map (SDPose Multi-Person).json
Normal file
1206
blueprints/Image to Pose Map (SDPose Multi-Person).json
Normal file
File diff suppressed because it is too large
Load Diff
888
blueprints/Image to Pose Map (SDPose-OOD).json
Normal file
888
blueprints/Image to Pose Map (SDPose-OOD).json
Normal file
@ -0,0 +1,888 @@
|
||||
{
|
||||
"revision": 0,
|
||||
"last_node_id": 675,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 675,
|
||||
"type": "01b6a731-fb78-4070-9a38-c87146da9604",
|
||||
"pos": [
|
||||
-2480,
|
||||
3400
|
||||
],
|
||||
"size": [
|
||||
360,
|
||||
433.3125
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "input",
|
||||
"name": "input",
|
||||
"type": "IMAGE,MASK",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "resize_target_longer_size",
|
||||
"name": "resize_type.longer_size",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "resize_type.longer_size"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "scale_method",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "scale_method"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "draw_body",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "draw_body"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "draw_hands",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "draw_hands"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "draw_face",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "draw_face"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "draw_feet",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "draw_feet"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "stick_width",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "stick_width"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "face_point_size",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "face_point_size"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "score_threshold",
|
||||
"type": "FLOAT",
|
||||
"widget": {
|
||||
"name": "score_threshold"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "ckpt_name",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "ckpt_name"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "bboxes",
|
||||
"shape": 7,
|
||||
"type": "BOUNDING_BOX",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "IMAGE",
|
||||
"name": "IMAGE",
|
||||
"type": "IMAGE",
|
||||
"links": []
|
||||
},
|
||||
{
|
||||
"name": "keypoints",
|
||||
"type": "POSE_KEYPOINT",
|
||||
"links": null
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"proxyWidgets": [
|
||||
[
|
||||
"674",
|
||||
"resize_type.longer_size"
|
||||
],
|
||||
[
|
||||
"674",
|
||||
"scale_method"
|
||||
],
|
||||
[
|
||||
"672",
|
||||
"draw_body"
|
||||
],
|
||||
[
|
||||
"672",
|
||||
"draw_hands"
|
||||
],
|
||||
[
|
||||
"672",
|
||||
"draw_face"
|
||||
],
|
||||
[
|
||||
"672",
|
||||
"draw_feet"
|
||||
],
|
||||
[
|
||||
"672",
|
||||
"stick_width"
|
||||
],
|
||||
[
|
||||
"672",
|
||||
"face_point_size"
|
||||
],
|
||||
[
|
||||
"672",
|
||||
"score_threshold"
|
||||
],
|
||||
[
|
||||
"673",
|
||||
"ckpt_name"
|
||||
]
|
||||
],
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.15.1",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [],
|
||||
"title": "Image to Pose Map (SDPose-OOD)"
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"version": 0.4,
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "01b6a731-fb78-4070-9a38-c87146da9604",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 0,
|
||||
"lastNodeId": 676,
|
||||
"lastLinkId": 1715,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "Image to Pose Map (SDPose-OOD)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
-3290,
|
||||
3590,
|
||||
190.8984375,
|
||||
288
|
||||
]
|
||||
},
|
||||
"outputNode": {
|
||||
"id": -20,
|
||||
"bounding": [
|
||||
-1756.2451602089645,
|
||||
3366,
|
||||
128,
|
||||
88
|
||||
]
|
||||
},
|
||||
"inputs": [
|
||||
{
|
||||
"id": "e24699c3-1356-4634-9eb4-19bb58e5c0b0",
|
||||
"name": "input",
|
||||
"type": "IMAGE,MASK",
|
||||
"linkIds": [
|
||||
1700
|
||||
],
|
||||
"localized_name": "input",
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3614
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "088eefc1-cd8a-4573-993f-9e4da008a12d",
|
||||
"name": "resize_type.longer_size",
|
||||
"type": "INT",
|
||||
"linkIds": [
|
||||
1704
|
||||
],
|
||||
"label": "resize_target_longer_size",
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3634
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "b6449bd3-73d4-41c8-b81f-cf8d33f76a2e",
|
||||
"name": "scale_method",
|
||||
"type": "COMBO",
|
||||
"linkIds": [
|
||||
1705
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3654
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "4cff52ad-ed07-4c97-8803-fcbd89554fd0",
|
||||
"name": "draw_body",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
1706
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3674
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "7af63dce-f7df-4d7e-8215-d7c7f60bf81c",
|
||||
"name": "draw_hands",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
1707
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3694
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "af3a9bce-61f9-4aca-b530-9f65e028b35e",
|
||||
"name": "draw_face",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
1708
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3714
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "4620f6a3-2c85-4b79-ad8f-35d0326b568f",
|
||||
"name": "draw_feet",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
1709
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3734
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "fee5d0c9-8d4b-4934-81d8-ba2206dc56cb",
|
||||
"name": "stick_width",
|
||||
"type": "INT",
|
||||
"linkIds": [
|
||||
1710
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3754
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "aafdd060-ba81-4324-a9cc-b656e1ebc133",
|
||||
"name": "face_point_size",
|
||||
"type": "INT",
|
||||
"linkIds": [
|
||||
1711
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3774
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "514c5503-f9e6-4d23-b1ae-1d3291acb2a3",
|
||||
"name": "score_threshold",
|
||||
"type": "FLOAT",
|
||||
"linkIds": [
|
||||
1712
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3794
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "ae46de61-2cc6-483e-8ee9-87e4144a2ffa",
|
||||
"name": "ckpt_name",
|
||||
"type": "COMBO",
|
||||
"linkIds": [
|
||||
1713
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3814
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "41bec0c6-dffa-4c78-9289-ee678715ae54",
|
||||
"name": "bboxes",
|
||||
"type": "BOUNDING_BOX",
|
||||
"linkIds": [
|
||||
1714
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3834
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "f05ed8cc-9403-4f14-8085-4364b06f8a48",
|
||||
"name": "IMAGE",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
1701
|
||||
],
|
||||
"localized_name": "IMAGE",
|
||||
"pos": [
|
||||
-1732.2451602089645,
|
||||
3390
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "29a6584e-4685-4986-8ffd-e6d8539953fd",
|
||||
"name": "keypoints",
|
||||
"type": "POSE_KEYPOINT",
|
||||
"linkIds": [
|
||||
1715
|
||||
],
|
||||
"pos": [
|
||||
-1732.2451602089645,
|
||||
3410
|
||||
]
|
||||
}
|
||||
],
|
||||
"widgets": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": 671,
|
||||
"type": "SDPoseKeypointExtractor",
|
||||
"pos": [
|
||||
-2470,
|
||||
3250
|
||||
],
|
||||
"size": [
|
||||
270,
|
||||
180
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "model",
|
||||
"name": "model",
|
||||
"type": "MODEL",
|
||||
"link": 1696
|
||||
},
|
||||
{
|
||||
"localized_name": "vae",
|
||||
"name": "vae",
|
||||
"type": "VAE",
|
||||
"link": 1697
|
||||
},
|
||||
{
|
||||
"localized_name": "image",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"link": 1698
|
||||
},
|
||||
{
|
||||
"localized_name": "bboxes",
|
||||
"name": "bboxes",
|
||||
"shape": 7,
|
||||
"type": "BOUNDING_BOX",
|
||||
"link": 1714
|
||||
},
|
||||
{
|
||||
"localized_name": "batch_size",
|
||||
"name": "batch_size",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "batch_size"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "keypoints",
|
||||
"name": "keypoints",
|
||||
"type": "POSE_KEYPOINT",
|
||||
"links": [
|
||||
1699,
|
||||
1715
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "SDPoseKeypointExtractor",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.15.0",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 674,
|
||||
"type": "ResizeImageMaskNode",
|
||||
"pos": [
|
||||
-2960,
|
||||
3490
|
||||
],
|
||||
"size": [
|
||||
270,
|
||||
110
|
||||
],
|
||||
"flags": {},
|
||||
"order": 3,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "input",
|
||||
"name": "input",
|
||||
"type": "IMAGE,MASK",
|
||||
"link": 1700
|
||||
},
|
||||
{
|
||||
"localized_name": "resize_type",
|
||||
"name": "resize_type",
|
||||
"type": "COMFY_DYNAMICCOMBO_V3",
|
||||
"widget": {
|
||||
"name": "resize_type"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "resize_type.longer_size",
|
||||
"name": "resize_type.longer_size",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "resize_type.longer_size"
|
||||
},
|
||||
"link": 1704
|
||||
},
|
||||
{
|
||||
"localized_name": "scale_method",
|
||||
"name": "scale_method",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "scale_method"
|
||||
},
|
||||
"link": 1705
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "resized",
|
||||
"name": "resized",
|
||||
"type": "*",
|
||||
"links": [
|
||||
1698
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "ResizeImageMaskNode",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.15.0",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
"scale longer dimension",
|
||||
1024,
|
||||
"area"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 672,
|
||||
"type": "SDPoseDrawKeypoints",
|
||||
"pos": [
|
||||
-2120,
|
||||
3260
|
||||
],
|
||||
"size": [
|
||||
270,
|
||||
280
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "keypoints",
|
||||
"name": "keypoints",
|
||||
"type": "POSE_KEYPOINT",
|
||||
"link": 1699
|
||||
},
|
||||
{
|
||||
"localized_name": "draw_body",
|
||||
"name": "draw_body",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "draw_body"
|
||||
},
|
||||
"link": 1706
|
||||
},
|
||||
{
|
||||
"localized_name": "draw_hands",
|
||||
"name": "draw_hands",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "draw_hands"
|
||||
},
|
||||
"link": 1707
|
||||
},
|
||||
{
|
||||
"localized_name": "draw_face",
|
||||
"name": "draw_face",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "draw_face"
|
||||
},
|
||||
"link": 1708
|
||||
},
|
||||
{
|
||||
"localized_name": "draw_feet",
|
||||
"name": "draw_feet",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "draw_feet"
|
||||
},
|
||||
"link": 1709
|
||||
},
|
||||
{
|
||||
"localized_name": "stick_width",
|
||||
"name": "stick_width",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "stick_width"
|
||||
},
|
||||
"link": 1710
|
||||
},
|
||||
{
|
||||
"localized_name": "face_point_size",
|
||||
"name": "face_point_size",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "face_point_size"
|
||||
},
|
||||
"link": 1711
|
||||
},
|
||||
{
|
||||
"localized_name": "score_threshold",
|
||||
"name": "score_threshold",
|
||||
"type": "FLOAT",
|
||||
"widget": {
|
||||
"name": "score_threshold"
|
||||
},
|
||||
"link": 1712
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "IMAGE",
|
||||
"name": "IMAGE",
|
||||
"type": "IMAGE",
|
||||
"links": [
|
||||
1701
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "SDPoseDrawKeypoints",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.15.0",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
4,
|
||||
2,
|
||||
0.5
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 673,
|
||||
"type": "CheckpointLoaderSimple",
|
||||
"pos": [
|
||||
-2960,
|
||||
3250
|
||||
],
|
||||
"size": [
|
||||
390,
|
||||
190
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "ckpt_name",
|
||||
"name": "ckpt_name",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "ckpt_name"
|
||||
},
|
||||
"link": 1713
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "MODEL",
|
||||
"name": "MODEL",
|
||||
"type": "MODEL",
|
||||
"links": [
|
||||
1696
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "CLIP",
|
||||
"name": "CLIP",
|
||||
"type": "CLIP",
|
||||
"links": []
|
||||
},
|
||||
{
|
||||
"localized_name": "VAE",
|
||||
"name": "VAE",
|
||||
"type": "VAE",
|
||||
"links": [
|
||||
1697
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "CheckpointLoaderSimple",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.15.0",
|
||||
"models": [
|
||||
{
|
||||
"name": "sdpose_wholebody_fp16.safetensors",
|
||||
"url": "https://huggingface.co/Comfy-Org/SDPose/resolve/main/checkpoints/sdpose_wholebody_fp16.safetensors",
|
||||
"directory": "checkpoints"
|
||||
}
|
||||
],
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
"sdpose_wholebody_fp16.safetensors"
|
||||
]
|
||||
}
|
||||
],
|
||||
"groups": [],
|
||||
"links": [
|
||||
{
|
||||
"id": 1696,
|
||||
"origin_id": 673,
|
||||
"origin_slot": 0,
|
||||
"target_id": 671,
|
||||
"target_slot": 0,
|
||||
"type": "MODEL"
|
||||
},
|
||||
{
|
||||
"id": 1697,
|
||||
"origin_id": 673,
|
||||
"origin_slot": 2,
|
||||
"target_id": 671,
|
||||
"target_slot": 1,
|
||||
"type": "VAE"
|
||||
},
|
||||
{
|
||||
"id": 1698,
|
||||
"origin_id": 674,
|
||||
"origin_slot": 0,
|
||||
"target_id": 671,
|
||||
"target_slot": 2,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 1699,
|
||||
"origin_id": 671,
|
||||
"origin_slot": 0,
|
||||
"target_id": 672,
|
||||
"target_slot": 0,
|
||||
"type": "POSE_KEYPOINT"
|
||||
},
|
||||
{
|
||||
"id": 1700,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 674,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE,MASK"
|
||||
},
|
||||
{
|
||||
"id": 1701,
|
||||
"origin_id": 672,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 1704,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 1,
|
||||
"target_id": 674,
|
||||
"target_slot": 2,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 1705,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 2,
|
||||
"target_id": 674,
|
||||
"target_slot": 3,
|
||||
"type": "COMBO"
|
||||
},
|
||||
{
|
||||
"id": 1706,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 3,
|
||||
"target_id": 672,
|
||||
"target_slot": 1,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 1707,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 4,
|
||||
"target_id": 672,
|
||||
"target_slot": 2,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 1708,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 5,
|
||||
"target_id": 672,
|
||||
"target_slot": 3,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 1709,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 6,
|
||||
"target_id": 672,
|
||||
"target_slot": 4,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 1710,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 7,
|
||||
"target_id": 672,
|
||||
"target_slot": 5,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 1711,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 8,
|
||||
"target_id": 672,
|
||||
"target_slot": 6,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 1712,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 9,
|
||||
"target_id": 672,
|
||||
"target_slot": 7,
|
||||
"type": "FLOAT"
|
||||
},
|
||||
{
|
||||
"id": 1713,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 10,
|
||||
"target_id": 673,
|
||||
"target_slot": 0,
|
||||
"type": "COMBO"
|
||||
},
|
||||
{
|
||||
"id": 1714,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 11,
|
||||
"target_id": 671,
|
||||
"target_slot": 3,
|
||||
"type": "BOUNDING_BOX"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"origin_id": 671,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 1,
|
||||
"type": "POSE_KEYPOINT"
|
||||
}
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Conditioning & Preprocessors/Pose",
|
||||
"description": "Extracts human pose keypoints and stick-figure visuals from an image using SDPose-OOD, with optional bounding-box input per subject."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {
|
||||
"ue_links": []
|
||||
}
|
||||
}
|
||||
1219
blueprints/Merge Videos.json
Normal file
1219
blueprints/Merge Videos.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -1298,7 +1298,7 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Pose to image",
|
||||
"category": "Image generation and editing/Conditioned",
|
||||
"description": "Generates an image from pose keypoints using Z-Image-Turbo with text conditioning."
|
||||
}
|
||||
]
|
||||
|
||||
@ -3870,7 +3870,7 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Pose to video",
|
||||
"category": "Video generation and editing/Conditioned",
|
||||
"description": "Generates video from pose reference frames using LTX-2, with optional synchronized audio."
|
||||
}
|
||||
]
|
||||
|
||||
@ -270,7 +270,7 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Text generation/Prompt enhance",
|
||||
"category": "Text Tools",
|
||||
"description": "Expands short text prompts into detailed descriptions using a text generation model for better generation quality."
|
||||
}
|
||||
]
|
||||
|
||||
@ -389,7 +389,7 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Image generation and editing/Background Removal"
|
||||
"category": "Image Tools/Background Removal"
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
485
blueprints/Select Per-Line Text by Index.json
Normal file
485
blueprints/Select Per-Line Text by Index.json
Normal file
@ -0,0 +1,485 @@
|
||||
{
|
||||
"revision": 0,
|
||||
"last_node_id": 10,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 10,
|
||||
"type": "3fb7557a-470d-4983-9d8c-6d5caa9788f0",
|
||||
"pos": [
|
||||
-250,
|
||||
8590
|
||||
],
|
||||
"size": [
|
||||
280,
|
||||
360
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "text_per_line",
|
||||
"name": "text_per_line",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "text_per_line"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "index",
|
||||
"name": "index",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "index"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "selected_line",
|
||||
"name": "selected_line",
|
||||
"type": "STRING",
|
||||
"links": []
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"proxyWidgets": [
|
||||
[
|
||||
"2",
|
||||
"string"
|
||||
],
|
||||
[
|
||||
"3",
|
||||
"value"
|
||||
]
|
||||
],
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.19.0",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [],
|
||||
"title": "Select Per-Line Text by Index"
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"version": 0.4,
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "3fb7557a-470d-4983-9d8c-6d5caa9788f0",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 0,
|
||||
"lastNodeId": 10,
|
||||
"lastLinkId": 14,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "Select Per-Line Text by Index",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
-990,
|
||||
8595,
|
||||
128,
|
||||
88
|
||||
]
|
||||
},
|
||||
"outputNode": {
|
||||
"id": -20,
|
||||
"bounding": [
|
||||
710,
|
||||
8585,
|
||||
128,
|
||||
68
|
||||
]
|
||||
},
|
||||
"inputs": [
|
||||
{
|
||||
"id": "75417d82-a934-4ac9-b667-d8dcd5a3bfb3",
|
||||
"name": "text_per_line",
|
||||
"type": "STRING",
|
||||
"linkIds": [
|
||||
13
|
||||
],
|
||||
"localized_name": "text_per_line",
|
||||
"pos": [
|
||||
-886,
|
||||
8619
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "46e69a73-1804-4ca6-9175-31445bf0be96",
|
||||
"name": "index",
|
||||
"type": "INT",
|
||||
"linkIds": [
|
||||
14
|
||||
],
|
||||
"localized_name": "index",
|
||||
"pos": [
|
||||
-886,
|
||||
8639
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "e34e8ad1-84d2-4bd2-a460-eb7de6067c10",
|
||||
"name": "selected_line",
|
||||
"type": "STRING",
|
||||
"linkIds": [
|
||||
10
|
||||
],
|
||||
"localized_name": "selected_line",
|
||||
"pos": [
|
||||
734,
|
||||
8609
|
||||
]
|
||||
}
|
||||
],
|
||||
"widgets": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": 1,
|
||||
"type": "PreviewAny",
|
||||
"pos": [
|
||||
-500,
|
||||
8400
|
||||
],
|
||||
"size": [
|
||||
230,
|
||||
180
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "source",
|
||||
"name": "source",
|
||||
"type": "*",
|
||||
"link": 1
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "STRING",
|
||||
"name": "STRING",
|
||||
"type": "STRING",
|
||||
"links": [
|
||||
6
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "PreviewAny",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.19.0",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
null,
|
||||
null,
|
||||
null
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "RegexExtract",
|
||||
"pos": [
|
||||
-240,
|
||||
8740
|
||||
],
|
||||
"size": [
|
||||
470,
|
||||
460
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"showAdvanced": false,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "string",
|
||||
"name": "string",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "string"
|
||||
},
|
||||
"link": 13
|
||||
},
|
||||
{
|
||||
"localized_name": "regex_pattern",
|
||||
"name": "regex_pattern",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "regex_pattern"
|
||||
},
|
||||
"link": 9
|
||||
},
|
||||
{
|
||||
"localized_name": "mode",
|
||||
"name": "mode",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "mode"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "case_insensitive",
|
||||
"name": "case_insensitive",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "case_insensitive"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "multiline",
|
||||
"name": "multiline",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "multiline"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "dotall",
|
||||
"name": "dotall",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "dotall"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "group_index",
|
||||
"name": "group_index",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "group_index"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "STRING",
|
||||
"name": "STRING",
|
||||
"type": "STRING",
|
||||
"links": [
|
||||
10
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "RegexExtract",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.19.0",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
"",
|
||||
"",
|
||||
"First Group",
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"type": "PrimitiveInt",
|
||||
"pos": [
|
||||
-810,
|
||||
8400
|
||||
],
|
||||
"size": [
|
||||
270,
|
||||
110
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "value",
|
||||
"name": "value",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "value"
|
||||
},
|
||||
"link": 14
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "INT",
|
||||
"name": "INT",
|
||||
"type": "INT",
|
||||
"links": [
|
||||
1
|
||||
]
|
||||
}
|
||||
],
|
||||
"title": "Int (line index)",
|
||||
"properties": {
|
||||
"Node name for S&R": "Int (line index)",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.19.0",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
0,
|
||||
"fixed"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 8,
|
||||
"type": "StringReplace",
|
||||
"pos": [
|
||||
-240,
|
||||
8400
|
||||
],
|
||||
"size": [
|
||||
400,
|
||||
280
|
||||
],
|
||||
"flags": {},
|
||||
"order": 3,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "string",
|
||||
"name": "string",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "string"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "find",
|
||||
"name": "find",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "find"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "replace",
|
||||
"name": "replace",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "replace"
|
||||
},
|
||||
"link": 6
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "STRING",
|
||||
"name": "STRING",
|
||||
"type": "STRING",
|
||||
"links": [
|
||||
9
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "StringReplace",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.19.0",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
"^(?:[^\\n]*\\n){index}([^\\n]*)(?:\\n|$)",
|
||||
"index",
|
||||
""
|
||||
]
|
||||
}
|
||||
],
|
||||
"groups": [],
|
||||
"links": [
|
||||
{
|
||||
"id": 1,
|
||||
"origin_id": 3,
|
||||
"origin_slot": 0,
|
||||
"target_id": 1,
|
||||
"target_slot": 0,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 9,
|
||||
"origin_id": 8,
|
||||
"origin_slot": 0,
|
||||
"target_id": 2,
|
||||
"target_slot": 1,
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"id": 6,
|
||||
"origin_id": 1,
|
||||
"origin_slot": 0,
|
||||
"target_id": 8,
|
||||
"target_slot": 2,
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"id": 10,
|
||||
"origin_id": 2,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 0,
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 2,
|
||||
"target_slot": 0,
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"id": 14,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 1,
|
||||
"target_id": 3,
|
||||
"target_slot": 0,
|
||||
"type": "INT"
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Text Tools",
|
||||
"description": "Selects one line from multiline text by zero-based index for batch or list-driven prompt workflows."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {
|
||||
"ue_links": [],
|
||||
"links_added_by_ue": []
|
||||
}
|
||||
}
|
||||
714
blueprints/Split Image Grid to Tiles.json
Normal file
714
blueprints/Split Image Grid to Tiles.json
Normal file
@ -0,0 +1,714 @@
|
||||
{
|
||||
"revision": 0,
|
||||
"last_node_id": 251,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 251,
|
||||
"type": "609e1fd1-b731-4b78-89ac-d19b1156b025",
|
||||
"pos": [
|
||||
-1490,
|
||||
130
|
||||
],
|
||||
"size": [
|
||||
230,
|
||||
164
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "source_image",
|
||||
"name": "source_image",
|
||||
"type": "IMAGE",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "columns",
|
||||
"name": "columns",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "columns"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "rows",
|
||||
"name": "rows",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "rows"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "tiles",
|
||||
"name": "tiles",
|
||||
"type": "IMAGE",
|
||||
"links": []
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"proxyWidgets": [
|
||||
[
|
||||
"228",
|
||||
"value"
|
||||
],
|
||||
[
|
||||
"252",
|
||||
"value"
|
||||
]
|
||||
],
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.20.1",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65
|
||||
},
|
||||
"widgets_values": [],
|
||||
"title": "Split Image Grid to Tiles"
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"version": 0.4,
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "609e1fd1-b731-4b78-89ac-d19b1156b025",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 9,
|
||||
"lastNodeId": 252,
|
||||
"lastLinkId": 429,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "Split Image Grid to Tiles",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
-1690,
|
||||
260,
|
||||
128,
|
||||
108
|
||||
]
|
||||
},
|
||||
"outputNode": {
|
||||
"id": -20,
|
||||
"bounding": [
|
||||
-510,
|
||||
590,
|
||||
128,
|
||||
68
|
||||
]
|
||||
},
|
||||
"inputs": [
|
||||
{
|
||||
"id": "866ac798-cfbc-450a-b755-e704f86404d9",
|
||||
"name": "source_image",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
386,
|
||||
389
|
||||
],
|
||||
"localized_name": "source_image",
|
||||
"pos": [
|
||||
-1586,
|
||||
284
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "bc37b1f8-8ab2-4f19-bd00-75d4fbc4feb3",
|
||||
"name": "columns",
|
||||
"type": "INT",
|
||||
"linkIds": [
|
||||
427
|
||||
],
|
||||
"localized_name": "columns",
|
||||
"pos": [
|
||||
-1586,
|
||||
304
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "d45915da-e848-43dd-9ccc-e3161e9c99d9",
|
||||
"name": "rows",
|
||||
"type": "INT",
|
||||
"linkIds": [
|
||||
428
|
||||
],
|
||||
"localized_name": "rows",
|
||||
"pos": [
|
||||
-1586,
|
||||
324
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "18bc780f-064b-4038-87c6-67dba71deb08",
|
||||
"name": "tiles",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
394
|
||||
],
|
||||
"localized_name": "tiles",
|
||||
"shape": 6,
|
||||
"pos": [
|
||||
-486,
|
||||
614
|
||||
]
|
||||
}
|
||||
],
|
||||
"widgets": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": 225,
|
||||
"type": "SplitImageToTileList",
|
||||
"pos": [
|
||||
-1010,
|
||||
620
|
||||
],
|
||||
"size": [
|
||||
290,
|
||||
170
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "image",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"link": 386
|
||||
},
|
||||
{
|
||||
"localized_name": "tile_width",
|
||||
"name": "tile_width",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "tile_width"
|
||||
},
|
||||
"link": 403
|
||||
},
|
||||
{
|
||||
"localized_name": "tile_height",
|
||||
"name": "tile_height",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "tile_height"
|
||||
},
|
||||
"link": 404
|
||||
},
|
||||
{
|
||||
"localized_name": "overlap",
|
||||
"name": "overlap",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "overlap"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "IMAGE",
|
||||
"name": "IMAGE",
|
||||
"shape": 6,
|
||||
"type": "IMAGE",
|
||||
"links": [
|
||||
394
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "SplitImageToTileList",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.20.1",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65
|
||||
},
|
||||
"widgets_values": [
|
||||
1024,
|
||||
1024,
|
||||
0
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 231,
|
||||
"type": "ComfyMathExpression",
|
||||
"pos": [
|
||||
-1080,
|
||||
330
|
||||
],
|
||||
"size": [
|
||||
370,
|
||||
190
|
||||
],
|
||||
"flags": {},
|
||||
"order": 4,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"label": "a",
|
||||
"localized_name": "values.a",
|
||||
"name": "values.a",
|
||||
"type": "FLOAT,INT,BOOLEAN",
|
||||
"link": 390
|
||||
},
|
||||
{
|
||||
"label": "b",
|
||||
"localized_name": "values.b",
|
||||
"name": "values.b",
|
||||
"shape": 7,
|
||||
"type": "FLOAT,INT,BOOLEAN",
|
||||
"link": 429
|
||||
},
|
||||
{
|
||||
"label": "c",
|
||||
"localized_name": "values.c",
|
||||
"name": "values.c",
|
||||
"shape": 7,
|
||||
"type": "FLOAT,INT,BOOLEAN",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "expression",
|
||||
"name": "expression",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "expression"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "FLOAT",
|
||||
"name": "FLOAT",
|
||||
"type": "FLOAT",
|
||||
"links": null
|
||||
},
|
||||
{
|
||||
"localized_name": "INT",
|
||||
"name": "INT",
|
||||
"type": "INT",
|
||||
"links": [
|
||||
404
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "BOOL",
|
||||
"name": "BOOL",
|
||||
"type": "BOOLEAN",
|
||||
"links": null
|
||||
}
|
||||
],
|
||||
"title": "Math Expression (Height)",
|
||||
"properties": {
|
||||
"Node name for S&R": "ComfyMathExpression",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.18.1",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65,
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
"max(1, (int(a) + int(b) - 1) // int(b))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 229,
|
||||
"type": "ComfyMathExpression",
|
||||
"pos": [
|
||||
-1090,
|
||||
-30
|
||||
],
|
||||
"size": [
|
||||
370,
|
||||
190
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"label": "a",
|
||||
"localized_name": "values.a",
|
||||
"name": "values.a",
|
||||
"type": "FLOAT,INT,BOOLEAN",
|
||||
"link": 387
|
||||
},
|
||||
{
|
||||
"label": "b",
|
||||
"localized_name": "values.b",
|
||||
"name": "values.b",
|
||||
"shape": 7,
|
||||
"type": "FLOAT,INT,BOOLEAN",
|
||||
"link": 388
|
||||
},
|
||||
{
|
||||
"label": "c",
|
||||
"localized_name": "values.c",
|
||||
"name": "values.c",
|
||||
"shape": 7,
|
||||
"type": "FLOAT,INT,BOOLEAN",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "expression",
|
||||
"name": "expression",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "expression"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "FLOAT",
|
||||
"name": "FLOAT",
|
||||
"type": "FLOAT",
|
||||
"links": null
|
||||
},
|
||||
{
|
||||
"localized_name": "INT",
|
||||
"name": "INT",
|
||||
"type": "INT",
|
||||
"links": [
|
||||
403
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "BOOL",
|
||||
"name": "BOOL",
|
||||
"type": "BOOLEAN",
|
||||
"links": null
|
||||
}
|
||||
],
|
||||
"title": "Math Expression (Width)",
|
||||
"properties": {
|
||||
"Node name for S&R": "ComfyMathExpression",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.18.1",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65,
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
"max(1, (int(a) + int(b) - 1) // int(b))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 228,
|
||||
"type": "PrimitiveInt",
|
||||
"pos": [
|
||||
-1380,
|
||||
90
|
||||
],
|
||||
"size": [
|
||||
230,
|
||||
110
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "value",
|
||||
"name": "value",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "value"
|
||||
},
|
||||
"link": 427
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "INT",
|
||||
"name": "INT",
|
||||
"type": "INT",
|
||||
"links": [
|
||||
388
|
||||
]
|
||||
}
|
||||
],
|
||||
"title": "Int (grid columns)",
|
||||
"properties": {
|
||||
"Node name for S&R": "Int (grid columns)",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.18.1",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65,
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
2,
|
||||
"fixed"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 230,
|
||||
"type": "GetImageSize",
|
||||
"pos": [
|
||||
-1380,
|
||||
290
|
||||
],
|
||||
"size": [
|
||||
230,
|
||||
100
|
||||
],
|
||||
"flags": {},
|
||||
"order": 3,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "image",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"link": 389
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "width",
|
||||
"name": "width",
|
||||
"type": "INT",
|
||||
"links": [
|
||||
387
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "height",
|
||||
"name": "height",
|
||||
"type": "INT",
|
||||
"links": [
|
||||
390
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "batch_size",
|
||||
"name": "batch_size",
|
||||
"type": "INT",
|
||||
"links": null
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "GetImageSize",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.18.1",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65,
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 252,
|
||||
"type": "PrimitiveInt",
|
||||
"pos": [
|
||||
-1380,
|
||||
470
|
||||
],
|
||||
"size": [
|
||||
230,
|
||||
110
|
||||
],
|
||||
"flags": {},
|
||||
"order": 5,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "value",
|
||||
"name": "value",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "value"
|
||||
},
|
||||
"link": 428
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "INT",
|
||||
"name": "INT",
|
||||
"type": "INT",
|
||||
"links": [
|
||||
429
|
||||
]
|
||||
}
|
||||
],
|
||||
"title": "Int (grid rows)",
|
||||
"properties": {
|
||||
"Node name for S&R": "Int (grid rows)",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.18.1",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65,
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
3,
|
||||
"fixed"
|
||||
]
|
||||
}
|
||||
],
|
||||
"groups": [],
|
||||
"links": [
|
||||
{
|
||||
"id": 403,
|
||||
"origin_id": 229,
|
||||
"origin_slot": 1,
|
||||
"target_id": 225,
|
||||
"target_slot": 1,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 404,
|
||||
"origin_id": 231,
|
||||
"origin_slot": 1,
|
||||
"target_id": 225,
|
||||
"target_slot": 2,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 390,
|
||||
"origin_id": 230,
|
||||
"origin_slot": 1,
|
||||
"target_id": 231,
|
||||
"target_slot": 0,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 387,
|
||||
"origin_id": 230,
|
||||
"origin_slot": 0,
|
||||
"target_id": 229,
|
||||
"target_slot": 0,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 388,
|
||||
"origin_id": 228,
|
||||
"origin_slot": 0,
|
||||
"target_id": 229,
|
||||
"target_slot": 1,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 386,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 225,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 389,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 230,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 394,
|
||||
"origin_id": 225,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 427,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 1,
|
||||
"target_id": 228,
|
||||
"target_slot": 0,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 428,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 2,
|
||||
"target_id": 252,
|
||||
"target_slot": 0,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 429,
|
||||
"origin_id": 252,
|
||||
"origin_slot": 0,
|
||||
"target_id": 231,
|
||||
"target_slot": 1,
|
||||
"type": "INT"
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Image Tools/Crop",
|
||||
"description": "Splits an image into a configurable columns×rows grid of equal tiles for tiled generation or processing."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
1085
blueprints/Text to Image (Anima).json
Normal file
1085
blueprints/Text to Image (Anima).json
Normal file
File diff suppressed because it is too large
Load Diff
@ -307,9 +307,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Text generation/Video Captioning",
|
||||
"category": "Video Tools",
|
||||
"description": "Generates descriptive captions for video input using Google's Gemini multimodal LLM."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
1226
blueprints/Video Depth Estimation (MoGe).json
Normal file
1226
blueprints/Video Depth Estimation (MoGe).json
Normal file
File diff suppressed because it is too large
Load Diff
1109
blueprints/Video Face Detection (Mediapipe).json
Normal file
1109
blueprints/Video Face Detection (Mediapipe).json
Normal file
File diff suppressed because it is too large
Load Diff
4340
blueprints/Video Inpaint (VOID).json
Normal file
4340
blueprints/Video Inpaint (VOID).json
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
4196
blueprints/Video Inpainting (Wan2.1 VACE).json
Normal file
4196
blueprints/Video Inpainting (Wan2.1 VACE).json
Normal file
File diff suppressed because it is too large
Load Diff
@ -818,7 +818,7 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Video Tools",
|
||||
"category": "Conditioning & Preprocessors/Segmentation & Mask",
|
||||
"description": "Segments video into temporally consistent masks using Meta SAM3 from text or interactive prompts."
|
||||
}
|
||||
]
|
||||
|
||||
@ -412,7 +412,7 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Enhance video",
|
||||
"category": "Video generation and editing/Upscale",
|
||||
"description": "Upscales video to 4× resolution using a GAN-based upscaling model."
|
||||
}
|
||||
]
|
||||
|
||||
1323
blueprints/Video to Pose Map (SDPose Multi-Person).json
Normal file
1323
blueprints/Video to Pose Map (SDPose Multi-Person).json
Normal file
File diff suppressed because it is too large
Load Diff
@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
|
||||
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use, as a comma-separated list (e.g. '0' or '0,1'). All other devices will not be visible.")
|
||||
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
|
||||
cm_group = parser.add_mutually_exclusive_group()
|
||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
"""Comfy-specific type hinting"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Literal, TypedDict, Optional
|
||||
from typing_extensions import NotRequired
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
@ -15,13 +15,14 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from enum import Enum
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_detection
|
||||
@ -38,7 +39,7 @@ import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.ldm.qwen_image.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Union
|
||||
if TYPE_CHECKING:
|
||||
from comfy.hooks import HookGroup
|
||||
|
||||
@ -64,6 +65,18 @@ class StrengthType(Enum):
|
||||
CONSTANT = 1
|
||||
LINEAR_UP = 2
|
||||
|
||||
class ControlIsolation:
|
||||
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
|
||||
def __init__(self, control: ControlBase):
|
||||
self.control = control
|
||||
self.orig_previous_controlnet = control.previous_controlnet
|
||||
|
||||
def __enter__(self):
|
||||
self.control.previous_controlnet = None
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.control.previous_controlnet = self.orig_previous_controlnet
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self):
|
||||
self.cond_hint_original = None
|
||||
@ -77,7 +90,7 @@ class ControlBase:
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
self.extra_args = {}
|
||||
self.previous_controlnet = None
|
||||
self.previous_controlnet: Union[ControlBase, None] = None
|
||||
self.extra_conds = []
|
||||
self.strength_type = StrengthType.CONSTANT
|
||||
self.concat_mask = False
|
||||
@ -85,6 +98,7 @@ class ControlBase:
|
||||
self.extra_concat = None
|
||||
self.extra_hooks: HookGroup = None
|
||||
self.preprocess_image = lambda a: a
|
||||
self.multigpu_clones: dict[torch.device, ControlBase] = {}
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||
self.cond_hint_original = cond_hint
|
||||
@ -111,17 +125,38 @@ class ControlBase:
|
||||
def cleanup(self):
|
||||
if self.previous_controlnet is not None:
|
||||
self.previous_controlnet.cleanup()
|
||||
|
||||
for device_cnet in self.multigpu_clones.values():
|
||||
with ControlIsolation(device_cnet):
|
||||
device_cnet.cleanup()
|
||||
self.cond_hint = None
|
||||
self.extra_concat = None
|
||||
self.timestep_range = None
|
||||
|
||||
def get_models(self):
|
||||
out = []
|
||||
for device_cnet in self.multigpu_clones.values():
|
||||
out += device_cnet.get_models_only_self()
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
def get_models_only_self(self):
|
||||
'Calls get_models, but temporarily sets previous_controlnet to None.'
|
||||
with ControlIsolation(self):
|
||||
return self.get_models()
|
||||
|
||||
def get_instance_for_device(self, device):
|
||||
'Returns instance of this Control object intended for selected device.'
|
||||
return self.multigpu_clones.get(device, self)
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
'''
|
||||
Create deep clone of Control object where model(s) is set to other devices.
|
||||
|
||||
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
|
||||
'''
|
||||
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
|
||||
|
||||
def get_extra_hooks(self):
|
||||
out = []
|
||||
if self.extra_hooks is not None:
|
||||
@ -130,7 +165,7 @@ class ControlBase:
|
||||
out += self.previous_controlnet.get_extra_hooks()
|
||||
return out
|
||||
|
||||
def copy_to(self, c):
|
||||
def copy_to(self, c: ControlBase):
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
c.strength = self.strength
|
||||
c.timestep_percent_range = self.timestep_percent_range
|
||||
@ -284,6 +319,14 @@ class ControlNet(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
c = self.copy()
|
||||
c.control_model = copy.deepcopy(c.control_model)
|
||||
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
if autoregister:
|
||||
self.multigpu_clones[load_device] = c
|
||||
return c
|
||||
|
||||
def get_models(self):
|
||||
out = super().get_models()
|
||||
out.append(self.control_model_wrapped)
|
||||
@ -314,6 +357,10 @@ class QwenFunControlNet(ControlNet):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
self.set_extra_arg("base_model", model.diffusion_model)
|
||||
|
||||
def cleanup(self):
|
||||
self.extra_args.pop("base_model", None)
|
||||
super().cleanup()
|
||||
|
||||
def copy(self):
|
||||
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
@ -906,6 +953,14 @@ class T2IAdapter(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
c = self.copy()
|
||||
c.t2i_model = copy.deepcopy(c.t2i_model)
|
||||
c.device = load_device
|
||||
if autoregister:
|
||||
self.multigpu_clones[load_device] = c
|
||||
return c
|
||||
|
||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
compression_ratio = 8
|
||||
upscale_algorithm = 'nearest-exact'
|
||||
|
||||
@ -799,13 +799,15 @@ class ZImagePixelSpace(ChromaRadiance):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class HiDreamO1Pixel(ChromaRadiance):
|
||||
"""Pixel-space latent format for HiDream-O1.
|
||||
No VAE — model patches/unpatches raw RGB internally with patch_size=32.
|
||||
"""
|
||||
pass
|
||||
|
||||
class PixelDiTPixel(ChromaRadiance):
|
||||
pass
|
||||
|
||||
class CogVideoX(LatentFormat):
|
||||
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
|
||||
|
||||
|
||||
@ -607,9 +607,13 @@ class HunYuanDiTPlain(nn.Module):
|
||||
def forward(self, x, t, context, transformer_options = {}, **kwargs):
|
||||
|
||||
x = x.movedim(-1, -2)
|
||||
if context.shape[0] >= 2:
|
||||
uncond_emb, cond_emb = context.chunk(2, dim = 0)
|
||||
context = torch.cat([cond_emb, uncond_emb], dim = 0)
|
||||
|
||||
swap_cfg_halves = context.shape[0] >= 2
|
||||
|
||||
if swap_cfg_halves:
|
||||
first_half, second_half = context.chunk(2, dim = 0)
|
||||
context = torch.cat([second_half, first_half], dim = 0)
|
||||
|
||||
main_condition = context
|
||||
|
||||
t = 1.0 - t
|
||||
@ -657,8 +661,8 @@ class HunYuanDiTPlain(nn.Module):
|
||||
output = self.final_layer(combined)
|
||||
output = output.movedim(-2, -1) * (-1.0)
|
||||
|
||||
if output.shape[0] >= 2:
|
||||
cond_emb, uncond_emb = output.chunk(2, dim = 0)
|
||||
return torch.cat([uncond_emb, cond_emb])
|
||||
else:
|
||||
return output
|
||||
if swap_cfg_halves:
|
||||
first_half, second_half = output.chunk(2, dim = 0)
|
||||
output = torch.cat([second_half, first_half], dim = 0)
|
||||
|
||||
return output
|
||||
|
||||
510
comfy/ldm/lens/model.py
Normal file
510
comfy/ldm/lens/model.py
Normal file
@ -0,0 +1,510 @@
|
||||
"""Lens denoising transformer (DiT)"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ldm.flux.layers
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor:
|
||||
return comfy.ldm.flux.layers.timestep_embedding(t, dim)
|
||||
|
||||
|
||||
def _lens_position_ids(
|
||||
frame: int, height: int, width: int, text_seq_len: int,
|
||||
scale_rope: bool = True, device=None,
|
||||
) -> torch.Tensor:
|
||||
"""Lens axial (frame, h, w) position ids for joint image + text sequence.
|
||||
|
||||
With ``scale_rope=True`` h/w are centered around 0 (negative + positive
|
||||
halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``;
|
||||
caller adds a batch dim for ``EmbedND``.
|
||||
"""
|
||||
if scale_rope:
|
||||
h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device),
|
||||
torch.arange(0, height // 2, device=device)])
|
||||
w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device),
|
||||
torch.arange(0, width // 2, device=device)])
|
||||
text_start = max(height // 2, width // 2)
|
||||
else:
|
||||
h_pos = torch.arange(height, device=device)
|
||||
w_pos = torch.arange(width, device=device)
|
||||
text_start = max(height, width)
|
||||
|
||||
f_pos = torch.arange(frame, device=device)
|
||||
img_ids = torch.zeros(frame, height, width, 3, device=device)
|
||||
img_ids[..., 0] = f_pos[:, None, None]
|
||||
img_ids[..., 1] = h_pos[None, :, None]
|
||||
img_ids[..., 2] = w_pos[None, None, :]
|
||||
img_ids = img_ids.reshape(-1, 3)
|
||||
|
||||
# Text positions replicate across all 3 axes (matches original packing).
|
||||
txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float()
|
||||
txt_ids = txt_pos[:, None].expand(text_seq_len, 3)
|
||||
|
||||
return torch.cat([img_ids, txt_ids], dim=0)
|
||||
|
||||
|
||||
class _TimestepEmbedder(nn.Module):
|
||||
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear_1(x)
|
||||
x = F.silu(x)
|
||||
return self.linear_2(x)
|
||||
|
||||
|
||||
class LensTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim: int, dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.timestep_embedder = _TimestepEmbedder(256, embedding_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
proj = _lens_time_proj(timestep, 256)
|
||||
return self.timestep_embedder(proj.to(dtype=hidden_states.dtype))
|
||||
|
||||
|
||||
class GateMLP(nn.Module):
|
||||
"""SwiGLU MLP."""
|
||||
|
||||
def __init__(self, dim: int, hidden_dim: int, dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
||||
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x), inplace=True).mul_(self.w3(x)))
|
||||
|
||||
|
||||
class LensJointAttention(nn.Module):
|
||||
"""Joint image+text attention with fused QKV per stream."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
added_kv_proj_dim: int,
|
||||
dim_head: int = 64,
|
||||
heads: int = 8,
|
||||
out_dim: Optional[int] = None,
|
||||
eps: float = 1e-5,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.heads = self.inner_dim // dim_head
|
||||
self.dim_head = dim_head
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
|
||||
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
|
||||
self.img_qkv = operations.Linear(query_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.txt_qkv = operations.Linear(added_kv_proj_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
# ModuleList([Linear, Identity]) for state-dict key compatibility.
|
||||
self.to_out = nn.ModuleList([
|
||||
operations.Linear(self.inner_dim, self.out_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.Identity(),
|
||||
])
|
||||
self.to_add_out = operations.Linear(self.inner_dim, query_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
bsz, seq_img, _ = hidden_states.shape
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
# image stream
|
||||
img_qkv = self.img_qkv(hidden_states).view(bsz, seq_img, 3, self.heads, self.dim_head)
|
||||
img_q, img_k, img_v = img_qkv.unbind(dim=2)
|
||||
img_q = self.norm_q(img_q)
|
||||
img_k = self.norm_k(img_k)
|
||||
del img_qkv
|
||||
|
||||
# text stream
|
||||
txt_qkv = self.txt_qkv(encoder_hidden_states).view(bsz, seq_txt, 3, self.heads, self.dim_head)
|
||||
txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2)
|
||||
txt_q = self.norm_added_q(txt_q)
|
||||
txt_k = self.norm_added_k(txt_k)
|
||||
|
||||
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
|
||||
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat([img_k, txt_k], dim=1).transpose(1, 2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
|
||||
del img_v, txt_v
|
||||
|
||||
q, k = apply_rope(q, k, freqs_cis)
|
||||
|
||||
if attention_mask is not None:
|
||||
expected = (bsz, 1, 1, seq_img + seq_txt)
|
||||
if attention_mask.shape != expected:
|
||||
raise ValueError(
|
||||
f"attention_mask must be {expected}, got {tuple(attention_mask.shape)}"
|
||||
)
|
||||
attention_mask = attention_mask.to(q.dtype)
|
||||
|
||||
out = optimized_attention(
|
||||
q, k, v, self.heads, mask=attention_mask, skip_reshape=True,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
img_out = self.to_out[1](self.to_out[0](out[:, :seq_img, :]))
|
||||
txt_out = self.to_add_out(out[:, seq_img:, :])
|
||||
return img_out, txt_out
|
||||
|
||||
|
||||
class LensTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
eps: float = 1e-6,
|
||||
rms_norm: bool = True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.attn = LensJointAttention(
|
||||
query_dim=dim,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
eps=1e-5,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
if rms_norm:
|
||||
NormCls = operations.RMSNorm
|
||||
norm_kwargs = {}
|
||||
else:
|
||||
NormCls = operations.LayerNorm
|
||||
norm_kwargs = {"elementwise_affine": False}
|
||||
|
||||
mlp_hidden = int(dim / 3 * 8)
|
||||
|
||||
# Sequential(SiLU, Linear) so state-dict lands at img_mod.1.{weight,bias}.
|
||||
self.img_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.img_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
||||
self.img_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
||||
self.img_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.txt_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
||||
self.txt_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
||||
self.txt_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
@staticmethod
|
||||
def _modulate(x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(temb).chunk(2, dim=-1)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(temb).chunk(2, dim=-1)
|
||||
|
||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
||||
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||
|
||||
img_attn, txt_attn = self.attn(
|
||||
hidden_states=img_modulated,
|
||||
encoder_hidden_states=txt_modulated,
|
||||
freqs_cis=freqs_cis,
|
||||
attention_mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + img_gate1 * img_attn
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn
|
||||
|
||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
||||
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
|
||||
|
||||
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class _AdaLayerNormContinuousNoAffine(nn.Module):
|
||||
"""AdaLayerNormContinuous(elementwise_affine=False).
|
||||
|
||||
The reference uses ``scale, shift = chunk(2)`` (scale first) — opposite
|
||||
to Flux's ``LastLayer``.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, eps: float = 1e-6,
|
||||
dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.linear = operations.Linear(
|
||||
conditioning_embedding_dim, embedding_dim * 2, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.eps = eps
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear(F.silu(conditioning))
|
||||
scale, shift = torch.chunk(emb, 2, dim=-1)
|
||||
x = F.layer_norm(x, (self.embedding_dim,), None, None, self.eps)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class LensTransformer2DModel(nn.Module):
|
||||
"""Lens dual-stream MMDiT (48 blocks, inner_dim=1536, multi-layer text)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 128,
|
||||
out_channels: Optional[int] = 32,
|
||||
num_layers: int = 48,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 24,
|
||||
enc_hidden_dim: int = 2880,
|
||||
axes_dims_rope: Tuple[int, int, int] = (8, 28, 28),
|
||||
rms_norm: bool = True,
|
||||
multi_layer_encoder_feature: bool = True,
|
||||
selected_layer_index: Tuple[int, ...] = (5, 11, 17, 23),
|
||||
image_model=None, # unused; accepted for detection-side configs.
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels if out_channels is not None else in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.multi_layer_encoder_feature = multi_layer_encoder_feature
|
||||
self.selected_layer_index = list(selected_layer_index)
|
||||
self.dtype = dtype
|
||||
|
||||
self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||
self.time_text_embed = LensTimestepProjEmbeddings(
|
||||
embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
if self.multi_layer_encoder_feature:
|
||||
self.txt_norm = nn.ModuleList(
|
||||
[operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
|
||||
for _ in self.selected_layer_index]
|
||||
)
|
||||
self.txt_in = operations.Linear(
|
||||
enc_hidden_dim * len(self.selected_layer_index),
|
||||
self.inner_dim, bias=True, dtype=dtype, device=device,
|
||||
)
|
||||
else:
|
||||
self.txt_norm = operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
|
||||
self.txt_in = operations.Linear(enc_hidden_dim, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.img_in = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
LensTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
eps=1e-6,
|
||||
rms_norm=rms_norm,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_out = _AdaLayerNormContinuousNoAffine(
|
||||
self.inner_dim, self.inner_dim, eps=1e-6,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.proj_out = operations.Linear(
|
||||
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True,
|
||||
dtype=dtype, device=device,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None, **kwargs) -> torch.Tensor:
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward, self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
||||
).execute(x, timestep, context, attention_mask, transformer_options, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None,
|
||||
control: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""ComfyUI bridge: ``(x[B,128,h,w], t[B], context[B,S,L*H], mask[B,S])``."""
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
transformer_options = transformer_options.copy()
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
B, C, h, w = x.shape
|
||||
hidden_states = x.permute(0, 2, 3, 1).reshape(B, h * w, C)
|
||||
|
||||
if self.multi_layer_encoder_feature:
|
||||
L = len(self.selected_layer_index)
|
||||
enc_dim = context.shape[-1] // L
|
||||
encoder_hidden_states = list(
|
||||
context.reshape(B, -1, L, enc_dim).unbind(dim=2)
|
||||
)
|
||||
text_seq_len = encoder_hidden_states[0].shape[1]
|
||||
else:
|
||||
encoder_hidden_states = context
|
||||
text_seq_len = context.shape[1]
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(B, text_seq_len), dtype=torch.bool, device=x.device
|
||||
)
|
||||
|
||||
img_len = h * w
|
||||
joint_mask = self._build_joint_attention_mask(attention_mask, img_len)
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
|
||||
if self.multi_layer_encoder_feature:
|
||||
normed = [self.txt_norm[i](encoder_hidden_states[i]) for i in range(L)]
|
||||
encoder_hidden_states = torch.cat(normed, dim=-1)
|
||||
else:
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({
|
||||
"img": hidden_states,
|
||||
"txt": encoder_hidden_states,
|
||||
"transformer_options": transformer_options,
|
||||
})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
|
||||
temb = self.time_text_embed(timestep, hidden_states)
|
||||
ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0)
|
||||
freqs_cis = self.pos_embed(ids)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(
|
||||
hidden_states=args["img"],
|
||||
encoder_hidden_states=args["txt"],
|
||||
temb=args["vec"],
|
||||
freqs_cis=args["pe"],
|
||||
attention_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"),
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)](
|
||||
{
|
||||
"img": hidden_states,
|
||||
"txt": encoder_hidden_states,
|
||||
"vec": temb,
|
||||
"pe": freqs_cis,
|
||||
"attn_mask": joint_mask,
|
||||
"transformer_options": transformer_options,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
encoder_hidden_states = out["txt"]
|
||||
hidden_states = out["img"]
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
freqs_cis=freqs_cis,
|
||||
attention_mask=joint_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({
|
||||
"img": hidden_states,
|
||||
"txt": encoder_hidden_states,
|
||||
"x": x,
|
||||
"block_index": i,
|
||||
"transformer_options": transformer_options,
|
||||
})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
|
||||
if control is not None:
|
||||
control_i = control.get("input")
|
||||
if control_i is not None and i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
hidden_states[:, :add.shape[1]] += add
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
out = self.proj_out(hidden_states)
|
||||
return out.reshape(B, h, w, C).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
@staticmethod
|
||||
def _build_joint_attention_mask(text_mask: torch.Tensor, img_len: int) -> torch.Tensor:
|
||||
if text_mask.dtype != torch.bool:
|
||||
text_mask = text_mask.bool()
|
||||
bsz = text_mask.shape[0]
|
||||
img_ones = torch.ones((bsz, img_len), dtype=torch.bool, device=text_mask.device)
|
||||
joint = torch.cat([img_ones, text_mask], dim=1)
|
||||
additive = torch.zeros_like(joint, dtype=torch.float32)
|
||||
additive.masked_fill_(~joint, torch.finfo(torch.float32).min)
|
||||
return additive[:, None, None, :]
|
||||
@ -767,25 +767,25 @@ class LTXAVModel(LTXVModel):
|
||||
|
||||
# Cross-attention timesteps - compress these too
|
||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||
timestep.max().expand_as(a_timestep_flat),
|
||||
a_timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||
a_timestep.max().expand_as(timestep_flat),
|
||||
timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||
a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
|
||||
a_timestep_scaled.max().expand_as(timestep_flat) * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
|
||||
timestep_scaled.max().expand_as(a_timestep_flat) * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
import threading
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
||||
@ -211,7 +211,7 @@ class TimestepEmbedder(nn.Module):
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None, max_period=10000):
|
||||
super().__init__()
|
||||
if output_size is None:
|
||||
output_size = hidden_size
|
||||
@ -221,9 +221,10 @@ class TimestepEmbedder(nn.Module):
|
||||
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.max_period = max_period
|
||||
|
||||
def forward(self, t, dtype, **kwargs):
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size, max_period=self.max_period).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
"""Pure-torch + scipy geometry helpers for MoGe inference and mesh export."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ V1: DINOv2 backbone + multi-output head (points, mask).
|
||||
V2: DINOv2 encoder + neck + per-output heads (points, mask, normal, optional metric-scale MLP).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
"""Building blocks for MoGe: residual conv stack, resamplers, MLP, DINOv2 encoder, v1 head."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ equirect distance map via a multi-scale Poisson + gradient sparse solve.
|
||||
Image sampling uses F.grid_sample (GPU); the sparse solve uses lsmr (CPU).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
|
||||
239
comfy/ldm/pixeldit/model.py
Normal file
239
comfy/ldm/pixeldit/model.py
Normal file
@ -0,0 +1,239 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.flux.math import apply_rope, rope
|
||||
from comfy.ldm.hidream.model import FeedForwardSwiGLU
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||
|
||||
from .modules import (
|
||||
FinalLayer,
|
||||
PatchTokenEmbedder,
|
||||
PiTBlock,
|
||||
PixelTokenEmbedder,
|
||||
apply_adaln_,
|
||||
precompute_freqs_cis_2d,
|
||||
)
|
||||
|
||||
|
||||
class MMDiTJointAttention(nn.Module):
|
||||
"""Joint MMDiT attention with separate Q/K/V/proj for image and text streams.
|
||||
|
||||
RoPE is applied to each stream before concatenation so each stream uses its own
|
||||
2D/1D positional encoding. Concat order is [text, image] (text first).
|
||||
"""
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.qkv_x = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.qkv_y = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
|
||||
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.proj_x = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
self.proj_y = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, y, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
|
||||
B, Nx, _ = x.shape
|
||||
_, Ny, _ = y.shape
|
||||
H = self.num_heads
|
||||
D = self.head_dim
|
||||
|
||||
qkv_x = self.qkv_x(x).reshape(B, Nx, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||
qx, kx, vx = qkv_x.unbind(0)
|
||||
qx = self.q_norm_x(qx)
|
||||
kx = self.k_norm_x(kx)
|
||||
|
||||
qkv_y = self.qkv_y(y).reshape(B, Ny, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||
qy, ky, vy = qkv_y.unbind(0)
|
||||
qy = self.q_norm_y(qy)
|
||||
ky = self.k_norm_y(ky)
|
||||
|
||||
qx, kx = apply_rope(qx, kx, pos_img[None, None])
|
||||
if pos_txt is not None:
|
||||
qy, ky = apply_rope(qy, ky, pos_txt[None, None])
|
||||
|
||||
q_joint = torch.cat([qy, qx], dim=2)
|
||||
k_joint = torch.cat([ky, kx], dim=2)
|
||||
v_joint = torch.cat([vy, vx], dim=2)
|
||||
|
||||
out_joint = optimized_attention(
|
||||
q_joint, k_joint, v_joint, H,
|
||||
mask=attn_mask, skip_reshape=True, skip_output_reshape=True,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
out_y = out_joint[:, :, :Ny, :].transpose(1, 2).reshape(B, Ny, H * D)
|
||||
out_x = out_joint[:, :, Ny:, :].transpose(1, 2).reshape(B, Nx, H * D)
|
||||
|
||||
return self.proj_x(out_x), self.proj_y(out_y)
|
||||
|
||||
|
||||
class MMDiTBlockT2I(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_x1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm_y1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = MMDiTJointAttention(hidden_size, num_heads=groups, qkv_bias=False, dtype=dtype, device=device, operations=operations)
|
||||
self.norm_x2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm_y2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp_x = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp_y = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
|
||||
self.adaLN_modulation_img = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
self.adaLN_modulation_txt = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, y, c, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
|
||||
shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1)
|
||||
shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1)
|
||||
|
||||
x_norm = apply_adaln_(self.norm_x1(x), shift_msa_x, scale_msa_x)
|
||||
y_norm = apply_adaln_(self.norm_y1(y), shift_msa_y, scale_msa_y)
|
||||
attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options)
|
||||
x = torch.addcmul(x, gate_msa_x, attn_x)
|
||||
y = torch.addcmul(y, gate_msa_y, attn_y)
|
||||
|
||||
x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln_(self.norm_x2(x), shift_mlp_x, scale_mlp_x)))
|
||||
y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln_(self.norm_y2(y), shift_mlp_y, scale_mlp_y)))
|
||||
return x, y
|
||||
|
||||
|
||||
class PixDiT_T2I(nn.Module):
|
||||
"""PixelDiT T2I model. Hardcoded for the released 1024px Stage-3 checkpoint
|
||||
(also runs at 512px when fed the appropriate latent size and flow_shift).
|
||||
|
||||
Forward:
|
||||
x: [B, 3, H, W] pixel-space input (no VAE)
|
||||
timesteps:[B] in [0, 1000] (ComfyUI flow sampling convention)
|
||||
context: [B, Ltxt, 2304] Gemma-2-2b-it hidden states (chi_prompt prepended)
|
||||
Returns flow-matching velocity [B, 3, H, W].
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
num_groups=24,
|
||||
hidden_size=1536,
|
||||
pixel_hidden_size=16,
|
||||
pixel_attn_hidden_size=1152,
|
||||
pixel_num_groups=16,
|
||||
patch_depth=14,
|
||||
pixel_depth=2,
|
||||
patch_size=16,
|
||||
txt_embed_dim=2304,
|
||||
txt_max_length=300,
|
||||
use_text_rope=True,
|
||||
text_rope_theta=10000.0,
|
||||
image_model=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
pixel_mlp_chunks=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.patch_depth = patch_depth
|
||||
self.pixel_depth = pixel_depth
|
||||
self.patch_size = patch_size
|
||||
self.pixel_hidden_size = pixel_hidden_size
|
||||
self.pixel_attn_hidden_size = pixel_attn_hidden_size
|
||||
self.pixel_num_groups = pixel_num_groups
|
||||
self.txt_embed_dim = txt_embed_dim
|
||||
self.txt_max_length = txt_max_length
|
||||
self.use_text_rope = use_text_rope
|
||||
self.text_rope_theta = text_rope_theta
|
||||
|
||||
self.pixel_embedder = PixelTokenEmbedder(self.in_channels, self.pixel_hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.s_embedder = PatchTokenEmbedder(self.in_channels * self.patch_size ** 2, self.hidden_size, bias=True, dtype=dtype, device=device, operations=operations)
|
||||
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations, max_period=10)
|
||||
self.y_embedder = PatchTokenEmbedder(self.txt_embed_dim, self.hidden_size, bias=True, use_norm=True, dtype=dtype, device=device, operations=operations)
|
||||
self.y_pos_embedding = nn.Parameter(torch.empty(1, self.txt_max_length, self.hidden_size, dtype=dtype, device=device))
|
||||
|
||||
self.patch_blocks = nn.ModuleList([
|
||||
MMDiTBlockT2I(self.hidden_size, self.num_groups,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(self.patch_depth)
|
||||
])
|
||||
self.pixel_blocks = nn.ModuleList([
|
||||
PiTBlock(
|
||||
self.pixel_hidden_size,
|
||||
self.hidden_size,
|
||||
patch_size=self.patch_size,
|
||||
num_heads=self.num_groups,
|
||||
attn_hidden_size=self.pixel_attn_hidden_size,
|
||||
attn_num_heads=self.pixel_num_groups,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
mlp_chunks=pixel_mlp_chunks,
|
||||
)
|
||||
for _ in range(self.pixel_depth)
|
||||
])
|
||||
|
||||
self.final_layer = FinalLayer(self.pixel_hidden_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
|
||||
return precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width, device=device, dtype=dtype, **rope_opts)
|
||||
|
||||
def _fetch_text_pos(self, length, device, dtype):
|
||||
return rope(torch.arange(length, dtype=torch.float32, device=device).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0).to(dtype=dtype)
|
||||
|
||||
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
||||
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
|
||||
|
||||
def _pre_patch_block(self, s, i, **kwargs):
|
||||
"""Hook for subclasses to inject per-block state into the patch stream (e.g. PiD's LQ gate)."""
|
||||
return s
|
||||
|
||||
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||
H_orig, W_orig = x.shape[2], x.shape[3]
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
B, _, H, W = x.shape
|
||||
Hs = H // self.patch_size
|
||||
Ws = W // self.patch_size
|
||||
L = Hs * Ws
|
||||
|
||||
pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
|
||||
x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
|
||||
t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size)
|
||||
|
||||
if context is None or context.dim() != 3:
|
||||
raise ValueError("PixDiT_T2I requires context (text embeddings) of shape [B, L, D]")
|
||||
Ltxt = min(context.shape[1], self.txt_max_length)
|
||||
y = context[:, :Ltxt, :]
|
||||
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
|
||||
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb) # y_pos_embedding is a raw nn.Parameter
|
||||
|
||||
condition = F.silu(t_emb)
|
||||
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
|
||||
|
||||
s = self.s_embedder(x_patches)
|
||||
for i, blk in enumerate(self.patch_blocks):
|
||||
s = self._pre_patch_block(s, i, **kwargs)
|
||||
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options)
|
||||
s = F.silu(t_emb + s)
|
||||
|
||||
s_cond = s.view(B * L, self.hidden_size)
|
||||
x_pixels = self.pixel_embedder(x, patch_size=self.patch_size)
|
||||
for blk in self.pixel_blocks:
|
||||
x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, transformer_options=transformer_options)
|
||||
|
||||
x_pixels = self.final_layer(x_pixels)
|
||||
C_out = self.out_channels
|
||||
P2 = self.patch_size * self.patch_size
|
||||
x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).reshape(B, C_out * P2, L)
|
||||
out = F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
return out[:, :, :H_orig, :W_orig]
|
||||
187
comfy/ldm/pixeldit/modules.py
Normal file
187
comfy/ldm/pixeldit/modules.py
Normal file
@ -0,0 +1,187 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.ldm.flux.math import apply_rope, rope
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, get_1d_sincos_pos_embed_from_grid_torch
|
||||
|
||||
|
||||
def apply_adaln_(x, shift, scale):
|
||||
return x.addcmul_(x, scale).add_(shift)
|
||||
|
||||
|
||||
def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0,
|
||||
ref_grid_h=None, ref_grid_w=None,
|
||||
scale_x=1.0, scale_y=1.0, shift_x=0.0, shift_y=0.0,
|
||||
device=None, dtype=torch.float32, **kwargs):
|
||||
"""2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim.
|
||||
|
||||
rope_options:
|
||||
scale_x / scale_y multiply the position range (RoPE extrapolation).
|
||||
shift_x / shift_y offset the position origin (tiled / regional inference).
|
||||
With ref_grid_h/w set, also applies NTK-aware per-axis theta scaling
|
||||
(rope_mode='ntk_aware'): theta_axis = theta * (current/ref)^(dim_axis/(dim_axis-2)).
|
||||
Returns Flux-format rotation matrices of shape [H*W, dim/2, 2, 2].
|
||||
Layout of head-dim pairs: [x_0, y_0, x_1, y_1, ..., x_{dim/4-1}, y_{dim/4-1}].
|
||||
"""
|
||||
dim_axis = dim // 2
|
||||
if ref_grid_h is not None and dim_axis > 2:
|
||||
h_ntk = (height / ref_grid_h) ** (dim_axis / (dim_axis - 2))
|
||||
w_ntk = (width / ref_grid_w) ** (dim_axis / (dim_axis - 2))
|
||||
else:
|
||||
h_ntk = w_ntk = 1.0
|
||||
|
||||
x_lin = torch.linspace(shift_x, scale * scale_x + shift_x, width, device=device)
|
||||
y_lin = torch.linspace(shift_y, scale * scale_y + shift_y, height, device=device)
|
||||
y_grid, x_grid = torch.meshgrid(y_lin, x_lin, indexing="ij")
|
||||
x_rope = rope(x_grid.reshape(1, -1), dim_axis, theta * w_ntk).squeeze(0)
|
||||
y_rope = rope(y_grid.reshape(1, -1), dim_axis, theta * h_ntk).squeeze(0)
|
||||
out = torch.stack([x_rope, y_rope], dim=2).reshape(height * width, dim // 2, 2, 2)
|
||||
return out.to(dtype=dtype)
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, height, width, device=None, dtype=torch.float32):
|
||||
"""Standard 2D sin/cos absolute positional embedding (ViT-style).
|
||||
|
||||
first half encodes W-coordinates, second half H.
|
||||
"""
|
||||
assert embed_dim % 4 == 0
|
||||
grid_h = torch.arange(height, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(width, dtype=torch.float32, device=device)
|
||||
grid_y, grid_x = torch.meshgrid(grid_h, grid_w, indexing="ij")
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_x.reshape(-1), device=device)
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_y.reshape(-1), device=device)
|
||||
return torch.cat([emb_w, emb_h], dim=1).to(dtype=dtype)
|
||||
|
||||
|
||||
class RotaryAttention(nn.Module):
|
||||
"""Single-stream self-attention with rotary positional encoding (used inside PiTBlock)."""
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, pos, mask=None, transformer_options={}):
|
||||
B, N, C = x.shape
|
||||
H = self.num_heads
|
||||
D = self.head_dim
|
||||
qkv = self.qkv(x).reshape(B, N, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = apply_rope(self.q_norm(q), self.k_norm(k), pos[None, None])
|
||||
x = optimized_attention(q, k, v, H, mask=mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.norm(x))
|
||||
|
||||
|
||||
class PatchTokenEmbedder(nn.Module):
|
||||
"""Linear projection used both for patchified-image tokens and text-feature tokens."""
|
||||
def __init__(self, in_chans, embed_dim, use_norm=False, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(in_chans, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.norm = operations.RMSNorm(embed_dim, eps=1e-6, dtype=dtype, device=device) if use_norm else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(self.proj(x))
|
||||
|
||||
|
||||
class PixelTokenEmbedder(nn.Module):
|
||||
"""Pixel-level embedder: lifts each RGB pixel to hidden_size and packs into per-patch sequences."""
|
||||
def __init__(self, in_channels, hidden_size_output, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_size_output = hidden_size_output
|
||||
self.proj = operations.Linear(self.in_channels, self.hidden_size_output, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, inputs, patch_size):
|
||||
B, _, H, W = inputs.shape
|
||||
Hs, Ws = H // patch_size, W // patch_size
|
||||
P2 = patch_size * patch_size
|
||||
x = inputs.permute(0, 2, 3, 1).contiguous()
|
||||
x = self.proj(x)
|
||||
pos_full = get_2d_sincos_pos_embed(self.hidden_size_output, H, W, device=x.device, dtype=x.dtype).view(H, W, self.hidden_size_output)
|
||||
x = x + pos_full.unsqueeze(0)
|
||||
x = x.view(B, Hs, patch_size, Ws, patch_size, self.hidden_size_output)
|
||||
return x.permute(0, 1, 3, 2, 4, 5).reshape(B * Hs * Ws, P2, self.hidden_size_output)
|
||||
|
||||
|
||||
class PiTBlock(nn.Module):
|
||||
"""Pixel-level transformer block.
|
||||
|
||||
Compresses each patch's P^2 pixel tokens → 1 attention token via a linear,
|
||||
runs global self-attention across patches with 2D RoPE, then expands back to P^2 tokens.
|
||||
Conditioning is per-pixel adaLN from the patch-level features.
|
||||
"""
|
||||
def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0,
|
||||
attn_hidden_size=None, attn_num_heads=None, dtype=None, device=None, operations=None, mlp_chunks=1):
|
||||
super().__init__()
|
||||
self.pixel_dim = pixel_hidden_size
|
||||
self.context_dim = patch_hidden_size
|
||||
self.attn_dim = attn_hidden_size if attn_hidden_size is not None else patch_hidden_size
|
||||
self.num_heads = attn_num_heads if attn_num_heads is not None else num_heads
|
||||
assert self.attn_dim % self.num_heads == 0
|
||||
|
||||
p2 = patch_size * patch_size
|
||||
self.compress_to_attn = operations.Linear(p2 * self.pixel_dim, self.attn_dim, bias=True, dtype=dtype, device=device)
|
||||
self.expand_from_attn = operations.Linear(self.attn_dim, p2 * self.pixel_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.norm1 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = RotaryAttention(self.attn_dim, num_heads=self.num_heads, qkv_bias=False, dtype=dtype, device=device, operations=operations)
|
||||
self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio), dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.adaLN_modulation_msa = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation_mlp = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self._rope_fn = precompute_freqs_cis_2d
|
||||
self.mlp_chunks = max(1, int(mlp_chunks))
|
||||
|
||||
def _fetch_pos(self, height, width, device, dtype, **rope_opts):
|
||||
return self._rope_fn(self.attn_dim // self.num_heads, height, width, device=device, dtype=dtype, **rope_opts)
|
||||
|
||||
def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}):
|
||||
BL, P2, _ = x.shape
|
||||
Hs, Ws = image_height // patch_size, image_width // patch_size
|
||||
L = Hs * Ws
|
||||
B = BL // L
|
||||
|
||||
# Attention path uses only msa params; compute, use, free before mlp params allocate.
|
||||
msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim)
|
||||
shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1)
|
||||
|
||||
x_norm = apply_adaln_(self.norm1(x), shift_msa, scale_msa)
|
||||
x_flat = x_norm.view(BL, P2 * self.pixel_dim)
|
||||
|
||||
x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim)
|
||||
pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
|
||||
attn_out = self.attn(x_comp, pos_comp, mask=mask, transformer_options=transformer_options)
|
||||
attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim))
|
||||
attn_exp = attn_flat.view(BL, P2, self.pixel_dim)
|
||||
x = torch.addcmul(x, gate_msa, attn_exp)
|
||||
del msa_params, shift_msa, scale_msa, gate_msa
|
||||
|
||||
mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim)
|
||||
shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1)
|
||||
gate_mlp = gate_mlp.contiguous() # detach from mlp_params so the del below frees shift+scale storage before the MLP
|
||||
mlp_input = apply_adaln_(self.norm2(x), shift_mlp, scale_mlp)
|
||||
del mlp_params, shift_mlp, scale_mlp
|
||||
|
||||
# MLP in chunks since the peak memory usage is huge here
|
||||
chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks
|
||||
for s in range(0, BL, chunk_size):
|
||||
e = min(s + chunk_size, BL)
|
||||
x[s:e].addcmul_(gate_mlp[s:e], self.mlp(mlp_input[s:e]))
|
||||
return x
|
||||
226
comfy/ldm/pixeldit/pid.py
Normal file
226
comfy/ldm/pixeldit/pid.py
Normal file
@ -0,0 +1,226 @@
|
||||
"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent
|
||||
directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I
|
||||
body + LQ projection branch injected before each MMDiT patch block.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .model import PixDiT_T2I
|
||||
from .modules import precompute_freqs_cis_2d
|
||||
|
||||
|
||||
class SigmaAwareGatePerTokenPerDim(nn.Module):
|
||||
"""gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq.
|
||||
|
||||
Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device)
|
||||
self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
||||
content_logit = self.content_proj(torch.cat([x, lq], dim=-1))
|
||||
# log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM.
|
||||
log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32)
|
||||
sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1)
|
||||
gate = torch.sigmoid(content_logit + sigma_offset)
|
||||
return x + (gate * lq).to(x.dtype)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip."""
|
||||
|
||||
def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.block(x)
|
||||
|
||||
|
||||
class LQProjection2D(nn.Module):
|
||||
"""LQ latent -> per-block patch-aligned features for controlnet-style injection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
latent_channels: int,
|
||||
hidden_dim: int = 512,
|
||||
out_dim: int = 1536,
|
||||
patch_size: int = 16,
|
||||
sr_scale: int = 4,
|
||||
latent_spatial_down_factor: int = 8,
|
||||
num_res_blocks: int = 4,
|
||||
num_outputs: int = 7,
|
||||
interval: int = 2,
|
||||
dtype=None, device=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.latent_channels = latent_channels
|
||||
self.hidden_dim = hidden_dim
|
||||
self.out_dim = out_dim
|
||||
self.patch_size = patch_size
|
||||
self.sr_scale = sr_scale
|
||||
self.latent_spatial_down_factor = latent_spatial_down_factor
|
||||
self.num_outputs = num_outputs
|
||||
self.interval = interval
|
||||
|
||||
z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size
|
||||
self.z_to_patch_ratio = z_to_patch_ratio
|
||||
if z_to_patch_ratio >= 1:
|
||||
self.latent_fold_factor = 0
|
||||
latent_proj_in_ch = latent_channels
|
||||
else:
|
||||
fold_factor = int(1 / z_to_patch_ratio)
|
||||
assert fold_factor * z_to_patch_ratio == 1.0
|
||||
self.latent_fold_factor = fold_factor
|
||||
latent_proj_in_ch = latent_channels * fold_factor * fold_factor
|
||||
|
||||
layers = [
|
||||
operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
]
|
||||
for _ in range(num_res_blocks):
|
||||
layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations))
|
||||
self.latent_proj = nn.Sequential(*layers)
|
||||
|
||||
self.output_heads = nn.ModuleList(
|
||||
[operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)]
|
||||
)
|
||||
self.gate_modules = nn.ModuleList(
|
||||
[SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_outputs)]
|
||||
)
|
||||
|
||||
def is_gate_active(self, block_idx: int) -> bool:
|
||||
return block_idx % self.interval == 0
|
||||
|
||||
def output_index(self, block_idx: int) -> int:
|
||||
return block_idx // self.interval
|
||||
|
||||
def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor:
|
||||
return self.gate_modules[out_idx](x, lq_feature, sigma)
|
||||
|
||||
def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor:
|
||||
B, z_dim = lq_latent.shape[:2]
|
||||
if self.z_to_patch_ratio >= 1:
|
||||
if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW:
|
||||
z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest")
|
||||
else:
|
||||
z_aligned = lq_latent
|
||||
else:
|
||||
f = self.latent_fold_factor
|
||||
zH_expected, zW_expected = pH * f, pW * f
|
||||
if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected:
|
||||
lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest")
|
||||
z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4)
|
||||
z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW)
|
||||
return self.latent_proj(z_aligned)
|
||||
|
||||
def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]:
|
||||
feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW)
|
||||
B, C, H, W = feat.shape
|
||||
tokens = feat.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)
|
||||
return [head(tokens) for head in self.output_heads]
|
||||
|
||||
|
||||
class PidNet(PixDiT_T2I):
|
||||
"""PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lq_latent_channels: int = 16,
|
||||
lq_hidden_dim: int = 512,
|
||||
lq_num_res_blocks: int = 4,
|
||||
lq_interval: int = 2,
|
||||
sr_scale: int = 4,
|
||||
latent_spatial_down_factor: int = 8,
|
||||
rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64.
|
||||
rope_ref_w: int = 1024,
|
||||
image_model=None,
|
||||
dtype=None, device=None, operations=None,
|
||||
**pixdit_kwargs,
|
||||
):
|
||||
super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs)
|
||||
|
||||
self.rope_ref_grid_h = rope_ref_h // self.patch_size
|
||||
self.rope_ref_grid_w = rope_ref_w // self.patch_size
|
||||
|
||||
# Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware.
|
||||
def _pit_rope_fn(head_dim, h, w, device=None, dtype=torch.float32, **rope_opts):
|
||||
return precompute_freqs_cis_2d(head_dim, h, w, ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, device=device, dtype=dtype, **rope_opts)
|
||||
for blk in self.pixel_blocks:
|
||||
blk._rope_fn = _pit_rope_fn
|
||||
|
||||
num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval
|
||||
self.lq_proj = LQProjection2D(
|
||||
latent_channels=lq_latent_channels,
|
||||
hidden_dim=lq_hidden_dim,
|
||||
out_dim=self.hidden_size,
|
||||
patch_size=self.patch_size,
|
||||
sr_scale=sr_scale,
|
||||
latent_spatial_down_factor=latent_spatial_down_factor,
|
||||
num_res_blocks=lq_num_res_blocks,
|
||||
num_outputs=num_lq_outputs,
|
||||
interval=lq_interval,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
|
||||
return precompute_freqs_cis_2d(
|
||||
self.hidden_size // self.num_groups,
|
||||
height, width,
|
||||
ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w,
|
||||
device=device, dtype=dtype, **rope_opts,
|
||||
)
|
||||
|
||||
def _pre_patch_block(self, s, i, pid_lq_features, pid_degrade_sigma, **kwargs):
|
||||
if not self.lq_proj.is_gate_active(i):
|
||||
return s
|
||||
out_idx = self.lq_proj.output_index(i)
|
||||
if out_idx >= len(pid_lq_features):
|
||||
return s
|
||||
return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx)
|
||||
|
||||
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs):
|
||||
if lq_latent is None:
|
||||
raise ValueError("PidNet requires lq_latent — attach via PiDConditioning")
|
||||
expected_c = self.lq_proj.latent_channels
|
||||
if lq_latent.shape[1] != expected_c:
|
||||
raise ValueError(
|
||||
f"Input latent has {lq_latent.shape[1]} channels, this model variant expects {expected_c}. "
|
||||
f"Flux1/SD3 = 16 channels, Flux2 = 128 channels."
|
||||
)
|
||||
B = x.shape[0]
|
||||
Hs = x.shape[2] // self.patch_size
|
||||
Ws = x.shape[3] // self.patch_size
|
||||
|
||||
degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
|
||||
if degrade_sigma.numel() == 1 and B > 1:
|
||||
degrade_sigma = degrade_sigma.expand(B).contiguous()
|
||||
|
||||
lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws)
|
||||
|
||||
return super()._forward(
|
||||
x, timesteps,
|
||||
context=context, attention_mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
pid_lq_features=lq_features,
|
||||
pid_degrade_sigma=degrade_sigma,
|
||||
**kwargs,
|
||||
)
|
||||
@ -16,7 +16,6 @@
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import math
|
||||
import ctypes
|
||||
import threading
|
||||
import dataclasses
|
||||
import torch
|
||||
from typing import NamedTuple
|
||||
@ -10,7 +9,7 @@ from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
class TensorFileSlice(NamedTuple):
|
||||
file_ref: object
|
||||
thread_id: int
|
||||
lock: object
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
@ -43,7 +42,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
||||
file_obj = info.file_ref
|
||||
if (destination.device.type != "cpu"
|
||||
or file_obj is None
|
||||
or threading.get_ident() != info.thread_id
|
||||
or destination.numel() * destination.element_size() < info.size
|
||||
or tensor.numel() * tensor.element_size() != info.size
|
||||
or tensor.storage_offset() != 0
|
||||
@ -57,27 +55,29 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
||||
if hostbuf is not None:
|
||||
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
||||
device_ptr = destination2.data_ptr() if destination2 is not None else 0
|
||||
hostbuf.read_file_slice(file_obj, info.offset, info.size,
|
||||
offset=destination.data_ptr() - hostbuf.get_raw_address(),
|
||||
stream=stream_ptr,
|
||||
device_ptr=device_ptr,
|
||||
device=None if destination2 is None else destination2.device.index)
|
||||
with info.lock:
|
||||
hostbuf.read_file_slice(file_obj, info.offset, info.size,
|
||||
offset=destination.data_ptr() - hostbuf.get_raw_address(),
|
||||
stream=stream_ptr,
|
||||
device_ptr=device_ptr,
|
||||
device=None if destination2 is None else destination2.device.index)
|
||||
return True
|
||||
|
||||
buf_type = ctypes.c_ubyte * info.size
|
||||
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||
|
||||
try:
|
||||
file_obj.seek(info.offset)
|
||||
done = 0
|
||||
while done < info.size:
|
||||
try:
|
||||
n = file_obj.readinto(view[done:])
|
||||
except OSError:
|
||||
return False
|
||||
if n <= 0:
|
||||
return False
|
||||
done += n
|
||||
with info.lock:
|
||||
file_obj.seek(info.offset)
|
||||
done = 0
|
||||
while done < info.size:
|
||||
try:
|
||||
n = file_obj.readinto(view[done:])
|
||||
except OSError:
|
||||
return False
|
||||
if n <= 0:
|
||||
return False
|
||||
done += n
|
||||
return True
|
||||
finally:
|
||||
view.release()
|
||||
|
||||
@ -35,6 +35,7 @@ import comfy.ldm.hydit.models
|
||||
import comfy.ldm.audio.dit
|
||||
import comfy.ldm.audio.embedders
|
||||
import comfy.ldm.flux.model
|
||||
import comfy.ldm.lens.model
|
||||
import comfy.ldm.lightricks.model
|
||||
import comfy.ldm.hunyuan_video.model
|
||||
import comfy.ldm.cosmos.model
|
||||
@ -48,6 +49,8 @@ import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.chroma_radiance.model
|
||||
import comfy.ldm.pixeldit.model
|
||||
import comfy.ldm.pixeldit.pid
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
@ -1058,6 +1061,27 @@ class Flux2(Flux):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class Lens(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(
|
||||
model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.lens.model.LensTransformer2DModel,
|
||||
)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return None # Lens has no pooled/ADM conditioning.
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
return out
|
||||
|
||||
class GenmoMochi(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
||||
@ -1375,6 +1399,36 @@ class ZImagePixelSpace(Lumina2):
|
||||
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
|
||||
class PixelDiTT2I(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.pixeldit.model.PixDiT_T2I)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
|
||||
return out
|
||||
|
||||
|
||||
class PiD(PixelDiTT2I):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
BaseModel.__init__(self, model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.pixeldit.pid.PidNet)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
lq_latent = kwargs.get("lq_latent", None)
|
||||
if lq_latent is not None:
|
||||
out["lq_latent"] = comfy.conds.CONDRegular(lq_latent)
|
||||
degrade_sigma = kwargs.get("degrade_sigma", None)
|
||||
if degrade_sigma is not None:
|
||||
out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma)
|
||||
return out
|
||||
|
||||
|
||||
class WAN21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
|
||||
@ -463,6 +463,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||
return dit_config
|
||||
|
||||
# PiD (Pixel Diffusion Decoder). Must check BEFORE plain PixelDiT_T2I.
|
||||
_lq_w_key = '{}lq_proj.latent_proj.0.weight'.format(key_prefix)
|
||||
if _lq_w_key in state_dict_keys:
|
||||
in_ch = int(state_dict[_lq_w_key].shape[1])
|
||||
_gate_prefix = '{}lq_proj.gate_modules.'.format(key_prefix)
|
||||
num_gates = len({k[len(_gate_prefix):].split('.')[0]
|
||||
for k in state_dict_keys if k.startswith(_gate_prefix)})
|
||||
dit_config = {"image_model": "pid",
|
||||
"lq_latent_channels": in_ch,
|
||||
"latent_spatial_down_factor": 16 if in_ch >= 64 else 8}
|
||||
if num_gates > 0:
|
||||
dit_config["lq_interval"] = (14 + num_gates - 1) // num_gates
|
||||
return dit_config
|
||||
|
||||
if '{}core.pixel_embedder.proj.weight'.format(key_prefix) in state_dict_keys: # PixelDiT T2I
|
||||
return {"image_model": "pixeldit_t2i"}
|
||||
|
||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "lumina2"
|
||||
@ -755,6 +772,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["timestep_scale"] = 1000.0
|
||||
return dit_config
|
||||
|
||||
if '{}transformer_blocks.0.attn.norm_added_q.weight'.format(key_prefix) in state_dict_keys \
|
||||
and '{}transformer_blocks.0.img_mlp.w1.weight'.format(key_prefix) in state_dict_keys: # Lens
|
||||
img_in_w = state_dict['{}img_in.weight'.format(key_prefix)]
|
||||
proj_out_w = state_dict['{}proj_out.weight'.format(key_prefix)]
|
||||
multi_layer = '{}txt_norm.0.weight'.format(key_prefix) in state_dict_keys
|
||||
if multi_layer:
|
||||
enc_hidden_dim = state_dict['{}txt_norm.0.weight'.format(key_prefix)].shape[0]
|
||||
# Indices are TE-side; the DiT just consumes L layers in order.
|
||||
selected_layer_index = tuple(range(count_blocks(state_dict_keys, '{}txt_norm.'.format(key_prefix) + '{}.')))
|
||||
else:
|
||||
enc_hidden_dim = state_dict['{}txt_norm.weight'.format(key_prefix)].shape[0]
|
||||
selected_layer_index = (0,)
|
||||
|
||||
return {
|
||||
"image_model": "lens",
|
||||
"in_channels": img_in_w.shape[1],
|
||||
"out_channels": proj_out_w.shape[0] // 4, # patch_size ** 2 (=2² default)
|
||||
"num_layers": count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.'),
|
||||
"num_attention_heads": img_in_w.shape[0] // 64, # // attention_head_dim default
|
||||
"enc_hidden_dim": enc_hidden_dim,
|
||||
"multi_layer_encoder_feature": multi_layer,
|
||||
"selected_layer_index": selected_layer_index,
|
||||
}
|
||||
|
||||
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "qwen_image"
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import psutil
|
||||
import logging
|
||||
@ -27,13 +28,18 @@ import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from contextlib import contextmanager, nullcontext
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.vram_buffer
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||
@ -204,6 +210,107 @@ def get_torch_device():
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
def get_all_torch_devices(exclude_current=False):
|
||||
global cpu_state
|
||||
devices = []
|
||||
if cpu_state == CPUState.GPU:
|
||||
# NVIDIA + AMD/ROCm both expose their GPUs through torch.cuda.*;
|
||||
# without the AMD arm, single-GPU ROCm users get an empty list
|
||||
# which silently turns unload_all_models() into a no-op.
|
||||
if is_nvidia() or is_amd():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
devices.append(torch.device("cuda", i))
|
||||
elif is_intel_xpu():
|
||||
for i in range(torch.xpu.device_count()):
|
||||
devices.append(torch.device("xpu", i))
|
||||
elif is_ascend_npu():
|
||||
for i in range(torch.npu.device_count()):
|
||||
devices.append(torch.device("npu", i))
|
||||
elif is_mlu():
|
||||
for i in range(torch.mlu.device_count()):
|
||||
devices.append(torch.device("mlu", i))
|
||||
else:
|
||||
# Fallback for unhandled GPU backends (e.g. DirectML): at least
|
||||
# report the current device so callers like unload_all_models()
|
||||
# do not silently no-op.
|
||||
devices.append(get_torch_device())
|
||||
else:
|
||||
devices.append(get_torch_device())
|
||||
if exclude_current:
|
||||
current = get_torch_device()
|
||||
if current in devices:
|
||||
devices.remove(current)
|
||||
return devices
|
||||
|
||||
def get_gpu_device_options():
|
||||
"""Return list of device option strings for node widgets.
|
||||
|
||||
Always includes "default" and "cpu". When multiple GPUs are present,
|
||||
adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels).
|
||||
"""
|
||||
options = ["default", "cpu"]
|
||||
devices = get_all_torch_devices()
|
||||
if len(devices) > 1:
|
||||
for i in range(len(devices)):
|
||||
options.append(f"gpu:{i}")
|
||||
return options
|
||||
|
||||
def get_gpu_device_options_no_cpu():
|
||||
"""Variant of get_gpu_device_options that omits "cpu".
|
||||
|
||||
Intended for components like the VAE selector where running on CPU
|
||||
is impractical and should not be offered as a choice.
|
||||
"""
|
||||
return [o for o in get_gpu_device_options() if o != "cpu"]
|
||||
|
||||
def resolve_gpu_device_option(option: str):
|
||||
"""Resolve a device option string to a torch.device.
|
||||
|
||||
Returns None for "default" (let the caller use its normal default).
|
||||
Returns torch.device("cpu") for "cpu".
|
||||
For "gpu:N", returns the Nth torch device. Returns None if the
|
||||
index is out of range, the option string is malformed, or
|
||||
unrecognized (callers are expected to log their own context-rich
|
||||
message before falling back to the default device).
|
||||
"""
|
||||
if option is None or option == "default":
|
||||
return None
|
||||
if option == "cpu":
|
||||
return torch.device("cpu")
|
||||
if option.startswith("gpu:"):
|
||||
try:
|
||||
idx = int(option[4:])
|
||||
except ValueError:
|
||||
return None
|
||||
devices = get_all_torch_devices()
|
||||
if 0 <= idx < len(devices):
|
||||
return devices[idx]
|
||||
return None
|
||||
|
||||
@contextmanager
|
||||
def cuda_device_context(device):
|
||||
"""Context manager that sets torch.cuda.current_device to match *device*.
|
||||
|
||||
Used when running operations on a non-default CUDA device so that custom
|
||||
CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
|
||||
device index. The previous device is restored on exit.
|
||||
|
||||
No-op when *device* is not CUDA, has no explicit index, or already matches
|
||||
the current device.
|
||||
"""
|
||||
prev = None
|
||||
if device.type == "cuda" and device.index is not None:
|
||||
prev = torch.cuda.current_device()
|
||||
if prev != device.index:
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
prev = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if prev is not None:
|
||||
torch.cuda.set_device(prev)
|
||||
|
||||
def get_total_memory(dev=None, torch_total_too=False):
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
@ -492,9 +599,13 @@ try:
|
||||
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||
except:
|
||||
logging.warning("Could not pick default device.")
|
||||
try:
|
||||
for device in get_all_torch_devices(exclude_current=True):
|
||||
logging.info("Device: {}".format(get_torch_device_name(device)))
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
current_loaded_models = []
|
||||
current_loaded_models: list[LoadedModel] = []
|
||||
|
||||
DIRTY_MMAPS = set()
|
||||
|
||||
@ -554,7 +665,7 @@ def ensure_pin_registerable(size, evict_active=False):
|
||||
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
def __init__(self, model: ModelPatcher):
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
self.real_model = None
|
||||
@ -562,7 +673,7 @@ class LoadedModel:
|
||||
self.model_finalizer = None
|
||||
self._patcher_finalizer = None
|
||||
|
||||
def _set_model(self, model):
|
||||
def _set_model(self, model: ModelPatcher):
|
||||
self._model = weakref.ref(model)
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
@ -573,6 +684,7 @@ class LoadedModel:
|
||||
model = self._parent_model()
|
||||
if model is not None:
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
@ -1848,7 +1960,34 @@ def soft_empty_cache(force=False):
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
for device in get_all_torch_devices():
|
||||
free_memory(1e30, device)
|
||||
|
||||
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
|
||||
'Unload only model and its clones - primarily for multigpu cloning purposes.'
|
||||
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
|
||||
additional_models = []
|
||||
if unload_additional_models:
|
||||
additional_models = model.get_nested_additional_models()
|
||||
keep_loaded = []
|
||||
for loaded_model in initial_keep_loaded:
|
||||
if loaded_model.model is not None:
|
||||
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||
continue
|
||||
# check additional models if they are a match
|
||||
skip = False
|
||||
for add_model in additional_models:
|
||||
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||
skip = True
|
||||
break
|
||||
if skip:
|
||||
continue
|
||||
keep_loaded.append(loaded_model)
|
||||
if not all_devices:
|
||||
free_memory(1e30, get_torch_device(), keep_loaded)
|
||||
else:
|
||||
for device in get_all_torch_devices():
|
||||
free_memory(1e30, device, keep_loaded)
|
||||
|
||||
def debug_memory_summary():
|
||||
if is_amd() or is_nvidia():
|
||||
|
||||
@ -78,12 +78,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
||||
def create_model_options_clone(orig_model_options: dict):
|
||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||
|
||||
def create_hook_patches_clone(orig_hook_patches):
|
||||
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
|
||||
new_hook_patches = {}
|
||||
for hook_ref in orig_hook_patches:
|
||||
new_hook_patches[hook_ref] = {}
|
||||
for k in orig_hook_patches[hook_ref]:
|
||||
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
||||
if copy_tuples:
|
||||
for i in range(len(new_hook_patches[hook_ref][k])):
|
||||
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
|
||||
return new_hook_patches
|
||||
|
||||
def wipe_lowvram_weight(m):
|
||||
@ -329,7 +332,10 @@ class ModelPatcher:
|
||||
self.is_clip = False
|
||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||
|
||||
self.cached_patcher_init: tuple[Callable, tuple] | None = None
|
||||
self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None
|
||||
self.is_multigpu_base_clone = False
|
||||
self.clone_base_uuid = uuid.uuid4()
|
||||
|
||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
@ -366,7 +372,8 @@ class ModelPatcher:
|
||||
#than pays for CFG. So return everything both torch and Aimdo could give us
|
||||
aimdo_mem = 0
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
|
||||
aimdo_device = device.index if getattr(device, "type", None) == "cuda" else None
|
||||
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze(aimdo_device)
|
||||
return comfy.model_management.get_free_memory(device) + aimdo_mem
|
||||
|
||||
def get_clone_model_override(self):
|
||||
@ -380,6 +387,8 @@ class ModelPatcher:
|
||||
if self.cached_patcher_init is None:
|
||||
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
|
||||
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
||||
if len(self.cached_patcher_init) > 2:
|
||||
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
|
||||
model_override = temp_model_patcher.get_clone_model_override()
|
||||
if model_override is None:
|
||||
model_override = self.get_clone_model_override()
|
||||
@ -438,19 +447,113 @@ class ModelPatcher:
|
||||
n.hook_mode = self.hook_mode
|
||||
|
||||
n.cached_patcher_init = self.cached_patcher_init
|
||||
n.is_multigpu_base_clone = self.is_multigpu_base_clone
|
||||
n.clone_base_uuid = self.clone_base_uuid
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
|
||||
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
|
||||
if self.cached_patcher_init is None:
|
||||
raise RuntimeError(
|
||||
f"Cannot create multigpu deepclone of {self.model.__class__.__name__}: "
|
||||
"the loader that produced this model does not support multigpu "
|
||||
"(cached_patcher_init is not initialized). Use a core loader "
|
||||
"(CheckpointLoaderSimple, UNETLoader, CLIPLoader/DualCLIPLoader, VAELoader), "
|
||||
"or have the custom loader register a cached_patcher_init factory."
|
||||
)
|
||||
comfy.model_management.unload_model_and_clones(self)
|
||||
# Produce a freshly-loaded patcher from the loader factory so the multigpu
|
||||
# clone owns its own untainted model weights (rather than relying on
|
||||
# copy.deepcopy of an already-patched/already-loaded module).
|
||||
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
|
||||
if len(self.cached_patcher_init) > 2:
|
||||
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
|
||||
# Override clone()'s normal "share self.model + share backup containers" with
|
||||
# the pristine model from temp_model_patcher plus empty backup containers --
|
||||
# the fresh model has no patches applied, so any deepcopy of self's stale
|
||||
# backup/object_patches_backup/pinned would just propagate dead state that
|
||||
# no longer corresponds to anything in n.model.
|
||||
model_override = (temp_model_patcher.model, ({}, {}, {}, set()))
|
||||
n = self.clone(model_override=model_override)
|
||||
# clone() copies hook_backup by reference from self; reset since model is pristine.
|
||||
n.hook_backup = {}
|
||||
# set load device, if present
|
||||
if new_load_device is not None:
|
||||
n.load_device = new_load_device
|
||||
# Ensure any per-device bookkeeping (e.g. ModelPatcherDynamic.dynamic_pins)
|
||||
# has an entry for n.load_device on the freshly-loaded n.model. temp_model_patcher's
|
||||
# __init__ only registered its own (default) load_device.
|
||||
if hasattr(n, "register_load_device"):
|
||||
n.register_load_device(n.load_device)
|
||||
# multigpu clone should not have multigpu additional_models entry
|
||||
n.remove_additional_models("multigpu")
|
||||
# multigpu_clone all stored additional_models; make sure circular references are properly handled
|
||||
if models_cache is None:
|
||||
models_cache = {}
|
||||
for key, model_list in n.additional_models.items():
|
||||
for i in range(len(model_list)):
|
||||
add_model = n.additional_models[key][i]
|
||||
if add_model.clone_base_uuid not in models_cache:
|
||||
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
|
||||
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
def match_multigpu_clones(self):
|
||||
multigpu_models = self.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) > 0:
|
||||
new_multigpu_models = []
|
||||
for mm in multigpu_models:
|
||||
# clone main model, but bring over relevant props from existing multigpu clone
|
||||
n = self.clone()
|
||||
n.load_device = mm.load_device
|
||||
n.backup = mm.backup
|
||||
n.object_patches_backup = mm.object_patches_backup
|
||||
n.hook_backup = mm.hook_backup
|
||||
n.model = mm.model
|
||||
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
|
||||
n.remove_additional_models("multigpu")
|
||||
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
|
||||
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
|
||||
# figure out which additional models are not present in multigpu clone
|
||||
models_cache = {}
|
||||
for mm_add_model in mm.get_additional_models():
|
||||
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
|
||||
remove_models_uuids = set(list(models_cache.keys()))
|
||||
for key, model_list in orig_additional_models.items():
|
||||
for orig_add_model in model_list:
|
||||
if orig_add_model.clone_base_uuid not in models_cache:
|
||||
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
|
||||
existing_list = n.get_additional_models_with_key(key)
|
||||
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
|
||||
n.set_additional_models(key, existing_list)
|
||||
if orig_add_model.clone_base_uuid in remove_models_uuids:
|
||||
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
|
||||
# remove duplicate additional models
|
||||
for key, model_list in n.additional_models.items():
|
||||
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
|
||||
n.set_additional_models(key, new_model_list)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
|
||||
callback(self, n)
|
||||
new_multigpu_models.append(n)
|
||||
self.set_additional_models("multigpu", new_multigpu_models)
|
||||
|
||||
def is_clone(self, other):
|
||||
if hasattr(other, 'model') and self.model is other.model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def clone_has_same_weights(self, clone: 'ModelPatcher'):
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
|
||||
if allow_multigpu:
|
||||
if self.clone_base_uuid != clone.clone_base_uuid:
|
||||
return False
|
||||
else:
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
|
||||
if self.current_hooks != clone.current_hooks:
|
||||
return False
|
||||
@ -1232,7 +1335,7 @@ class ModelPatcher:
|
||||
return self.additional_models.get(key, [])
|
||||
|
||||
def get_additional_models(self):
|
||||
all_models = []
|
||||
all_models: list[ModelPatcher] = []
|
||||
for models in self.additional_models.values():
|
||||
all_models.extend(models)
|
||||
return all_models
|
||||
@ -1286,9 +1389,18 @@ class ModelPatcher:
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
||||
callback(self)
|
||||
|
||||
def prepare_state(self, timestep):
|
||||
def prepare_state(self, timestep, model_options):
|
||||
ignore_multigpu = model_options.get("ignore_multigpu", False)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
||||
callback(self, timestep)
|
||||
callback(self, timestep, model_options)
|
||||
if not ignore_multigpu and "multigpu_clones" in model_options:
|
||||
model_options["ignore_multigpu"] = True
|
||||
try:
|
||||
for p in model_options["multigpu_clones"].values():
|
||||
p: ModelPatcher
|
||||
p.prepare_state(timestep, model_options)
|
||||
finally:
|
||||
model_options.pop("ignore_multigpu", None)
|
||||
|
||||
def restore_hook_patches(self):
|
||||
if self.hook_patches_backup is not None:
|
||||
@ -1301,12 +1413,18 @@ class ModelPatcher:
|
||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
||||
curr_t = t[0]
|
||||
reset_current_hooks = False
|
||||
multigpu_kf_changed_cache = None
|
||||
transformer_options = model_options.get("transformer_options", {})
|
||||
for hook in hook_group.hooks:
|
||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
||||
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||
# this will cause the weights to be recalculated when sampling
|
||||
if changed:
|
||||
# cache changed for multigpu usage
|
||||
if "multigpu_clones" in model_options:
|
||||
if multigpu_kf_changed_cache is None:
|
||||
multigpu_kf_changed_cache = []
|
||||
multigpu_kf_changed_cache.append(hook)
|
||||
# reset current_hooks if contains hook that changed
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
@ -1318,6 +1436,28 @@ class ModelPatcher:
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
if "multigpu_clones" in model_options:
|
||||
for p in model_options["multigpu_clones"].values():
|
||||
p: ModelPatcher
|
||||
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
|
||||
|
||||
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
|
||||
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
|
||||
if kf_changed_cache is None:
|
||||
return
|
||||
reset_current_hooks = False
|
||||
# reset current_hooks if contains hook that changed
|
||||
for hook in kf_changed_cache:
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
if current_hook == hook:
|
||||
reset_current_hooks = True
|
||||
break
|
||||
for cached_group in list(self.cached_hook_patches.keys()):
|
||||
if cached_group.contains(hook):
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
|
||||
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||
registered: comfy.hooks.HookGroup = None):
|
||||
@ -1566,16 +1706,27 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
self.model.dynamic_vbars = {}
|
||||
if not hasattr(self.model, "dynamic_pins"):
|
||||
self.model.dynamic_pins = {}
|
||||
if self.load_device not in self.model.dynamic_pins:
|
||||
self.model.dynamic_pins[self.load_device] = {
|
||||
self.register_load_device(self.load_device)
|
||||
self.non_dynamic_delegate_model = None
|
||||
assert load_device is not None
|
||||
|
||||
def register_load_device(self, device):
|
||||
"""Ensure dynamic_pins has an entry for *device*.
|
||||
|
||||
Called from __init__ and also from any code that retargets an
|
||||
already-constructed patcher to a new load_device (e.g. the
|
||||
Select{Model,CLIP,VAE}Device selector nodes); without this entry
|
||||
partially_unload_ram() raises KeyError when it tries to read the
|
||||
per-device pin state.
|
||||
"""
|
||||
if device not in self.model.dynamic_pins:
|
||||
self.model.dynamic_pins[device] = {
|
||||
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
|
||||
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
|
||||
"hostbufs_initialized": False,
|
||||
"failed": False,
|
||||
"active": False,
|
||||
}
|
||||
self.non_dynamic_delegate_model = None
|
||||
assert load_device is not None
|
||||
|
||||
def is_dynamic(self):
|
||||
return True
|
||||
|
||||
248
comfy/multigpu.py
Normal file
248
comfy/multigpu.py
Normal file
@ -0,0 +1,248 @@
|
||||
from __future__ import annotations
|
||||
import queue
|
||||
import threading
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import comfy.utils
|
||||
import comfy.patcher_extension
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class MultiGPUThreadPool:
|
||||
"""Persistent thread pool for multi-GPU work distribution.
|
||||
|
||||
Maintains one worker thread per extra GPU device. Each thread calls
|
||||
torch.cuda.set_device() once at startup so that compiled kernel caches
|
||||
(inductor/triton) stay warm across diffusion steps.
|
||||
"""
|
||||
|
||||
def __init__(self, devices: list[torch.device]):
|
||||
self._workers: list[threading.Thread] = []
|
||||
self._work_queues: dict[torch.device, queue.Queue] = {}
|
||||
self._result_queues: dict[torch.device, queue.Queue] = {}
|
||||
|
||||
for device in devices:
|
||||
wq = queue.Queue()
|
||||
rq = queue.Queue()
|
||||
self._work_queues[device] = wq
|
||||
self._result_queues[device] = rq
|
||||
t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
|
||||
t.start()
|
||||
self._workers.append(t)
|
||||
|
||||
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
except Exception as e:
|
||||
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
|
||||
while True:
|
||||
item = work_q.get()
|
||||
if item is None:
|
||||
return
|
||||
result_q.put((None, e))
|
||||
return
|
||||
while True:
|
||||
item = work_q.get()
|
||||
if item is None:
|
||||
break
|
||||
fn, args, kwargs = item
|
||||
try:
|
||||
result = fn(*args, **kwargs)
|
||||
result_q.put((result, None))
|
||||
except Exception as e:
|
||||
result_q.put((None, e))
|
||||
|
||||
def submit(self, device: torch.device, fn, *args, **kwargs):
|
||||
self._work_queues[device].put((fn, args, kwargs))
|
||||
|
||||
def get_result(self, device: torch.device):
|
||||
return self._result_queues[device].get()
|
||||
|
||||
@property
|
||||
def devices(self) -> list[torch.device]:
|
||||
return list(self._work_queues.keys())
|
||||
|
||||
def shutdown(self):
|
||||
for wq in self._work_queues.values():
|
||||
wq.put(None) # sentinel
|
||||
for t in self._workers:
|
||||
t.join(timeout=5.0)
|
||||
|
||||
|
||||
class GPUOptions:
|
||||
def __init__(self, device_index: int, relative_speed: float):
|
||||
self.device_index = device_index
|
||||
self.relative_speed = relative_speed
|
||||
|
||||
def clone(self):
|
||||
return GPUOptions(self.device_index, self.relative_speed)
|
||||
|
||||
def create_dict(self):
|
||||
return {
|
||||
"relative_speed": self.relative_speed
|
||||
}
|
||||
|
||||
class GPUOptionsGroup:
|
||||
def __init__(self):
|
||||
self.options: dict[int, GPUOptions] = {}
|
||||
|
||||
def add(self, info: GPUOptions):
|
||||
self.options[info.device_index] = info
|
||||
|
||||
def clone(self):
|
||||
c = GPUOptionsGroup()
|
||||
for opt in self.options.values():
|
||||
c.add(opt)
|
||||
return c
|
||||
|
||||
def register(self, model: ModelPatcher):
|
||||
opts_dict = {}
|
||||
# get devices that are valid for this model
|
||||
devices: list[torch.device] = [model.load_device]
|
||||
for extra_model in model.get_additional_models_with_key("multigpu"):
|
||||
extra_model: ModelPatcher
|
||||
devices.append(extra_model.load_device)
|
||||
# create dictionary with actual device mapped to its GPUOptions
|
||||
device_opts_list: list[GPUOptions] = []
|
||||
for device in devices:
|
||||
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
|
||||
opts_dict[device] = device_opts.create_dict()
|
||||
device_opts_list.append(device_opts)
|
||||
# make relative_speed relative to 1.0
|
||||
min_speed = min([x.relative_speed for x in device_opts_list])
|
||||
for value in opts_dict.values():
|
||||
value['relative_speed'] /= min_speed
|
||||
model.model_options['multigpu_options'] = opts_dict
|
||||
|
||||
|
||||
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
|
||||
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
|
||||
model = model.clone()
|
||||
# check if multigpu is already prepared - get the load devices from them if possible to exclude
|
||||
skip_devices = set()
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) > 0:
|
||||
for mm in multigpu_models:
|
||||
skip_devices.add(mm.load_device)
|
||||
skip_devices = list(skip_devices)
|
||||
|
||||
# Exclude the primary model's actual device, not the global current device:
|
||||
# after SelectModelDevice(gpu:N) the primary may not live on the process's
|
||||
# current CUDA device, and excluding the wrong device picks bad extras.
|
||||
all_devices = comfy.model_management.get_all_torch_devices(exclude_current=False)
|
||||
full_extra_devices = [d for d in all_devices if d != model.load_device]
|
||||
limit_extra_devices = full_extra_devices[:max_gpus-1]
|
||||
extra_devices = limit_extra_devices.copy()
|
||||
# exclude skipped devices
|
||||
for skip in skip_devices:
|
||||
if skip in extra_devices:
|
||||
extra_devices.remove(skip)
|
||||
# create new deepclones
|
||||
if len(extra_devices) > 0:
|
||||
for device in extra_devices:
|
||||
device_patcher = None
|
||||
if reuse_loaded:
|
||||
# Only reuse a previously-loaded MultiGPU clone. A SelectModelDevice
|
||||
# patcher on the same device shares clone_base_uuid but has
|
||||
# is_multigpu_base_clone=False, which would later be filtered out by
|
||||
# prepare_model_patcher_multigpu_clones() and silently shrink the
|
||||
# work split back to one GPU.
|
||||
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
|
||||
for lm in loaded_models:
|
||||
if lm.model is None:
|
||||
continue
|
||||
if lm.load_device != device:
|
||||
continue
|
||||
if lm.clone_base_uuid != model.clone_base_uuid:
|
||||
continue
|
||||
if not getattr(lm, "is_multigpu_base_clone", False):
|
||||
continue
|
||||
device_patcher = lm.clone()
|
||||
logging.info(f"Reusing loaded multigpu deepclone of {device_patcher.model.__class__.__name__} for {device}")
|
||||
break
|
||||
if device_patcher is None:
|
||||
device_patcher = model.deepclone_multigpu(new_load_device=device)
|
||||
# Always flag the clone; whether reused or freshly deepcloned, it must
|
||||
# advertise itself as a MultiGPU base clone so the cond scheduler picks
|
||||
# it up in prepare_model_patcher_multigpu_clones().
|
||||
device_patcher.is_multigpu_base_clone = True
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
multigpu_models.append(device_patcher)
|
||||
model.set_additional_models("multigpu", multigpu_models)
|
||||
model.match_multigpu_clones()
|
||||
if gpu_options is None:
|
||||
gpu_options = GPUOptionsGroup()
|
||||
gpu_options.register(model)
|
||||
else:
|
||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
|
||||
# only keep model clones that don't go 'past' the intended max_gpu count;
|
||||
# this prunes any inherited multigpu clones whose load_device is no longer allowed
|
||||
# when max_gpus is lowered between runs.
|
||||
allowed_devices = set(limit_extra_devices)
|
||||
allowed_devices.add(model.load_device)
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
new_multigpu_models = [m for m in multigpu_models if m.load_device in allowed_devices]
|
||||
if len(new_multigpu_models) != len(multigpu_models):
|
||||
model.set_additional_models("multigpu", new_multigpu_models)
|
||||
model.match_multigpu_clones()
|
||||
return model
|
||||
|
||||
|
||||
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
||||
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
||||
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
||||
opts_dict = model_options['multigpu_options']
|
||||
devices = list(model_options['multigpu_clones'].keys())
|
||||
speed_per_device = []
|
||||
work_per_device = []
|
||||
# get sum of each device's relative_speed
|
||||
total_speed = 0.0
|
||||
for opts in opts_dict.values():
|
||||
total_speed += opts['relative_speed']
|
||||
# get relative work for each device;
|
||||
# obtained by w = (W*r)/R
|
||||
for device in devices:
|
||||
relative_speed = opts_dict[device]['relative_speed']
|
||||
relative_work = (total_work*relative_speed) / total_speed
|
||||
speed_per_device.append(relative_speed)
|
||||
work_per_device.append(relative_work)
|
||||
# relative work must be expressed in whole numbers, but likely is a decimal;
|
||||
# perform rounding while maintaining total sum equal to total work (sum of relative works)
|
||||
work_per_device = round_preserved(work_per_device)
|
||||
dict_work_per_device = {}
|
||||
for device, relative_work in zip(devices, work_per_device):
|
||||
dict_work_per_device[device] = relative_work
|
||||
if not return_idle_time:
|
||||
return LoadBalance(dict_work_per_device, None)
|
||||
# divide relative work by relative speed to get estimated completion time of said work by each device;
|
||||
# time here is relative and does not correspond to real-world units
|
||||
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
|
||||
# calculate relative time spent by the devices waiting on each other after their work is completed
|
||||
idle_time = abs(min(completion_time) - max(completion_time))
|
||||
# if need to compare work idle time, need to normalize to a common total work
|
||||
if work_normalized:
|
||||
idle_time *= (work_normalized/total_work)
|
||||
|
||||
return LoadBalance(dict_work_per_device, idle_time)
|
||||
|
||||
def round_preserved(values: list[float]):
|
||||
'Round all values in a list, preserving the combined sum of values.'
|
||||
# get floor of values; casting to int does it too
|
||||
floored = [int(x) for x in values]
|
||||
total_floored = sum(floored)
|
||||
# get remainder to distribute
|
||||
remainder = round(sum(values)) - total_floored
|
||||
# pair values with fractional portions
|
||||
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
|
||||
# sort by fractional part in descending order
|
||||
fractional.sort(key=lambda x: x[1], reverse=True)
|
||||
# distribute the remainder
|
||||
for i in range(remainder):
|
||||
index = fractional[i][0]
|
||||
floored[index] += 1
|
||||
return floored
|
||||
476
comfy/ops.py
476
comfy/ops.py
@ -18,6 +18,7 @@
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import contextlib
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
@ -1047,6 +1048,144 @@ class QuantLinearFunc(torch.autograd.Function):
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
# Quantized-weight module helpers
|
||||
|
||||
def _quantized_apply(module, fn, recurse=True):
|
||||
"""Re-wrap Parameters after fn so .to()/.cuda() propagate through QuantizedTensor weights."""
|
||||
if recurse:
|
||||
for child in module.children():
|
||||
child._apply(fn)
|
||||
for key, param in module._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
p = fn(param)
|
||||
if (not torch.is_inference_mode_enabled()) and p.is_inference():
|
||||
p = p.clone()
|
||||
module.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||
for key, buf in module._buffers.items():
|
||||
if buf is not None:
|
||||
module._buffers[key] = fn(buf)
|
||||
return module
|
||||
|
||||
|
||||
def _load_quantized_module(module, super_load, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs, load_extra_params=False):
|
||||
"""Shared _load_from_state_dict body for quantized-weight modules.
|
||||
|
||||
Pops weight (+ scales, +/- extras), populates module.weight as a Parameter
|
||||
or Parameter-wrapped QuantizedTensor, then calls super_load and strips
|
||||
consumed keys from missing_keys. Reads compute_dtype from factory_kwargs
|
||||
and disabled formats from module._disabled_formats.
|
||||
"""
|
||||
device = module.factory_kwargs["device"]
|
||||
compute_dtype = module.factory_kwargs["dtype"]
|
||||
disabled_formats = module._disabled_formats
|
||||
layer_name = prefix.rstrip('.')
|
||||
|
||||
weight = state_dict.pop(f"{prefix}weight", None)
|
||||
if weight is None:
|
||||
logging.warning(f"Missing weight for layer {layer_name}")
|
||||
module.weight = None
|
||||
return
|
||||
manually_loaded_keys = [f"{prefix}weight"]
|
||||
|
||||
def pop_scale(name, dtype=None):
|
||||
key = f"{prefix}{name}"
|
||||
v = state_dict.pop(key, None)
|
||||
if v is not None:
|
||||
v = v.to(device=device)
|
||||
if dtype is not None:
|
||||
v = v.view(dtype=dtype)
|
||||
manually_loaded_keys.append(key)
|
||||
return v
|
||||
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
|
||||
if layer_conf is None:
|
||||
module.weight = torch.nn.Parameter(weight.to(device=device, dtype=compute_dtype), requires_grad=False)
|
||||
else:
|
||||
module.quant_format = layer_conf.get("format", None)
|
||||
module._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||
if not module._full_precision_mm:
|
||||
module._full_precision_mm = module._full_precision_mm_config
|
||||
if module.quant_format in disabled_formats:
|
||||
module._full_precision_mm = True
|
||||
if module.quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
qconfig = QUANT_ALGOS[module.quant_format]
|
||||
module.layout_type = qconfig["comfy_tensor_layout"]
|
||||
layout_cls = get_layout_class(module.layout_type)
|
||||
|
||||
# Per-format scales; fp8 dtype views handle both legacy uint8-on-disk and native fp8.
|
||||
if module.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
||||
scales = {"scale": pop_scale("weight_scale")}
|
||||
elif module.quant_format == "mxfp8":
|
||||
bs = pop_scale("weight_scale", torch.float8_e8m0fnu)
|
||||
if bs is None:
|
||||
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
||||
scales = {"scale": bs}
|
||||
elif module.quant_format == "nvfp4":
|
||||
ts = pop_scale("weight_scale_2")
|
||||
bs = pop_scale("weight_scale", torch.float8_e4m3fn)
|
||||
if ts is None or bs is None:
|
||||
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||
scales = {"scale": ts, "block_scale": bs}
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
|
||||
|
||||
params = layout_cls.Params(**scales, orig_dtype=compute_dtype, orig_shape=module._orig_shape)
|
||||
module.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), module.layout_type, params),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
if load_extra_params:
|
||||
for param_name in qconfig["parameters"]:
|
||||
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||
continue
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
module.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super_load(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
for key in manually_loaded_keys:
|
||||
if key in missing_keys:
|
||||
missing_keys.remove(key)
|
||||
|
||||
|
||||
def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extra_quant_params=()):
|
||||
"""Shared state_dict body. extra_quant_conf merges into the comfy_quant JSON;
|
||||
extra_quant_params names attributes written as additional top-level keys."""
|
||||
if not hasattr(module, 'weight'):
|
||||
logging.warning(f"Warning: state dict on uninitialized op {prefix}")
|
||||
return sd
|
||||
bias = getattr(module, 'bias', None)
|
||||
if bias is not None:
|
||||
sd[f"{prefix}bias"] = bias
|
||||
if module.weight is None:
|
||||
return sd
|
||||
if isinstance(module.weight, QuantizedTensor):
|
||||
sd.update(module.weight.state_dict(f"{prefix}weight"))
|
||||
quant_conf = {"format": module.quant_format}
|
||||
if getattr(module, '_full_precision_mm_config', False):
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
if extra_quant_conf:
|
||||
quant_conf.update(extra_quant_conf)
|
||||
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
|
||||
for name in extra_quant_params:
|
||||
value = getattr(module, name, None)
|
||||
if value is not None:
|
||||
sd[f"{prefix}{name}"] = value
|
||||
else:
|
||||
sd[f"{prefix}weight"] = module.weight
|
||||
return sd
|
||||
|
||||
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
@ -1056,21 +1195,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
_disabled = disabled
|
||||
|
||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
_disabled_formats = disabled
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
||||
super().__init__()
|
||||
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self._orig_shape = (out_features, in_features)
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||
else:
|
||||
@ -1083,151 +1217,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
|
||||
key = f"{prefix}{param_name}"
|
||||
value = state_dict.pop(key, None)
|
||||
if value is not None:
|
||||
value = value.to(device=device)
|
||||
if dtype is not None:
|
||||
value = value.view(dtype=dtype)
|
||||
manually_loaded_keys.append(key)
|
||||
return value
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
device = self.factory_kwargs["device"]
|
||||
layer_name = prefix.rstrip('.')
|
||||
weight_key = f"{prefix}weight"
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
logging.warning(f"Missing weight for layer {layer_name}")
|
||||
self.weight = None
|
||||
return
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
|
||||
if layer_conf is None:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
self.quant_format = layer_conf.get("format", None)
|
||||
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||
if not self._full_precision_mm:
|
||||
self._full_precision_mm = self._full_precision_mm_config
|
||||
|
||||
if self.quant_format in MixedPrecisionOps._disabled:
|
||||
self._full_precision_mm = True
|
||||
|
||||
if self.quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
qconfig = QUANT_ALGOS[self.quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
layout_cls = get_layout_class(self.layout_type)
|
||||
|
||||
# Load format-specific parameters
|
||||
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
|
||||
# FP8: single tensor scale
|
||||
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "mxfp8":
|
||||
# MXFP8: E8M0 block scales stored as uint8 in safetensors
|
||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||
dtype=torch.uint8)
|
||||
|
||||
if block_scale is None:
|
||||
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
||||
|
||||
block_scale = block_scale.view(torch.float8_e8m0fnu)
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=block_scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "nvfp4":
|
||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
|
||||
if tensor_scale is None or block_scale is None:
|
||||
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=tensor_scale,
|
||||
block_scale=block_scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||
continue # Already handled above
|
||||
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
for key in manually_loaded_keys:
|
||||
if key in missing_keys:
|
||||
missing_keys.remove(key)
|
||||
def _load_from_state_dict(self, *args):
|
||||
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=True)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
if destination is not None:
|
||||
sd = destination
|
||||
else:
|
||||
sd = {}
|
||||
|
||||
if not hasattr(self, 'weight'):
|
||||
logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
|
||||
return sd
|
||||
|
||||
if self.bias is not None:
|
||||
sd["{}bias".format(prefix)] = self.bias
|
||||
|
||||
if self.weight is None:
|
||||
return sd
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
sd[k] = sd_out[k]
|
||||
|
||||
quant_conf = {"format": self.quant_format}
|
||||
if self._full_precision_mm_config:
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
input_scale = getattr(self, 'input_scale', None)
|
||||
if input_scale is not None:
|
||||
sd["{}input_scale".format(prefix)] = input_scale
|
||||
else:
|
||||
sd["{}weight".format(prefix)] = self.weight
|
||||
return sd
|
||||
sd = destination if destination is not None else {}
|
||||
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_params=("input_scale",))
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
@ -1317,25 +1312,126 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
|
||||
if recurse:
|
||||
for module in self.children():
|
||||
module._apply(fn)
|
||||
return _quantized_apply(self, fn, recurse)
|
||||
|
||||
for key, param in self._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
p = fn(param)
|
||||
if (not torch.is_inference_mode_enabled()) and p.is_inference():
|
||||
p = p.clone()
|
||||
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||
for key, buf in self._buffers.items():
|
||||
if buf is not None:
|
||||
self._buffers[key] = fn(buf)
|
||||
return self
|
||||
class MoEExperts(torch.nn.Module, CastWeightBiasOp):
|
||||
"""Container for E quantized expert weights, indexed via expert_weight(i).
|
||||
|
||||
The bank lives on self.weight as a single 3D tensor — either a
|
||||
compute_dtype Parameter or a Parameter wrapping a QuantizedTensor
|
||||
with leading expert dim.
|
||||
|
||||
State-dict layout matches mixed_precision_ops.Linear with a leading
|
||||
expert dim:
|
||||
{prefix}.weight quant data (storage_t), leading dim = E
|
||||
{prefix}.weight_scale block / per-tensor scale
|
||||
{prefix}.weight_scale_2 [E] or scalar NVFP4 only
|
||||
{prefix}.bias [E, out_features] optional, compute_dtype
|
||||
{prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}}
|
||||
|
||||
Without comfy_quant the weight loads as a plain compute_dtype 3D Parameter [E, out, in].
|
||||
"""
|
||||
|
||||
_disabled_formats = disabled
|
||||
|
||||
def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self._orig_shape = (num_experts, out_features, in_features)
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
# Populated by _load_from_state_dict:
|
||||
self.weight = None
|
||||
self.quant_format = None
|
||||
self.layout_type = None
|
||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||
self._full_precision_mm_config = False
|
||||
self._resident_bank = None
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _apply(self, fn, recurse=True):
|
||||
return _quantized_apply(self, fn, recurse)
|
||||
|
||||
def _load_from_state_dict(self, *args):
|
||||
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=False)
|
||||
|
||||
def expert_weight(self, i: int):
|
||||
"""Expert i's weight (Tensor or per-expert QuantizedTensor view)."""
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
return self._expert_qt_from(self.weight, i)
|
||||
return self.weight[i]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def bank_resident(self, input):
|
||||
"""Cast the whole bank once; expert_linear inside reuses the cast.
|
||||
Not re-entrant — do not nest calls on the same instance.
|
||||
"""
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
self._resident_bank = (weight, bias)
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
self._resident_bank = None
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
|
||||
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
|
||||
"""Linear against expert i's weight (with optional bias)."""
|
||||
resident = getattr(self, "_resident_bank", None)
|
||||
if resident is not None:
|
||||
weight, bias = resident
|
||||
return self._expert_linear_impl(input, weight, bias, i)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
try:
|
||||
return self._expert_linear_impl(input, weight, bias, i)
|
||||
finally:
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
|
||||
def _expert_linear_impl(self, input, weight, bias, i):
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
qw = self._expert_qt_from(weight, i)
|
||||
else:
|
||||
qw = weight[i]
|
||||
b = cast_to_input(bias[i], input, copy=False) if bias is not None else None
|
||||
|
||||
if isinstance(qw, QuantizedTensor):
|
||||
use_fast = (
|
||||
not self._full_precision_mm
|
||||
and qw.layout_cls.supports_fast_matmul()
|
||||
and input.dim() == 2
|
||||
)
|
||||
if use_fast:
|
||||
qin = QuantizedTensor.from_float(input, self.layout_type)
|
||||
return torch.nn.functional.linear(qin, qw, b)
|
||||
out = input @ qw.dequantize().t()
|
||||
return out + b if b is not None else out
|
||||
return torch.nn.functional.linear(input, qw, b)
|
||||
|
||||
def _expert_qt_from(self, weight: QuantizedTensor, i: int) -> QuantizedTensor:
|
||||
"""Build a per-expert QuantizedTensor by indexing into a resident bank."""
|
||||
params = weight._params
|
||||
kwargs = {
|
||||
"scale": params.scale[i] if params.scale.dim() else params.scale,
|
||||
"orig_dtype": params.orig_dtype,
|
||||
"orig_shape": (self.out_features, self.in_features),
|
||||
}
|
||||
if hasattr(params, "block_scale"): # NVFP4
|
||||
kwargs["block_scale"] = params.block_scale[i]
|
||||
return QuantizedTensor(weight._qdata[i], weight._layout_cls, type(params)(**kwargs))
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
sd = destination if destination is not None else {}
|
||||
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_conf={"num_experts": self.num_experts})
|
||||
|
||||
class Embedding(manual_cast.Embedding):
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
weight_key = f"{prefix}weight"
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
@ -1343,14 +1439,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
# Only fp8 makes sense for embeddings (per-row dequant via index select).
|
||||
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
|
||||
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
|
||||
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
|
||||
quant_format = layer_conf.get("format") if layer_conf is not None else None
|
||||
manually_loaded_keys = []
|
||||
|
||||
if quant_format in ("float8_e4m3fn", "float8_e5m2") and weight_key in state_dict:
|
||||
self.quant_format = quant_format
|
||||
qconfig = QUANT_ALGOS[quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
layout_cls = get_layout_class(self.layout_type)
|
||||
weight = state_dict.pop(weight_key)
|
||||
manually_loaded_keys = [weight_key]
|
||||
manually_loaded_keys.append(weight_key)
|
||||
|
||||
scale_key = f"{prefix}weight_scale"
|
||||
scale = state_dict.pop(scale_key, None)
|
||||
@ -1366,35 +1464,19 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
|
||||
requires_grad=False)
|
||||
elif layer_conf is not None:
|
||||
# Unsupported format — restore the marker so it round-trips; fall through to default load.
|
||||
state_dict[f"{prefix}comfy_quant"] = torch.tensor(
|
||||
list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
for k in manually_loaded_keys:
|
||||
if k in missing_keys:
|
||||
missing_keys.remove(k)
|
||||
else:
|
||||
if layer_conf is not None:
|
||||
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
for k in manually_loaded_keys:
|
||||
if k in missing_keys:
|
||||
missing_keys.remove(k)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
if destination is not None:
|
||||
sd = destination
|
||||
else:
|
||||
sd = {}
|
||||
|
||||
if not hasattr(self, 'weight') or self.weight is None:
|
||||
return sd
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
sd[k] = sd_out[k]
|
||||
|
||||
quant_conf = {"format": self.quant_format}
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
else:
|
||||
sd["{}weight".format(prefix)] = self.weight
|
||||
return sd
|
||||
sd = destination if destination is not None else {}
|
||||
return _quantized_weight_state_dict(self, sd, prefix)
|
||||
|
||||
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
||||
weight = self.weight
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
from typing import Callable
|
||||
|
||||
class CallbacksMP:
|
||||
ON_CLONE = "on_clone"
|
||||
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
|
||||
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
|
||||
ON_LOAD = "on_load_after"
|
||||
ON_DETACH = "on_detach_after"
|
||||
ON_CLEANUP = "on_cleanup"
|
||||
|
||||
@ -1,16 +1,18 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
import uuid
|
||||
import math
|
||||
import collections
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import comfy.hooks
|
||||
import comfy.patcher_extension
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.controlnet import ControlBase
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
@ -119,6 +121,47 @@ def cleanup_additional_models(models):
|
||||
if hasattr(m, 'cleanup'):
|
||||
m.cleanup()
|
||||
|
||||
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
|
||||
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
|
||||
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) == 0:
|
||||
return
|
||||
extra_devices = [x.load_device for x in multigpu_models]
|
||||
# handle controlnets
|
||||
controlnets: set[ControlBase] = set()
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
if 'control' in kk:
|
||||
controlnets.add(kk['control'])
|
||||
if len(controlnets) > 0:
|
||||
# first, unload all controlnet clones
|
||||
for cnet in list(controlnets):
|
||||
cnet_models = cnet.get_models()
|
||||
for cm in cnet_models:
|
||||
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
|
||||
|
||||
# next, make sure each controlnet has a deepclone for all relevant devices
|
||||
for cnet in controlnets:
|
||||
curr_cnet = cnet
|
||||
while curr_cnet is not None:
|
||||
for device in extra_devices:
|
||||
if device not in curr_cnet.multigpu_clones:
|
||||
curr_cnet.deepclone_multigpu(device, autoregister=True)
|
||||
curr_cnet = curr_cnet.previous_controlnet
|
||||
# since all device clones are now present, recreate the linked list for cloned cnets per device
|
||||
for cnet in controlnets:
|
||||
curr_cnet = cnet
|
||||
while curr_cnet is not None:
|
||||
prev_cnet = curr_cnet.previous_controlnet
|
||||
for device in extra_devices:
|
||||
device_cnet = curr_cnet.get_instance_for_device(device)
|
||||
prev_device_cnet = None
|
||||
if prev_cnet is not None:
|
||||
prev_device_cnet = prev_cnet.get_instance_for_device(device)
|
||||
device_cnet.set_previous_controlnet(prev_device_cnet)
|
||||
curr_cnet = prev_cnet
|
||||
# potentially handle gligen - since not widely used, ignored for now
|
||||
|
||||
def estimate_memory(model, noise_shape, conds):
|
||||
cond_shapes = collections.defaultdict(list)
|
||||
cond_shapes_min = {}
|
||||
@ -143,7 +186,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||
real_model: BaseModel = None
|
||||
model.match_multigpu_clones()
|
||||
preprocess_multigpu_conds(conds, model, model_options)
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
@ -155,7 +199,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
|
||||
memory_required += inference_memory
|
||||
minimum_memory_required += inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||
real_model = model.model
|
||||
real_model: BaseModel = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
@ -201,3 +245,18 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
||||
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||
copy_dict1=False)
|
||||
return to_load_options
|
||||
|
||||
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
|
||||
'''
|
||||
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
|
||||
'''
|
||||
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
|
||||
if len(multigpu_patchers) > 0:
|
||||
multigpu_dict: dict[torch.device, ModelPatcher] = {}
|
||||
multigpu_dict[model_patcher.load_device] = model_patcher
|
||||
for x in multigpu_patchers:
|
||||
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
|
||||
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
|
||||
multigpu_dict[x.load_device] = x
|
||||
model_options["multigpu_clones"] = multigpu_dict
|
||||
return multigpu_patchers
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import comfy.model_management
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple, Any
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
@ -16,6 +18,7 @@ import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import comfy.context_windows
|
||||
import comfy.multigpu
|
||||
import comfy.utils
|
||||
import scipy.stats
|
||||
import numpy
|
||||
@ -141,7 +144,7 @@ def can_concat_cond(c1, c2):
|
||||
|
||||
return cond_equal_size(c1.conditioning, c2.conditioning)
|
||||
|
||||
def cond_cat(c_list):
|
||||
def cond_cat(c_list, device=None):
|
||||
temp = {}
|
||||
for x in c_list:
|
||||
for k in x:
|
||||
@ -153,6 +156,8 @@ def cond_cat(c_list):
|
||||
for k in temp:
|
||||
conds = temp[k]
|
||||
out[k] = conds[0].concat(conds[1:])
|
||||
if device is not None and hasattr(out[k], 'to'):
|
||||
out[k] = out[k].to(device)
|
||||
|
||||
return out
|
||||
|
||||
@ -212,7 +217,12 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
# NOTE: keep in sync with _calc_cond_batch_multigpu below. Shared logic
|
||||
# (hooked_to_run accumulation, memory-fit batching, per-chunk output
|
||||
# aggregation) is duplicated there with per-device scheduling layered on top.
|
||||
if 'multigpu_clones' in model_options:
|
||||
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
@ -244,7 +254,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep)
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
@ -344,6 +354,239 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
|
||||
return out_conds
|
||||
|
||||
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
# NOTE: keep in sync with _calc_cond_batch above. Same conds-by-hooks
|
||||
# accumulation, memory-fit batching, and output aggregation, but adds a
|
||||
# per-device scheduler, per-device patcher/control lookup, tensor .to(device)
|
||||
# placement, and MultiGPUThreadPool dispatch around the inner loop.
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
||||
default_conds = []
|
||||
has_default_conds = False
|
||||
|
||||
output_device = x_in.device
|
||||
|
||||
for i in range(len(conds)):
|
||||
out_conds.append(torch.zeros_like(x_in))
|
||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||
|
||||
cond = conds[i]
|
||||
default_c = []
|
||||
if cond is not None:
|
||||
for x in cond:
|
||||
if 'default' in x:
|
||||
default_c.append(x)
|
||||
has_default_conds = True
|
||||
continue
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
default_conds.append(default_c)
|
||||
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
|
||||
devices = list(model_options['multigpu_clones'].keys())
|
||||
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
|
||||
# Track conds currently scheduled per device; single source of truth for capacity checks.
|
||||
device_load: dict[torch.device, int] = {d: 0 for d in devices}
|
||||
|
||||
total_conds = sum(len(to_run) for to_run in hooked_to_run.values())
|
||||
conds_per_device = max(1, math.ceil(total_conds / len(devices)))
|
||||
|
||||
def next_available_device(start: int) -> tuple[int, torch.device]:
|
||||
"""Return (index, device) for the next device with remaining capacity, starting at `start`.
|
||||
|
||||
Scans at most len(devices) positions, so this always terminates. Raises if no device
|
||||
has remaining capacity, which would indicate a bug in conds_per_device accounting.
|
||||
"""
|
||||
for offset in range(len(devices)):
|
||||
i = (start + offset) % len(devices)
|
||||
if device_load[devices[i]] < conds_per_device:
|
||||
return i, devices[i]
|
||||
raise RuntimeError(
|
||||
f"MultiGPU scheduler: all {len(devices)} devices at capacity "
|
||||
f"({conds_per_device}) but conds remain to schedule"
|
||||
)
|
||||
|
||||
# run every hooked_to_run separately
|
||||
index_device = 0
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
while len(to_run) > 0:
|
||||
index_device, current_device = next_available_device(index_device)
|
||||
remaining_capacity = conds_per_device - device_load[current_device]
|
||||
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
# collect candidate indices that can be concatenated with `first`, up to remaining capacity
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity:
|
||||
to_batch_temp += [x]
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = comfy.model_management.get_free_memory(current_device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
cond_shapes = collections.defaultdict(list)
|
||||
for tt in batch_amount:
|
||||
for k, v in to_run[tt][0].conditioning.items():
|
||||
cond_shapes[k].append(v.size())
|
||||
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
conds_to_batch = [to_run.pop(x) for x in to_batch]
|
||||
device_load[current_device] += len(conds_to_batch)
|
||||
device_batched_hooked_to_run.setdefault(current_device, []).append((hooks, conds_to_batch))
|
||||
|
||||
if device_load[current_device] >= conds_per_device:
|
||||
index_device += 1
|
||||
|
||||
class thread_result(NamedTuple):
|
||||
output: Any
|
||||
mult: Any
|
||||
area: Any
|
||||
batch_chunks: int
|
||||
cond_or_uncond: Any
|
||||
error: Exception = None
|
||||
|
||||
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
||||
try:
|
||||
# TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
|
||||
# we extend multigpu QA beyond CUDA. Unconditional call crashes on
|
||||
# XPU/NPU/MPS/CPU/DirectML backends.
|
||||
torch.cuda.set_device(device)
|
||||
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
||||
# run every hooked_to_run separately
|
||||
with torch.no_grad():
|
||||
for hooks, to_batch in batch_tuple:
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
uuids = []
|
||||
area = []
|
||||
control: ControlBase = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = x
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
uuids.append(p.uuid)
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x).to(device)
|
||||
c = cond_cat(c, device=device)
|
||||
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
|
||||
|
||||
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||
model_options['transformer_options'],
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
|
||||
transformer_options.get("patches", {}),
|
||||
patches
|
||||
)
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["uuids"] = uuids[:]
|
||||
transformer_options["sigmas"] = timestep.to(device)
|
||||
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
|
||||
transformer_options["multigpu_thread_device"] = device
|
||||
|
||||
cast_transformer_options(transformer_options, device=device)
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if control is not None:
|
||||
device_control = control.get_instance_for_device(device)
|
||||
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
||||
else:
|
||||
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
|
||||
# TODO: non-NVIDIA support -- the `.to(output_device)` copies
|
||||
# above are async on CUDA, so the main thread's aggregation
|
||||
# could race with in-flight transfers. CUDA-only QA has not
|
||||
# surfaced this in practice, but before extending multigpu
|
||||
# beyond NVIDIA add a `torch.cuda.synchronize(output_device)`
|
||||
# here (guarded by `output_device.type == "cuda"`).
|
||||
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
|
||||
except Exception as e:
|
||||
results.append(thread_result(None, None, None, None, None, error=e))
|
||||
raise
|
||||
|
||||
|
||||
def _handle_batch_pooled(device, batch_tuple):
|
||||
worker_results = []
|
||||
_handle_batch(device, batch_tuple, worker_results)
|
||||
return worker_results
|
||||
|
||||
results: list[thread_result] = []
|
||||
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
|
||||
|
||||
# Submit all GPU work to pool threads
|
||||
pool_devices = []
|
||||
for device, batch_tuple in device_batched_hooked_to_run.items():
|
||||
if thread_pool is not None:
|
||||
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
|
||||
pool_devices.append(device)
|
||||
else:
|
||||
# Fallback: no pool, run everything on main thread
|
||||
_handle_batch(device, batch_tuple, results)
|
||||
|
||||
# Collect results from pool workers
|
||||
for device in pool_devices:
|
||||
worker_results, error = thread_pool.get_result(device)
|
||||
if error is not None:
|
||||
raise error
|
||||
results.extend(worker_results)
|
||||
|
||||
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
|
||||
if error is not None:
|
||||
raise error
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
|
||||
return out_conds
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
||||
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
||||
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
||||
@ -642,12 +885,21 @@ def calculate_start_end_timesteps(model, conds):
|
||||
|
||||
def pre_run_control(model, conds):
|
||||
s = model.model_sampling
|
||||
# Per-device model lookup so multigpu control clones get the matching
|
||||
# diffusion_model (e.g. QwenFunControlNet stashes it into extra_args).
|
||||
device_models: dict = {}
|
||||
patcher = getattr(model, "current_patcher", None)
|
||||
if patcher is not None:
|
||||
for p in patcher.get_additional_models_with_key("multigpu"):
|
||||
device_models[p.load_device] = p.model
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
|
||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||
if 'control' in x:
|
||||
x['control'].pre_run(model, percent_to_timestep_function)
|
||||
for device, device_cnet in x['control'].multigpu_clones.items():
|
||||
device_cnet.pre_run(device_models.get(device, model), percent_to_timestep_function)
|
||||
|
||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
cond_cnets = []
|
||||
@ -890,7 +1142,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
to_load_options = model_options.get("to_load_options", None)
|
||||
if to_load_options is None:
|
||||
return
|
||||
cast_transformer_options(to_load_options, device, dtype)
|
||||
|
||||
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
|
||||
casts = []
|
||||
if device is not None:
|
||||
casts.append(device)
|
||||
@ -899,18 +1153,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
# if nothing to apply, do nothing
|
||||
if len(casts) == 0:
|
||||
return
|
||||
|
||||
# try to call .to on patches
|
||||
if "patches" in to_load_options:
|
||||
patches = to_load_options["patches"]
|
||||
if "patches" in transformer_options:
|
||||
patches = transformer_options["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "to"):
|
||||
for cast in casts:
|
||||
patch_list[i] = patch_list[i].to(cast)
|
||||
if "patches_replace" in to_load_options:
|
||||
patches = to_load_options["patches_replace"]
|
||||
if "patches_replace" in transformer_options:
|
||||
patches = transformer_options["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
@ -920,8 +1173,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
# try to call .to on any wrappers/callbacks
|
||||
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||
for wc_name in wrappers_and_callbacks:
|
||||
if wc_name in to_load_options:
|
||||
wc: dict[str, list] = to_load_options[wc_name]
|
||||
if wc_name in transformer_options:
|
||||
wc: dict[str, list] = transformer_options[wc_name]
|
||||
for wc_dict in wc.values():
|
||||
for wc_list in wc_dict.values():
|
||||
for i in range(len(wc_list)):
|
||||
@ -929,7 +1182,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
for cast in casts:
|
||||
wc_list[i] = wc_list[i].to(cast)
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher: ModelPatcher):
|
||||
self.model_patcher = model_patcher
|
||||
@ -984,16 +1236,32 @@ class CFGGuider:
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
noise = noise.to(device=device, dtype=torch.float32)
|
||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
|
||||
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
# Create persistent thread pool for all GPU devices (main + extras)
|
||||
if multigpu_patchers:
|
||||
extra_devices = [p.load_device for p in multigpu_patchers]
|
||||
all_devices = [device] + extra_devices
|
||||
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
|
||||
|
||||
with comfy.model_management.cuda_device_context(device):
|
||||
try:
|
||||
noise = noise.to(device=device, dtype=torch.float32)
|
||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
|
||||
self.model_patcher.pre_run()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
|
||||
if thread_pool is not None:
|
||||
thread_pool.shutdown()
|
||||
self.model_patcher.cleanup()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.cleanup()
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
|
||||
402
comfy/sd.py
402
comfy/sd.py
@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import torch
|
||||
from enum import Enum
|
||||
@ -50,6 +49,7 @@ import comfy.text_encoders.lt
|
||||
import comfy.text_encoders.hunyuan_video
|
||||
import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.pixeldit
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.hidream
|
||||
import comfy.text_encoders.ace
|
||||
@ -69,6 +69,7 @@ import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.gemma4
|
||||
import comfy.text_encoders.cogvideo
|
||||
import comfy.text_encoders.sa3
|
||||
import comfy.text_encoders.gpt_oss
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -335,41 +336,43 @@ class CLIP:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model(tokens)
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
device = self.patcher.load_device
|
||||
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||
all_hooks.reset()
|
||||
self.patcher.patch_hooks(None)
|
||||
if show_pbar:
|
||||
pbar = ProgressBar(len(scheduled_keyframes))
|
||||
|
||||
for scheduled_opts in scheduled_keyframes:
|
||||
t_range = scheduled_opts[0]
|
||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||
if "start_percent" in add_dict:
|
||||
if t_range[1] < add_dict["start_percent"]:
|
||||
continue
|
||||
if "end_percent" in add_dict:
|
||||
if t_range[0] > add_dict["end_percent"]:
|
||||
continue
|
||||
hooks_keyframes = scheduled_opts[1]
|
||||
for hook, keyframe in hooks_keyframes:
|
||||
hook.hook_keyframe._current_keyframe = keyframe
|
||||
# apply appropriate hooks with values that match new hook_keyframe
|
||||
self.patcher.patch_hooks(all_hooks)
|
||||
# perform encoding as normal
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
pooled_dict = {"pooled_output": pooled}
|
||||
# add clip_start_percent and clip_end_percent in pooled
|
||||
pooled_dict["clip_start_percent"] = t_range[0]
|
||||
pooled_dict["clip_end_percent"] = t_range[1]
|
||||
# add/update any keys with the provided add_dict
|
||||
pooled_dict.update(add_dict)
|
||||
# add hooks stored on clip
|
||||
self.add_hooks_to_dict(pooled_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
if show_pbar:
|
||||
pbar.update(1)
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
with model_management.cuda_device_context(device):
|
||||
for scheduled_opts in scheduled_keyframes:
|
||||
t_range = scheduled_opts[0]
|
||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||
if "start_percent" in add_dict:
|
||||
if t_range[1] < add_dict["start_percent"]:
|
||||
continue
|
||||
if "end_percent" in add_dict:
|
||||
if t_range[0] > add_dict["end_percent"]:
|
||||
continue
|
||||
hooks_keyframes = scheduled_opts[1]
|
||||
for hook, keyframe in hooks_keyframes:
|
||||
hook.hook_keyframe._current_keyframe = keyframe
|
||||
# apply appropriate hooks with values that match new hook_keyframe
|
||||
self.patcher.patch_hooks(all_hooks)
|
||||
# perform encoding as normal
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
pooled_dict = {"pooled_output": pooled}
|
||||
# add clip_start_percent and clip_end_percent in pooled
|
||||
pooled_dict["clip_start_percent"] = t_range[0]
|
||||
pooled_dict["clip_end_percent"] = t_range[1]
|
||||
# add/update any keys with the provided add_dict
|
||||
pooled_dict.update(add_dict)
|
||||
# add hooks stored on clip
|
||||
self.add_hooks_to_dict(pooled_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
if show_pbar:
|
||||
pbar.update(1)
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
all_hooks.reset()
|
||||
return all_cond_pooled
|
||||
|
||||
@ -383,8 +386,12 @@ class CLIP:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model(tokens)
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
device = self.patcher.load_device
|
||||
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||
|
||||
with model_management.cuda_device_context(device):
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
|
||||
cond, pooled = o[:2]
|
||||
if return_dict:
|
||||
out = {"cond": cond, "pooled_output": pooled}
|
||||
@ -446,9 +453,12 @@ class CLIP:
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
self.load_model(tokens)
|
||||
device = self.patcher.load_device
|
||||
self.cond_stage_model.set_clip_options({"layer": None})
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||
|
||||
with model_management.cuda_device_context(device):
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=True):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
@ -1026,50 +1036,52 @@ class VAE:
|
||||
do_tile = False
|
||||
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||
samples_in = samples_in[:, :, 0]
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
# Pre-allocate output for VAEs that support direct buffer writes
|
||||
preallocated = False
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
preallocated = True
|
||||
with model_management.cuda_device_context(self.device):
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||
if preallocated:
|
||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||
else:
|
||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number].copy_(out)
|
||||
del out
|
||||
self.process_output(pixel_samples[x:x+batch_number])
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
# Pre-allocate output for VAEs that support direct buffer writes
|
||||
preallocated = False
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
preallocated = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||
if preallocated:
|
||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||
else:
|
||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number].copy_(out)
|
||||
del out
|
||||
self.process_output(pixel_samples[x:x+batch_number])
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
@ -1087,20 +1099,21 @@ class VAE:
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
output = self.decode_tiled_(samples, **args)
|
||||
elif dims == 3:
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||
if tile_t is not None:
|
||||
args["tile_t"] = max(2, tile_t)
|
||||
with model_management.cuda_device_context(self.device):
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
output = self.decode_tiled_(samples, **args)
|
||||
elif dims == 3:
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||
if tile_t is not None:
|
||||
args["tile_t"] = max(2, tile_t)
|
||||
|
||||
output = self.decode_tiled_3d(samples, **args)
|
||||
output = self.decode_tiled_3d(samples, **args)
|
||||
return output.movedim(1, -1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
@ -1113,44 +1126,46 @@ class VAE:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
else:
|
||||
pixel_samples = pixel_samples.unsqueeze(2)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||
|
||||
with model_management.cuda_device_context(self.device):
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||
else:
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
return samples
|
||||
|
||||
@ -1176,26 +1191,27 @@ class VAE:
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
samples = self.encode_tiled_1d(pixel_samples, **args)
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
with model_management.cuda_device_context(self.device):
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
samples = self.encode_tiled_1d(pixel_samples, **args)
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
|
||||
return samples
|
||||
|
||||
@ -1269,6 +1285,8 @@ class CLIPType(Enum):
|
||||
FLUX2 = 25
|
||||
LONGCAT_IMAGE = 26
|
||||
COGVIDEOX = 27
|
||||
LENS = 28
|
||||
PIXELDIT = 29
|
||||
|
||||
|
||||
|
||||
@ -1321,6 +1339,7 @@ class TEModel(Enum):
|
||||
GEMMA_4_E2B = 30
|
||||
GEMMA_4_31B = 31
|
||||
T5_GEMMA = 32
|
||||
GPT_OSS_20B = 33
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1362,6 +1381,9 @@ def detect_te_model(sd):
|
||||
else:
|
||||
return TEModel.GEMMA_3_4B
|
||||
return TEModel.GEMMA_2_2B
|
||||
# Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512).
|
||||
if "layers.0.self_attn.sinks" in sd and "layers.0.mlp.experts.gate_up_proj.weight" in sd:
|
||||
return TEModel.GPT_OSS_20B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||
if weight.shape[0] == 256:
|
||||
@ -1508,8 +1530,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.tokenizer = variant.tokenizer
|
||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||
elif te_model == TEModel.GEMMA_2_2B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
if clip_type == CLIPType.PIXELDIT:
|
||||
clip_target.clip = comfy.text_encoders.pixeldit.pixeldit_te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.GEMMA_3_4B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
|
||||
@ -1544,6 +1570,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
|
||||
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
||||
elif te_model == TEModel.GPT_OSS_20B:
|
||||
clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer
|
||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||
elif te_model == TEModel.QWEN3_4B:
|
||||
if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2:
|
||||
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b")
|
||||
@ -1710,12 +1740,52 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||
if out is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||
if output_model and out[0] is not None:
|
||||
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
if output_clip and out[1] is not None:
|
||||
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
if out[0] is not None:
|
||||
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
|
||||
# Register reload factories for the CLIP and VAE produced by the same checkpoint so
|
||||
# ModelPatcher.deepclone_multigpu can spawn per-device copies (Select{CLIP,VAE}Device,
|
||||
# MultiGPU work-units, etc.) without falling back to copy.deepcopy of an
|
||||
# already-loaded module.
|
||||
if out[1] is not None and getattr(out[1], "patcher", None) is not None:
|
||||
out[1].patcher.cached_patcher_init = (load_checkpoint_clip_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
if out[2] is not None and getattr(out[2], "patcher", None) is not None:
|
||||
out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
return out
|
||||
|
||||
|
||||
def load_checkpoint_clip_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
"""Reload only the CLIP patcher from a checkpoint. Used as the cached_patcher_init
|
||||
factory for the CLIP returned by load_checkpoint_guess_config."""
|
||||
_, clip, _, _ = load_checkpoint_guess_config(
|
||||
ckpt_path,
|
||||
output_vae=False,
|
||||
output_clip=True,
|
||||
output_clipvision=False,
|
||||
embedding_directory=embedding_directory,
|
||||
output_model=False,
|
||||
model_options=model_options,
|
||||
te_model_options=te_model_options,
|
||||
disable_dynamic=disable_dynamic,
|
||||
)
|
||||
return clip.patcher
|
||||
|
||||
|
||||
def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
"""Reload only the VAE patcher from a checkpoint. Used as the cached_patcher_init
|
||||
factory for the VAE returned by load_checkpoint_guess_config."""
|
||||
_, _, vae, _ = load_checkpoint_guess_config(
|
||||
ckpt_path,
|
||||
output_vae=True,
|
||||
output_clip=False,
|
||||
output_clipvision=False,
|
||||
embedding_directory=embedding_directory,
|
||||
output_model=False,
|
||||
model_options=model_options,
|
||||
te_model_options=te_model_options,
|
||||
disable_dynamic=disable_dynamic,
|
||||
)
|
||||
return vae.patcher
|
||||
|
||||
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
|
||||
embedding_directory=embedding_directory,
|
||||
@ -1742,7 +1812,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||
load_device = model_management.get_torch_device()
|
||||
load_device = model_options.get("load_device", model_management.get_torch_device())
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
@ -1782,13 +1852,15 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
offload_device = model_options.get("offload_device", model_management.unet_offload_device())
|
||||
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
||||
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||
vae = VAE(sd=vae_sd, metadata=metadata)
|
||||
vae_device = model_options.get("load_device", None)
|
||||
vae = VAE(sd=vae_sd, metadata=metadata, device=vae_device)
|
||||
|
||||
if output_clip:
|
||||
if te_model_options.get("custom_operations", None) is None:
|
||||
@ -1872,7 +1944,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
load_device = model_management.get_torch_device()
|
||||
load_device = model_options.get("load_device", model_management.get_torch_device())
|
||||
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||
|
||||
if model_config is not None:
|
||||
@ -1897,7 +1969,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
else:
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
offload_device = model_options.get("offload_device", model_management.unet_offload_device())
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.quant_config is not None:
|
||||
weight_dtype = None
|
||||
@ -1939,6 +2011,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
|
||||
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
|
||||
return model
|
||||
|
||||
|
||||
def load_vae_patcher(vae_path, metadata=None, device=None, disable_dynamic=False):
|
||||
"""Reload a disk-backed VAE from ``vae_path`` and return its patcher.
|
||||
|
||||
Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so
|
||||
:meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
|
||||
fresh, untainted VAE patcher (no inherited per-device load state, no
|
||||
in-place quantization fallout) for multigpu work-units and the
|
||||
SelectVAEDevice node. The optional ``device`` matches the source loader's
|
||||
VAE initialization path; the deepclone's ``load_device`` still controls
|
||||
where the cloned patcher is targeted.
|
||||
"""
|
||||
if metadata is None:
|
||||
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
||||
else:
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = VAE(sd=sd, metadata=metadata, device=device)
|
||||
vae.throw_exception_if_invalid()
|
||||
return vae.patcher
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||
|
||||
@ -30,6 +30,7 @@ import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.cogvideo
|
||||
import comfy.text_encoders.hidream_o1
|
||||
import comfy.text_encoders.pixeldit
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -829,6 +830,50 @@ class Flux2(Flux):
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class Lens(supported_models_base.BASE):
|
||||
"""Microsoft Lens (3.8B dual-stream MMDiT, GPT-OSS-20B text features, Flux2 VAE)."""
|
||||
|
||||
unet_config = {
|
||||
"image_model": "lens",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 1.829, # Default mu for 1440x1440 (and any seq_len > 4300
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux2
|
||||
|
||||
memory_usage_factor = 4.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32] # fp16 causes NaNs
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.Lens(self, model_type=model_base.ModelType.FLUX, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
for hint in ("gpt_oss.transformer.", ""):
|
||||
full_prefix = "{}{}".format(pref, hint)
|
||||
if "{}layers.0.self_attn.sinks".format(full_prefix) in state_dict:
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix)
|
||||
return supported_models_base.ClipTarget(
|
||||
comfy.text_encoders.gpt_oss.LensTokenizer,
|
||||
comfy.text_encoders.gpt_oss.lens_te(**detect),
|
||||
)
|
||||
return supported_models_base.ClipTarget(
|
||||
comfy.text_encoders.gpt_oss.LensTokenizer,
|
||||
comfy.text_encoders.gpt_oss.lens_te(),
|
||||
)
|
||||
|
||||
|
||||
class GenmoMochi(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "mochi_preview",
|
||||
@ -1159,6 +1204,72 @@ class ZImagePixelSpace(ZImage):
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.ZImagePixelSpace(self, device=device)
|
||||
|
||||
class PixelDiTT2I(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "pixeldit_t2i",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 4.0, # 1024px stage 3 default; 2.0 for 512px
|
||||
}
|
||||
|
||||
latent_format = latent_formats.PixelDiTPixel
|
||||
memory_usage_factor = 0.04
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.PixelDiTT2I(self, device=device)
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
# pixel_dim from pixel_embedder.proj.weight = (pixel_dim, in_channels); p2 derived per-weight from total // (6 * pixel_dim).
|
||||
pixel_dim = next(v for k, v in state_dict.items() if k.endswith("pixel_embedder.proj.weight")).shape[0]
|
||||
|
||||
out = {}
|
||||
marker = ".adaLN_modulation.0."
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith("_repa_projector") or k.startswith("net_ema."):
|
||||
continue
|
||||
if k.startswith("core."):
|
||||
k = k[len("core."):]
|
||||
elif k.startswith("net."):
|
||||
k = k[len("net."):]
|
||||
if "pixel_blocks." in k and marker in k:
|
||||
# Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM
|
||||
p2 = v.shape[0] // (6 * pixel_dim)
|
||||
trail = v.shape[1:] # () for bias, (in_dim,) for weight
|
||||
vv = v.view(p2, 6, pixel_dim, *trail)
|
||||
base, suffix = k.split(marker)
|
||||
out[f"{base}.adaLN_modulation_msa.{suffix}"] = vv[:, 0:3].reshape(3 * p2 * pixel_dim, *trail).contiguous()
|
||||
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = vv[:, 3:6].reshape(3 * p2 * pixel_dim, *trail).contiguous()
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(
|
||||
comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer,
|
||||
comfy.text_encoders.pixeldit.PixelDiTGemma2TE,
|
||||
)
|
||||
|
||||
class PiD(PixelDiTT2I):
|
||||
unet_config = {
|
||||
"image_model": "pid",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0]
|
||||
}
|
||||
|
||||
memory_usage_factor = 0.04
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.PiD(self, device=device)
|
||||
|
||||
class WAN21_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -2069,6 +2180,8 @@ models = [
|
||||
CosmosI2VPredict2,
|
||||
ZImagePixelSpace,
|
||||
ZImage,
|
||||
PiD,
|
||||
PixelDiTT2I,
|
||||
Lumina2,
|
||||
WAN22_T2V,
|
||||
WAN21_CausalAR_T2V,
|
||||
@ -2096,6 +2209,7 @@ models = [
|
||||
Omnigen2,
|
||||
QwenImage,
|
||||
Flux2,
|
||||
Lens,
|
||||
Kandinsky5Image,
|
||||
Kandinsky5,
|
||||
Anima,
|
||||
|
||||
600
comfy/text_encoders/gpt_oss.py
Normal file
600
comfy/text_encoders/gpt_oss.py
Normal file
@ -0,0 +1,600 @@
|
||||
"""GPT-OSS text encoder for Lens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy import sd1_clip
|
||||
from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device
|
||||
from comfy.text_encoders.llama import RMSNorm, apply_rope
|
||||
|
||||
|
||||
@dataclass
|
||||
class GptOss20BConfig:
|
||||
vocab_size: int = 201088
|
||||
hidden_size: int = 2880
|
||||
intermediate_size: int = 2880
|
||||
num_hidden_layers: int = 24
|
||||
num_attention_heads: int = 64
|
||||
num_key_value_heads: int = 8
|
||||
head_dim: int = 64
|
||||
num_local_experts: int = 32
|
||||
num_experts_per_tok: int = 4
|
||||
sliding_window: int = 128
|
||||
original_max_position_embeddings: int = 4096
|
||||
rope_theta: float = 150000.0
|
||||
rope_factor: float = 32.0
|
||||
rope_beta_fast: float = 32.0
|
||||
rope_beta_slow: float = 1.0
|
||||
rope_truncate: bool = False
|
||||
rms_norm_eps: float = 1e-5
|
||||
attention_bias: bool = True
|
||||
layer_types: Optional[List[str]] = None
|
||||
moe_alpha: float = 1.702
|
||||
moe_limit: float = 7.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention" if (i + 1) % 2 else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
|
||||
|
||||
def _yarn_inv_freq(head_dim: int, base: float, factor: float, beta_fast: float, beta_slow: float,
|
||||
original_max_position_embeddings: int, truncate: bool, device=None) -> tuple[torch.Tensor, float]:
|
||||
"""YARN inv_freq + attention scaling (matches transformers)."""
|
||||
dim = head_dim
|
||||
|
||||
def find_correction_dim(num_rotations: float) -> float:
|
||||
return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||
2 * math.log(base)
|
||||
)
|
||||
|
||||
def find_correction_range() -> tuple[float, float]:
|
||||
low = find_correction_dim(beta_fast)
|
||||
high = find_correction_dim(beta_slow)
|
||||
if truncate:
|
||||
low = math.floor(low)
|
||||
high = math.ceil(high)
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
def linear_ramp_factor(min_: float, max_: float, n: int) -> torch.Tensor:
|
||||
if min_ == max_:
|
||||
max_ += 0.001
|
||||
linear = (torch.arange(n, dtype=torch.float32, device=device) - min_) / (max_ - min_)
|
||||
return torch.clamp(linear, 0, 1)
|
||||
|
||||
def get_mscale(scale: float) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * math.log(scale) + 1.0
|
||||
|
||||
attention_scaling = get_mscale(factor)
|
||||
|
||||
pos_freqs = base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
||||
|
||||
low, high = find_correction_range()
|
||||
extrap_factor = 1 - linear_ramp_factor(low, high, dim // 2)
|
||||
inv_freq = inv_freq_interpolation * (1 - extrap_factor) + inv_freq_extrapolation * extrap_factor
|
||||
return inv_freq, attention_scaling
|
||||
|
||||
|
||||
def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_ids: torch.Tensor, dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
inv_freq_e = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
pos_e = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_e @ pos_e).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = (emb.cos() * attention_scaling).to(dtype).unsqueeze(1)
|
||||
sin = (emb.sin() * attention_scaling).to(dtype).unsqueeze(1)
|
||||
sin_split = sin.shape[-1] // 2
|
||||
return cos, sin[..., :sin_split], -sin[..., sin_split:]
|
||||
|
||||
|
||||
def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sinks: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor], num_heads: int, num_kv_groups: int) -> torch.Tensor:
|
||||
"""Attention with per-head sinks.
|
||||
|
||||
Sinks add a learned term to each row's softmax denominator but contribute
|
||||
nothing to the output. We fake this by appending one zero k/v position and
|
||||
putting the sink logit in the mask at that column.
|
||||
"""
|
||||
|
||||
if num_kv_groups > 1 and not TORCH_HAS_GQA:
|
||||
k = k.repeat_interleave(num_kv_groups, dim=1)
|
||||
v = v.repeat_interleave(num_kv_groups, dim=1)
|
||||
|
||||
B, _, S_q, D = q.shape
|
||||
H_kv = k.shape[1]
|
||||
S_kv = k.shape[-2]
|
||||
|
||||
k = torch.cat([k, k.new_zeros(B, H_kv, 1, D)], dim=-2)
|
||||
v = torch.cat([v, v.new_zeros(B, H_kv, 1, D)], dim=-2)
|
||||
|
||||
sinks_col = sinks.to(q.dtype).view(1, num_heads, 1, 1).expand(B, num_heads, S_q, 1)
|
||||
if attention_mask is not None:
|
||||
mask_left = attention_mask[..., :S_kv].expand(B, num_heads, S_q, S_kv)
|
||||
else:
|
||||
mask_left = q.new_zeros(B, num_heads, S_q, S_kv)
|
||||
mask = torch.cat([mask_left, sinks_col], dim=-1)
|
||||
|
||||
op = optimized_attention_for_device(q.device, mask=True, small_input=True)
|
||||
return op(q, k, v, num_heads, mask=mask, skip_reshape=True, enable_gqa=True)
|
||||
|
||||
|
||||
class GptOssAttention(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.layer_type = config.layer_types[layer_idx]
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.hidden_size = config.hidden_size
|
||||
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
||||
|
||||
bias = config.attention_bias
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=bias, device=device, dtype=dtype)
|
||||
self.sinks = nn.Parameter(torch.empty(self.num_heads, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], freqs_cis) -> torch.Tensor:
|
||||
B, S, _ = hidden_states.shape
|
||||
|
||||
q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
q, k = apply_rope(q, k, freqs_cis)
|
||||
|
||||
out = _attention_with_sinks(q, k, v, self.sinks, attention_mask, self.num_heads, self.num_kv_groups)
|
||||
return self.o_proj(out)
|
||||
|
||||
|
||||
# Mixture of Experts
|
||||
|
||||
class GptOssTopKRouter(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.num_experts = config.num_local_experts
|
||||
self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype))
|
||||
self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
weight = comfy.ops.cast_to_input(self.weight, hidden_states, copy=False)
|
||||
bias = comfy.ops.cast_to_input(self.bias, hidden_states, copy=False)
|
||||
logits = F.linear(hidden_states, weight, bias)
|
||||
top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1)
|
||||
# Softmax over top-k slice only
|
||||
scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype)
|
||||
return scores, top_idx
|
||||
|
||||
|
||||
class GptOssExperts(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.alpha = config.moe_alpha
|
||||
self.limit = config.moe_limit
|
||||
|
||||
E = self.num_experts
|
||||
H = self.hidden_size
|
||||
I = self.intermediate_size
|
||||
|
||||
self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype)
|
||||
self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
|
||||
gate = gate_up[..., ::2]
|
||||
up = gate_up[..., 1::2]
|
||||
gate = gate.clamp(max=self.limit)
|
||||
up = up.clamp(min=-self.limit, max=self.limit)
|
||||
glu = gate * torch.sigmoid(gate * self.alpha)
|
||||
return torch.addcmul(glu, up, glu)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
|
||||
N = hidden_states.shape[0]
|
||||
top_k = router_indices.shape[-1]
|
||||
H = hidden_states.shape[-1]
|
||||
|
||||
per_pair = torch.zeros((N * top_k, H), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
|
||||
with self.gate_up_proj.bank_resident(hidden_states) as gate_up_bank, \
|
||||
self.down_proj.bank_resident(hidden_states) as down_bank:
|
||||
for ei in expert_hit:
|
||||
expert_idx = int(ei.item())
|
||||
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
|
||||
current = hidden_states[token_idx]
|
||||
|
||||
gate_up = gate_up_bank.expert_linear(current, expert_idx)
|
||||
gated = self._apply_gate(gate_up)
|
||||
expert_out = down_bank.expert_linear(gated, expert_idx)
|
||||
|
||||
weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
|
||||
|
||||
flat_idx = token_idx * top_k + top_k_pos
|
||||
per_pair[flat_idx] = weighted.to(per_pair.dtype)
|
||||
|
||||
return per_pair.view(N, top_k, H).sum(dim=1)
|
||||
|
||||
|
||||
class GptOssMLP(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.router = GptOssTopKRouter(config, device=device, dtype=dtype)
|
||||
self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
B, S, H = hidden_states.shape
|
||||
flat = hidden_states.reshape(-1, H)
|
||||
scores, idx = self.router(flat)
|
||||
out = self.experts(flat, idx, scores)
|
||||
return out.reshape(B, S, H)
|
||||
|
||||
|
||||
# Decoder layer + model
|
||||
|
||||
class GptOssDecoderLayer(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.layer_type = config.layer_types[layer_idx]
|
||||
|
||||
def forward(self, x: torch.Tensor, attention_masks: dict[str, Optional[torch.Tensor]], freqs_cis) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(x, attention_masks[self.layer_type], freqs_cis)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.post_attention_layernorm(x)
|
||||
x = self.mlp(x)
|
||||
x = residual + x
|
||||
return x
|
||||
|
||||
|
||||
def _make_full_causal_mask(B: int, S: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
|
||||
neg = torch.finfo(dtype).min
|
||||
mask = torch.full((S, S), neg, dtype=dtype, device=device).triu_(1)
|
||||
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
|
||||
if key_padding_mask is not None:
|
||||
kp = key_padding_mask.to(dtype=dtype)
|
||||
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
|
||||
mask = mask + kp
|
||||
return mask
|
||||
|
||||
|
||||
def _make_sliding_causal_mask(B: int, S: int, window: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
|
||||
neg = torch.finfo(dtype).min
|
||||
i = torch.arange(S, device=device).view(-1, 1)
|
||||
j = torch.arange(S, device=device).view(1, -1)
|
||||
keep = (j <= i) & (j > i - window)
|
||||
mask = torch.where(keep, torch.zeros((), dtype=dtype, device=device), torch.full((), neg, dtype=dtype, device=device))
|
||||
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
|
||||
if key_padding_mask is not None:
|
||||
kp = key_padding_mask.to(dtype=dtype)
|
||||
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
|
||||
mask = mask + kp
|
||||
return mask
|
||||
|
||||
|
||||
class GptOssModel(nn.Module):
|
||||
"""GPT-OSS decoder with multi-layer hidden-state capture + early exit."""
|
||||
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
GptOssDecoderLayer(config, i, device=device, dtype=dtype, ops=ops)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
# Always build on CPU so the buffer survives meta-device construction.
|
||||
inv_freq, attn_scaling = _yarn_inv_freq(
|
||||
head_dim=config.head_dim,
|
||||
base=config.rope_theta,
|
||||
factor=config.rope_factor,
|
||||
beta_fast=config.rope_beta_fast,
|
||||
beta_slow=config.rope_beta_slow,
|
||||
original_max_position_embeddings=config.original_max_position_embeddings,
|
||||
truncate=config.rope_truncate,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
self.register_buffer("rope_inv_freq", inv_freq, persistent=False)
|
||||
self.rope_attention_scaling = float(attn_scaling)
|
||||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
return self.config.num_hidden_layers
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def _build_attention_masks(self, B: int, S: int, attention_mask: Optional[torch.Tensor], dtype: torch.dtype, device,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
full = _make_full_causal_mask(B, S, attention_mask, dtype, device)
|
||||
masks = {"full_attention": full}
|
||||
if any(t == "sliding_attention" for t in self.config.layer_types):
|
||||
masks["sliding_attention"] = _make_sliding_causal_mask(
|
||||
B, S, self.config.sliding_window, attention_mask, dtype, device
|
||||
)
|
||||
return masks
|
||||
|
||||
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None,
|
||||
capture_layers: Optional[Sequence[int]] = None) -> dict[str, Any]:
|
||||
B, S = input_ids.shape
|
||||
device = input_ids.device
|
||||
dtype = self.dtype
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids, out_dtype=dtype)
|
||||
|
||||
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
|
||||
freqs_cis = _build_freqs_cis(self.rope_inv_freq.to(device=device), self.rope_attention_scaling, position_ids, dtype)
|
||||
|
||||
attn_masks = self._build_attention_masks(B, S, attention_mask, dtype, device)
|
||||
|
||||
capture_layers = list(capture_layers) if capture_layers else None
|
||||
if capture_layers:
|
||||
max_layer = max(capture_layers)
|
||||
wanted = {idx: pos for pos, idx in enumerate(capture_layers)}
|
||||
captured: List[Optional[torch.Tensor]] = [None] * len(capture_layers)
|
||||
else:
|
||||
max_layer = self.config.num_hidden_layers - 1
|
||||
wanted = None
|
||||
captured = None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states = layer(hidden_states, attn_masks, freqs_cis)
|
||||
if wanted is not None and i in wanted:
|
||||
captured[wanted[i]] = hidden_states
|
||||
if i >= max_layer:
|
||||
break
|
||||
|
||||
if captured is not None:
|
||||
return {"hidden_states": captured}
|
||||
return {"last_hidden_state": self.norm(hidden_states)}
|
||||
|
||||
|
||||
# Lens chat-template constants (verbatim from the reference pipeline).
|
||||
_LENS_CHAT_SYSTEM = (
|
||||
"Describe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background."
|
||||
)
|
||||
_LENS_CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description."
|
||||
LENS_TXT_OFFSET = 97
|
||||
LENS_SELECTED_LAYERS = (5, 11, 17, 23)
|
||||
LENS_MAX_TOKENS = 512
|
||||
|
||||
|
||||
# The reference GPT-OSS Harmony template injects today's date here
|
||||
_LENS_CHAT_DATE = "2026-05-23"
|
||||
|
||||
|
||||
def _lens_render_chat(prompt: str) -> str:
|
||||
"""Render the Lens prompt in GPT-OSS Harmony format."""
|
||||
return (
|
||||
f"<|start|>system<|message|>"
|
||||
f"You are ChatGPT, a large language model trained by OpenAI.\n"
|
||||
f"Knowledge cutoff: 2024-06\n"
|
||||
f"Current date: {_LENS_CHAT_DATE}\n\n"
|
||||
f"Reasoning: medium\n\n"
|
||||
f"# Valid channels: analysis, commentary, final. "
|
||||
f"Channel must be included for every message.<|end|>"
|
||||
f"<|start|>developer<|message|># Instructions\n\n"
|
||||
f"{_LENS_CHAT_SYSTEM}\n\n<|end|>"
|
||||
f"<|start|>user<|message|>{prompt}<|end|>"
|
||||
f"<|start|>assistant<|channel|>analysis<|message|>"
|
||||
f"{_LENS_CHAT_ASSISTANT_THINKING}<|end|>"
|
||||
f"<|start|>assistant<|channel|>final<|message|>"
|
||||
)
|
||||
|
||||
|
||||
# GPT-OSS-20B fixed token IDs (from the tokenizer's added-tokens table).
|
||||
_LENS_PAD_TOKEN_ID = 199999 # <|endoftext|>
|
||||
|
||||
|
||||
class _GptOssRawTokenizer:
|
||||
"""Raw ``tokenizers.Tokenizer`` wrapper.
|
||||
|
||||
The tokenizer JSON ships as a byte tensor inside the encoder checkpoint
|
||||
(``tokenizer_json`` key) rather than as a committed file. Extracted
|
||||
it in ``sd.py`` and passes it here via ``tokenizer_data``.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer_json_bytes=None, **kwargs):
|
||||
from tokenizers import Tokenizer
|
||||
if isinstance(tokenizer_json_bytes, torch.Tensor):
|
||||
tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
|
||||
if tokenizer_json_bytes is None:
|
||||
raise ValueError(
|
||||
"Lens tokenizer requires the ``tokenizer_json`` byte tensor in the "
|
||||
"encoder state dict. Re-bundle the encoder via bundle_te.py so it "
|
||||
"embeds the tokenizer."
|
||||
)
|
||||
self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, tokenizer_data, **kwargs):
|
||||
return cls(tokenizer_json_bytes=tokenizer_data, **kwargs)
|
||||
|
||||
def __call__(self, text):
|
||||
return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids}
|
||||
|
||||
def get_vocab(self):
|
||||
return self.tokenizer.get_vocab()
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return [self.tokenizer.token_to_id(t) for t in tokens]
|
||||
|
||||
def decode(self, ids, **kwargs):
|
||||
return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False))
|
||||
|
||||
|
||||
class LensGptOssTokenizer(sd1_clip.SDTokenizer):
|
||||
tokenizer_json_data = None
|
||||
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_json = tokenizer_data.get("tokenizer_json", None)
|
||||
self.tokenizer_json_data = tokenizer_json
|
||||
super().__init__(
|
||||
tokenizer_json,
|
||||
embedding_directory=embedding_directory,
|
||||
pad_with_end=False,
|
||||
embedding_size=2880,
|
||||
embedding_key="gpt_oss",
|
||||
tokenizer_class=_GptOssRawTokenizer,
|
||||
has_start_token=False,
|
||||
has_end_token=False,
|
||||
pad_to_max_length=False,
|
||||
max_length=99999999,
|
||||
min_length=1,
|
||||
pad_left=False,
|
||||
disable_weights=True,
|
||||
tokenizer_data=tokenizer_data,
|
||||
)
|
||||
self.pad_token_id = _LENS_PAD_TOKEN_ID
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
|
||||
# Empty prompt -> empty list; encode_token_weights returns zeros (uncond).
|
||||
if not text or not text.strip():
|
||||
return [[]]
|
||||
rendered = _lens_render_chat(text)
|
||||
ids = self.tokenizer(rendered)["input_ids"]
|
||||
if len(ids) > LENS_MAX_TOKENS:
|
||||
ids = ids[:LENS_MAX_TOKENS]
|
||||
return [[(int(t), 1.0) for t in ids]]
|
||||
|
||||
def state_dict(self):
|
||||
if self.tokenizer_json_data is not None:
|
||||
return {"tokenizer_json": self.tokenizer_json_data}
|
||||
return {}
|
||||
|
||||
|
||||
class LensTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(
|
||||
embedding_directory=embedding_directory,
|
||||
tokenizer_data=tokenizer_data,
|
||||
name="gpt_oss",
|
||||
tokenizer=LensGptOssTokenizer,
|
||||
)
|
||||
|
||||
|
||||
class LensGptOssClipModel(nn.Module):
|
||||
"""SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor)."""
|
||||
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
|
||||
super().__init__()
|
||||
model_options = dict(model_options or {})
|
||||
|
||||
operations = model_options.get("custom_operations")
|
||||
if operations is None:
|
||||
quant_config = model_options.get("quantization_metadata") or {}
|
||||
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
|
||||
self.operations = operations
|
||||
|
||||
cfg_overrides = model_options.get("gpt_oss_config", {})
|
||||
self.config = GptOss20BConfig(**cfg_overrides)
|
||||
self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS))
|
||||
self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET))
|
||||
|
||||
self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations)
|
||||
self.num_layers = self.config.num_hidden_layers
|
||||
self.dtype = dtype
|
||||
self.execution_device = None
|
||||
self._pad_token_id = _LENS_PAD_TOKEN_ID
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.execution_device = options.get("execution_device", self.execution_device)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.execution_device = None
|
||||
|
||||
def _gather_tokens(self, token_weight_pairs):
|
||||
ids_list = [[int(t[0]) for t in batch] for batch in token_weight_pairs]
|
||||
pad_id = self._pad_token_id
|
||||
max_len = max(len(x) for x in ids_list)
|
||||
device = self.execution_device
|
||||
ids = torch.full((len(ids_list), max_len), pad_id, dtype=torch.long, device=device)
|
||||
mask = torch.zeros((len(ids_list), max_len), dtype=torch.long, device=device)
|
||||
for i, x in enumerate(ids_list):
|
||||
ids[i, : len(x)] = torch.tensor(x, dtype=torch.long, device=device)
|
||||
mask[i, : len(x)] = 1
|
||||
return ids, mask
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
# Empty negative: emit zero-length features + zero mask
|
||||
if all(len(batch) == 0 for batch in token_weight_pairs):
|
||||
device = self.execution_device
|
||||
B = len(token_weight_pairs)
|
||||
L = len(self.selected_layers)
|
||||
H = self.config.hidden_size
|
||||
flat = torch.zeros(B, 0, L * H, dtype=self.dtype, device=device)
|
||||
mask = torch.zeros(B, 0, dtype=torch.long, device=device)
|
||||
return flat, None, {"attention_mask": mask, "num_layers_stacked": L}
|
||||
|
||||
input_ids, attn_mask = self._gather_tokens(token_weight_pairs)
|
||||
out = self.transformer(input_ids, attention_mask=attn_mask, capture_layers=self.selected_layers)
|
||||
layers = out["hidden_states"] # list of L × [B, S, H]
|
||||
stacked = torch.stack(layers, dim=2) # [B, S, L, H]
|
||||
|
||||
offset = self.txt_offset
|
||||
if stacked.shape[1] > offset:
|
||||
stacked = stacked[:, offset:].contiguous()
|
||||
mask_trim = attn_mask[:, offset:]
|
||||
else:
|
||||
stacked = stacked[:, :0]
|
||||
mask_trim = attn_mask[:, :0]
|
||||
|
||||
B, S, L, H = stacked.shape
|
||||
flat = stacked.reshape(B, S, L * H)
|
||||
extra = {"attention_mask": mask_trim, "num_layers_stacked": L}
|
||||
return flat, None, extra
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False, assign=True)
|
||||
|
||||
|
||||
class LensTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||
super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {})
|
||||
|
||||
|
||||
def lens_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class LensTEModel_(LensTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||
mo = dict(model_options or {})
|
||||
if llama_quantization_metadata is not None:
|
||||
mo["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype is None and dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=mo)
|
||||
|
||||
return LensTEModel_
|
||||
104
comfy/text_encoders/pixeldit.py
Normal file
104
comfy/text_encoders/pixeldit.py
Normal file
@ -0,0 +1,104 @@
|
||||
import torch
|
||||
|
||||
from comfy import sd1_clip
|
||||
from .lumina2 import Gemma2BTokenizer, LuminaModel
|
||||
import comfy.text_encoders.llama
|
||||
|
||||
|
||||
class PixelDiTGemma2_2BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(
|
||||
device=device, layer=layer, layer_idx=layer_idx,
|
||||
textmodel_json_config={}, dtype=dtype,
|
||||
special_tokens={"start": 2, "pad": 0},
|
||||
layer_norm_hidden_state=False,
|
||||
model_class=comfy.text_encoders.llama.Gemma2_2B,
|
||||
enable_attention_masks=attention_mask,
|
||||
return_attention_masks=attention_mask,
|
||||
model_options=model_options,
|
||||
)
|
||||
|
||||
|
||||
_PIXELDIT_CHI_PROMPT = (
|
||||
'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions '
|
||||
"suitable for image generation. Evaluate the level of detail in the user prompt:\n"
|
||||
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, "
|
||||
"and spatial relationships to create vivid and concrete scenes.\n"
|
||||
"- If the prompt is already detailed, refine and enhance the existing details slightly without "
|
||||
"overcomplicating.\n"
|
||||
"Here are examples of how to transform or refine prompts:\n"
|
||||
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, "
|
||||
"sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n"
|
||||
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring "
|
||||
"glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus "
|
||||
"passing by towering glass skyscrapers.\n"
|
||||
"Please generate only the enhanced description for the prompt below and avoid including any "
|
||||
"additional commentary or evaluations:\n"
|
||||
"User Prompt: "
|
||||
)
|
||||
|
||||
_PIXELDIT_MAX_LENGTH = 300
|
||||
_PIXELDIT_CHI_PROMPT_DETECT_PREFIX = 'Given a user prompt, generate an "Enhanced prompt"'
|
||||
|
||||
|
||||
class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||
if tokenizer_data is None:
|
||||
tokenizer_data = {}
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
|
||||
name="gemma2_2b", tokenizer=Gemma2BTokenizer)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||
if not text.strip():
|
||||
return super().tokenize_with_weights("", return_word_ids=return_word_ids, disable_weights=True, min_length=_PIXELDIT_MAX_LENGTH)
|
||||
|
||||
chi_token_count = len(self.gemma2_2b.tokenizer(_PIXELDIT_CHI_PROMPT)["input_ids"])
|
||||
combined = text if text.startswith(_PIXELDIT_CHI_PROMPT_DETECT_PREFIX) else _PIXELDIT_CHI_PROMPT + text
|
||||
max_length_all = chi_token_count + _PIXELDIT_MAX_LENGTH - 2
|
||||
out = super().tokenize_with_weights(combined, return_word_ids=return_word_ids,
|
||||
disable_weights=True, min_length=max_length_all)
|
||||
out["gemma2_2b"] = [out["gemma2_2b"][0][:max_length_all]]
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.gemma2_2b.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return self.gemma2_2b.state_dict()
|
||||
|
||||
|
||||
class PixelDiTGemma2TE(LuminaModel):
|
||||
# PixelDiT's select_index: keep BOS + last 299 embeddings of the padded sequence.
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="gemma2_2b",
|
||||
clip_model=PixelDiTGemma2_2BModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
result = super().encode_token_weights(token_weight_pairs)
|
||||
cond, pooled = result[0], result[1]
|
||||
extra = result[2] if len(result) > 2 else None
|
||||
if cond.shape[1] > _PIXELDIT_MAX_LENGTH:
|
||||
cond = torch.cat([cond[:, :1], cond[:, -(_PIXELDIT_MAX_LENGTH - 1):]], dim=1)
|
||||
if extra is not None and "attention_mask" in extra:
|
||||
am = extra["attention_mask"]
|
||||
extra["attention_mask"] = torch.cat([am[..., :1], am[..., -(_PIXELDIT_MAX_LENGTH - 1):]], dim=-1)
|
||||
if extra is not None:
|
||||
return cond, pooled, extra
|
||||
return cond, pooled
|
||||
|
||||
|
||||
def pixeldit_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class PixelDiTTE_(PixelDiTGemma2TE):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return PixelDiTTE_
|
||||
@ -86,6 +86,7 @@ def load_safetensors(ckpt):
|
||||
import comfy_aimdo.model_mmap
|
||||
|
||||
f = open(ckpt, "rb", buffering=0)
|
||||
file_lock = threading.Lock()
|
||||
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
||||
file_size = os.path.getsize(ckpt)
|
||||
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
||||
@ -111,7 +112,7 @@ def load_safetensors(ckpt):
|
||||
storage = tensor.untyped_storage()
|
||||
setattr(storage,
|
||||
"_comfy_tensor_file_slice",
|
||||
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
||||
comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start))
|
||||
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
||||
sd[name] = tensor
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
|
||||
1
comfy_api_nodes/apis/__init__.py
generated
1
comfy_api_nodes/apis/__init__.py
generated
@ -3,7 +3,6 @@
|
||||
# timestamp: 2025-07-30T08:54:00+00:00
|
||||
|
||||
# pylint: disable
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
from enum import Enum
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
from typing import Type, Literal
|
||||
|
||||
import nodes
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict, Dict, Optional, Tuple
|
||||
from typing_extensions import override
|
||||
from PIL import Image
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
from comfy_api.latest import IO
|
||||
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
+ weighted Procrustes solver. Computes the 4x4 facial transformation matrix.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""Pure-PyTorch port of MediaPipe's face_landmarker_v2_with_blendshapes.task:
|
||||
BlazeFace detector → FaceMesh v2 → ARKit-52 blendshapes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import av
|
||||
import torchaudio
|
||||
import torch
|
||||
|
||||
@ -57,24 +57,55 @@ class CFGNorm(io.ComfyNode):
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||
io.Boolean.Input(
|
||||
"pre_cfg",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip=(
|
||||
"If true, rescale the combined noise BEFORE the sampler's CFG combine, "
|
||||
"without clamping (can amplify). Matches the norm-scaled CFG used by "
|
||||
"models like Lens. Default false keeps the original post-CFG x0-space "
|
||||
"attenuate-only behavior."
|
||||
),
|
||||
),
|
||||
],
|
||||
outputs=[io.Model.Output(display_name="patched_model")],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, strength) -> io.NodeOutput:
|
||||
def execute(cls, model, strength, pre_cfg=False) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
def cfg_norm(args):
|
||||
cond_p = args['cond_denoised']
|
||||
pred_text_ = args["denoised"]
|
||||
if pre_cfg:
|
||||
def cfg_norm_pre(args):
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
comb = uncond + cond_scale * (cond - uncond)
|
||||
cond_norm = torch.linalg.vector_norm(cond, dim=1, keepdim=True)
|
||||
comb_norm = torch.linalg.vector_norm(comb, dim=1, keepdim=True)
|
||||
rescale = torch.where(
|
||||
comb_norm > 0,
|
||||
cond_norm / comb_norm.clamp_min(1e-12),
|
||||
torch.ones_like(comb_norm),
|
||||
)
|
||||
rescaled = comb * rescale
|
||||
# strength blends back toward standard linear CFG (1.0 = full rescale).
|
||||
if strength != 1.0:
|
||||
rescaled = strength * rescaled + (1.0 - strength) * comb
|
||||
return rescaled
|
||||
m.set_model_sampler_cfg_function(cfg_norm_pre)
|
||||
else:
|
||||
def cfg_norm(args):
|
||||
cond_p = args['cond_denoised']
|
||||
pred_text_ = args["denoised"]
|
||||
|
||||
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
|
||||
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
|
||||
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
|
||||
return pred_text_ * scale * strength
|
||||
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
|
||||
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
|
||||
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
|
||||
return pred_text_ * scale * strength
|
||||
|
||||
m.set_model_sampler_post_cfg_function(cfg_norm)
|
||||
m.set_model_sampler_post_cfg_function(cfg_norm)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import comfy.context_windows
|
||||
import nodes
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import nodes
|
||||
import folder_paths
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
from typing import TypedDict
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
@ -226,10 +226,20 @@ def get_noise_mask(latent):
|
||||
noise_mask = noise_mask.clone()
|
||||
return noise_mask
|
||||
|
||||
def get_keyframe_idxs(cond):
|
||||
def get_keyframe_idxs(cond, latent_shape=None):
|
||||
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
||||
if keyframe_idxs is None:
|
||||
return None, 0
|
||||
# Get number of keyframes from latent_shape or guide_attention_entries if available
|
||||
if latent_shape is not None and len(latent_shape) == 5:
|
||||
tokens_per_frame = latent_shape[-2] * latent_shape[-1]
|
||||
num_keyframes = keyframe_idxs.shape[2] // tokens_per_frame
|
||||
return keyframe_idxs, num_keyframes
|
||||
entries = conditioning_get_any_value(cond, "guide_attention_entries", None)
|
||||
if entries:
|
||||
num_keyframes = sum(e["latent_shape"][0] for e in entries)
|
||||
return keyframe_idxs, num_keyframes
|
||||
# fallback, may under-count if keyframes share t-start
|
||||
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
|
||||
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
|
||||
return keyframe_idxs, num_keyframes
|
||||
@ -322,9 +332,9 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return factor
|
||||
|
||||
@classmethod
|
||||
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
|
||||
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors, latent_shape=None):
|
||||
time_scale_factor, _, _ = scale_factors
|
||||
_, num_keyframes = get_keyframe_idxs(cond)
|
||||
_, num_keyframes = get_keyframe_idxs(cond, latent_shape)
|
||||
latent_count = latent_length - num_keyframes
|
||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||
if guide_length > 1 and frame_idx != 0:
|
||||
@ -436,7 +446,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
|
||||
resolved_frame_idx = frame_idx
|
||||
if frame_idx < 0:
|
||||
_, num_keyframes = get_keyframe_idxs(positive)
|
||||
_, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
|
||||
resolved_frame_idx = max((latent_length - num_keyframes - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||
causal_fix = resolved_frame_idx == 0 or num_frames_to_keep == 1
|
||||
|
||||
@ -454,7 +464,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
if latent_downscale_factor > 1:
|
||||
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
|
||||
|
||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors, latent_shape=latent_image.shape)
|
||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
positive, negative, latent_image, noise_mask = cls.append_keyframe(
|
||||
@ -506,7 +516,7 @@ class LTXVCropGuides(io.ComfyNode):
|
||||
latent_image = latent["samples"].clone()
|
||||
noise_mask = get_noise_mask(latent)
|
||||
|
||||
_, num_keyframes = get_keyframe_idxs(positive)
|
||||
_, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
|
||||
if num_keyframes == 0:
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ Provides a ComfyMathExpression node that evaluates math expressions
|
||||
against dynamically-grown numeric inputs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import string
|
||||
|
||||
@ -10,7 +10,6 @@ Custom IO types:
|
||||
MediaPipeFaceLandmarker also emits the core BOUNDING_BOX type — pair with DrawBBoxes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
"""ComfyUI nodes for the native MoGe (Monocular Geometry Estimation) integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
408
comfy_extras/nodes_multigpu.py
Normal file
408
comfy_extras/nodes_multigpu.py
Normal file
@ -0,0 +1,408 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from inspect import cleandoc
|
||||
from typing import TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.sd import CLIP, VAE
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.multigpu
|
||||
|
||||
|
||||
class MultiGPUCFGSplitNode(io.ComfyNode):
|
||||
"""
|
||||
Prepares model to have sampling accelerated via splitting work units.
|
||||
|
||||
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
|
||||
|
||||
Other than those exceptions, this node can be placed in any order.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MultiGPU_WorkUnits",
|
||||
display_name="MultiGPU CFG Split",
|
||||
category="advanced/multigpu",
|
||||
description=cleandoc(cls.__doc__),
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Int.Input("max_gpus", default=2, min=1, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
|
||||
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
def _force_supported_compute_dtype(patcher: ModelPatcher, device: torch.device):
|
||||
"""Cast compute dtype to one the device supports; no-op if already supported."""
|
||||
weight_dtype = patcher.model_dtype()
|
||||
cast_dtype = comfy.model_management.unet_manual_cast(weight_dtype, device)
|
||||
if cast_dtype is None:
|
||||
return
|
||||
logging.info(f"Select Model Device: using {cast_dtype} compute dtype on {device} (model weight dtype was {weight_dtype}).")
|
||||
patcher.set_model_compute_dtype(cast_dtype)
|
||||
|
||||
|
||||
def _remember_base_devices(patcher: ModelPatcher):
|
||||
"""Stash the original load/offload device on the underlying model.
|
||||
|
||||
Stored on patcher.model (which is shared with the input patcher), so
|
||||
later "default" selections can recover the loader's original routing.
|
||||
Only the first Select on a given chain writes these attrs; subsequent
|
||||
deepclones inherit them onto their freshly-loaded model below.
|
||||
"""
|
||||
if not hasattr(patcher.model, "_select_base_load_device"):
|
||||
patcher.model._select_base_load_device = patcher.load_device
|
||||
patcher.model._select_base_offload_device = patcher.offload_device
|
||||
|
||||
|
||||
def _propagate_base_devices(src_model, dst_model):
|
||||
"""Carry the loader-original device attrs onto the freshly-deepcloned model."""
|
||||
if hasattr(src_model, "_select_base_load_device") and not hasattr(dst_model, "_select_base_load_device"):
|
||||
dst_model._select_base_load_device = src_model._select_base_load_device
|
||||
dst_model._select_base_offload_device = src_model._select_base_offload_device
|
||||
|
||||
|
||||
def _retarget_patcher(patcher: ModelPatcher, target_load_device, target_offload_device):
|
||||
"""Return a patcher whose actual model weights live on *target_load_device*.
|
||||
|
||||
If *patcher* is already on *target_load_device* we just retarget the
|
||||
(already-cloned) patcher's metadata in place. Otherwise we call
|
||||
:meth:`ModelPatcher.deepclone_multigpu` to spawn a fresh model from
|
||||
the loader's ``cached_patcher_init`` factory -- the only safe way to
|
||||
move weights that may already be partially loaded onto another device.
|
||||
|
||||
NOTE: reusing the input patcher's model when the requested device
|
||||
matches its current load_device is a deliberate fast path. Anything
|
||||
that has already mutated the original model (e.g. a prior KSampler
|
||||
invocation on the same model) will be observed here. This is by
|
||||
design and documented on the SelectXDeviceNode docstrings -- placing
|
||||
Select X Device after a node that consumes the same model is not
|
||||
recommended.
|
||||
"""
|
||||
if patcher.load_device == target_load_device:
|
||||
# Fast path: weights already on the desired device, just update offload.
|
||||
patcher.offload_device = target_offload_device
|
||||
return patcher
|
||||
src_model = patcher.model
|
||||
patcher = patcher.deepclone_multigpu(new_load_device=target_load_device)
|
||||
patcher.offload_device = target_offload_device
|
||||
_propagate_base_devices(src_model, patcher.model)
|
||||
if hasattr(patcher, "register_load_device"):
|
||||
patcher.register_load_device(patcher.load_device)
|
||||
return patcher
|
||||
|
||||
|
||||
def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None):
|
||||
"""Resolve the requested device and produce a patcher routed there.
|
||||
|
||||
For "default" we restore the loader's original load/offload pair.
|
||||
For CPU we pin both load and offload to CPU (and, on a dynamic
|
||||
patcher, downgrade to a plain ModelPatcher so the dynamic-only
|
||||
code paths are bypassed).
|
||||
For an explicit GPU we keep the loader's original offload but
|
||||
target the requested load device; if that differs from the current
|
||||
load device the patcher is deepcloned onto the new device.
|
||||
"""
|
||||
_remember_base_devices(patcher)
|
||||
base_load = patcher.model._select_base_load_device
|
||||
base_offload = base_offload_override if base_offload_override is not None else patcher.model._select_base_offload_device
|
||||
|
||||
if resolved is None:
|
||||
# "default" -> route back to the loader's original devices.
|
||||
return _retarget_patcher(patcher, base_load, base_offload)
|
||||
if resolved.type == "cpu":
|
||||
if patcher.is_dynamic():
|
||||
# clone(disable_dynamic=True) requires cached_patcher_init; let the
|
||||
# exception surface to the caller (Select*DeviceNode.execute), which
|
||||
# will translate it into a passthrough+log so unsupported loaders
|
||||
# don't hard-fail the workflow.
|
||||
patcher = patcher.clone(disable_dynamic=True)
|
||||
patcher.load_device = resolved
|
||||
patcher.offload_device = resolved
|
||||
return patcher
|
||||
return _retarget_patcher(patcher, resolved, base_offload)
|
||||
|
||||
|
||||
def _prune_multigpu_collision(model: ModelPatcher, primary_device):
|
||||
"""Drop any multigpu clone whose load_device matches *primary_device*.
|
||||
|
||||
Without pruning, MultiGPU CFG Split would have stacked a clone on
|
||||
the same device the primary now occupies (i.e. the workflow places
|
||||
MultiGPU CFG Split before Select Model Device). Keeps the clone set
|
||||
consistent with the new primary placement.
|
||||
"""
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
if not multigpu_models:
|
||||
return
|
||||
filtered = [m for m in multigpu_models if m.load_device != primary_device]
|
||||
if len(filtered) != len(multigpu_models):
|
||||
logging.info(f"Select Model Device: pruning MultiGPU clone on {primary_device} that now collides with the primary model.")
|
||||
model.set_additional_models("multigpu", filtered)
|
||||
if hasattr(model, "match_multigpu_clones"):
|
||||
model.match_multigpu_clones()
|
||||
|
||||
|
||||
class SelectModelDeviceNode(io.ComfyNode):
|
||||
"""
|
||||
Place the diffusion model on a specific device (default / cpu / gpu:N).
|
||||
|
||||
- "default" restores the device assigned by the loader (even after a
|
||||
prior Select Model Device call).
|
||||
- "cpu" pins both the load and offload device to CPU.
|
||||
- "gpu:N" pins the load device to the Nth available GPU; the offload
|
||||
device is restored to the loader's original choice.
|
||||
|
||||
When the requested device differs from the device the input model is
|
||||
already on, a fresh model is spawned via the loader's reload factory
|
||||
(cached_patcher_init) so the new patcher owns independent weights on
|
||||
the new device. Loaders that don't support multigpu (no factory) will
|
||||
cause the node to pass through unchanged with a warning.
|
||||
|
||||
If the workflow already has MultiGPU CFG Split applied and the chosen
|
||||
GPU collides with one of the existing multigpu clones, that clone is
|
||||
dropped so two patchers don't end up bound to the same device.
|
||||
|
||||
When the selected device does not exist on the current machine
|
||||
(e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
|
||||
the node passes the model through unchanged and logs a message
|
||||
instead of failing.
|
||||
|
||||
NOTE: Placing Select Model Device *after* a node that has already
|
||||
consumed the same model (e.g. a KSampler that ran on this model on
|
||||
the original device) is not recommended -- any state the prior
|
||||
consumer mutated on the original model will be observed when the
|
||||
selected device matches the original (fast path). Place Select Model
|
||||
Device before any consumer of the model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SelectModelDevice",
|
||||
display_name="Select Model Device",
|
||||
category="advanced/multigpu",
|
||||
description=cleandoc(cls.__doc__),
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options()),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, device="default"):
|
||||
# Allow unknown gpu:N values so portable workflows do not error
|
||||
# at validation time; runtime fallback will handle them.
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: ModelPatcher, device: str = "default") -> io.NodeOutput:
|
||||
model = model.clone()
|
||||
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||
if resolved is None and device not in (None, "default"):
|
||||
logging.info(f"Select Model Device: requested device '{device}' not available, passing through unchanged.")
|
||||
return io.NodeOutput(model)
|
||||
try:
|
||||
model = _apply_patcher_device(model, resolved)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})")
|
||||
return io.NodeOutput(model)
|
||||
if resolved is not None:
|
||||
_force_supported_compute_dtype(model, resolved)
|
||||
_prune_multigpu_collision(model, model.load_device)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class SelectCLIPDeviceNode(io.ComfyNode):
|
||||
"""
|
||||
Place the CLIP text encoder on a specific device (default / cpu / gpu:N).
|
||||
|
||||
- "default" restores the device assigned by the loader.
|
||||
- "cpu" pins both the load and offload device to CPU.
|
||||
- "gpu:N" pins the load device to the Nth available GPU.
|
||||
|
||||
When the selected device does not exist on the current machine
|
||||
(e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
|
||||
the node passes the CLIP through unchanged and logs a message
|
||||
instead of failing.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SelectCLIPDevice",
|
||||
display_name="Select CLIP Device",
|
||||
category="advanced/multigpu",
|
||||
description=cleandoc(cls.__doc__),
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options()),
|
||||
],
|
||||
outputs=[
|
||||
io.Clip.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, device="default"):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip: CLIP, device: str = "default") -> io.NodeOutput:
|
||||
clip = clip.clone()
|
||||
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||
if resolved is None and device not in (None, "default"):
|
||||
logging.info(f"Select CLIP Device: requested device '{device}' not available, passing through unchanged.")
|
||||
return io.NodeOutput(clip)
|
||||
try:
|
||||
clip.patcher = _apply_patcher_device(clip.patcher, resolved)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Select CLIP Device: cannot retarget CLIP, passing through unchanged. ({e})")
|
||||
return io.NodeOutput(clip)
|
||||
|
||||
|
||||
class SelectVAEDeviceNode(io.ComfyNode):
|
||||
"""
|
||||
Place the VAE on a specific device (default / gpu:N).
|
||||
|
||||
- "default" restores the device assigned by the loader.
|
||||
- "gpu:N" pins the load device to the Nth available GPU; the offload
|
||||
device is set to the standard VAE offload device.
|
||||
|
||||
CPU is intentionally not exposed in the UI for the VAE; if a workflow
|
||||
supplies "cpu" anyway (e.g. opened from another machine), the request
|
||||
is dropped with a log message and the VAE is passed through unchanged.
|
||||
|
||||
When the selected device does not exist on the current machine
|
||||
(e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
|
||||
the node passes the VAE through unchanged and logs a message
|
||||
instead of failing.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SelectVAEDevice",
|
||||
display_name="Select VAE Device",
|
||||
category="advanced/multigpu",
|
||||
description=cleandoc(cls.__doc__),
|
||||
inputs=[
|
||||
io.Vae.Input("vae"),
|
||||
io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options_no_cpu()),
|
||||
],
|
||||
outputs=[
|
||||
io.Vae.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, device="default"):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae: VAE, device: str = "default") -> io.NodeOutput:
|
||||
# VAE has no .clone(); shallow-copy the wrapper and clone the patcher
|
||||
# so we can retarget load/offload device without affecting the input VAE.
|
||||
vae = copy.copy(vae)
|
||||
vae.patcher = vae.patcher.clone()
|
||||
resolved = comfy.model_management.resolve_gpu_device_option(device)
|
||||
if resolved is None and device not in (None, "default"):
|
||||
logging.info(f"Select VAE Device: requested device '{device}' not available, passing through unchanged.")
|
||||
return io.NodeOutput(vae)
|
||||
if resolved is not None and resolved.type == "cpu":
|
||||
logging.info("Select VAE Device: CPU is not a supported choice, passing through unchanged.")
|
||||
return io.NodeOutput(vae)
|
||||
if not hasattr(vae, "_select_base_device"):
|
||||
vae._select_base_device = vae.device
|
||||
try:
|
||||
vae.patcher = _apply_patcher_device(
|
||||
vae.patcher, resolved,
|
||||
base_offload_override=comfy.model_management.vae_offload_device(),
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Select VAE Device: cannot retarget VAE, passing through unchanged. ({e})")
|
||||
return io.NodeOutput(vae)
|
||||
# Keep VAE wrapper in sync with whatever model the patcher now owns;
|
||||
# deepclone_multigpu may have produced a fresh first_stage_model.
|
||||
vae.first_stage_model = vae.patcher.model
|
||||
vae.device = vae._select_base_device if resolved is None else resolved
|
||||
return io.NodeOutput(vae)
|
||||
|
||||
|
||||
class MultiGPUOptionsNode(io.ComfyNode):
|
||||
"""
|
||||
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
|
||||
|
||||
NOTE (not registered yet, see MultiGPUExtension.get_node_list below):
|
||||
The output GPUOptionsGroup is plumbed through create_multigpu_deepclones() and stored on
|
||||
model.model_options['multigpu_options'] via GPUOptionsGroup.register(), but the cond
|
||||
scheduler in comfy/samplers.py (calc_cond_batch_outer_multigpu) does NOT yet consult
|
||||
relative_speed when distributing conds across devices; it uses a uniform conds_per_device
|
||||
round-robin via next_available_device(). Before re-enabling this node, wire its
|
||||
relative_speed into the scheduler (e.g. via comfy.multigpu.load_balance_devices(),
|
||||
which already implements the proportional split) so the input actually affects work
|
||||
distribution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MultiGPU_Options",
|
||||
display_name="MultiGPU Options",
|
||||
category="advanced/multigpu",
|
||||
description=cleandoc(cls.__doc__),
|
||||
inputs=[
|
||||
io.Int.Input("device_index", default=0, min=0, max=64),
|
||||
io.Float.Input("relative_speed", default=1.0, min=0.0, step=0.01),
|
||||
io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Custom("GPU_OPTIONS").Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput:
|
||||
if not gpu_options:
|
||||
gpu_options = comfy.multigpu.GPUOptionsGroup()
|
||||
else:
|
||||
gpu_options = gpu_options.clone()
|
||||
|
||||
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
|
||||
gpu_options.add(opt)
|
||||
|
||||
return io.NodeOutput(gpu_options)
|
||||
|
||||
|
||||
class MultiGPUExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
MultiGPUCFGSplitNode,
|
||||
SelectModelDeviceNode,
|
||||
SelectCLIPDeviceNode,
|
||||
SelectVAEDeviceNode,
|
||||
# MultiGPUOptionsNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> MultiGPUExtension:
|
||||
return MultiGPUExtension()
|
||||
@ -4,7 +4,6 @@ Provides a single node that converts INT, FLOAT, STRING, and BOOL
|
||||
inputs into FLOAT and INT outputs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user