mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 20:10:48 +08:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
commit
a58c4fbf68
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
Normal file
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
<!-- API_NODE_PR_CHECKLIST: do not remove -->
|
||||||
|
|
||||||
|
## API Node PR Checklist
|
||||||
|
|
||||||
|
### Scope
|
||||||
|
- [ ] **Is API Node Change**
|
||||||
|
|
||||||
|
### Pricing & Billing
|
||||||
|
- [ ] **Need pricing update**
|
||||||
|
- [ ] **No pricing update**
|
||||||
|
|
||||||
|
If **Need pricing update**:
|
||||||
|
- [ ] Metronome rate cards updated
|
||||||
|
- [ ] Auto‑billing tests updated and passing
|
||||||
|
|
||||||
|
### QA
|
||||||
|
- [ ] **QA done**
|
||||||
|
- [ ] **QA not required**
|
||||||
|
|
||||||
|
### Comms
|
||||||
|
- [ ] Informed **@Kosinkadink**
|
||||||
58
.github/workflows/api-node-template.yml
vendored
Normal file
58
.github/workflows/api-node-template.yml
vendored
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
name: Append API Node PR template
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request_target:
|
||||||
|
types: [opened, reopened, synchronize, edited, ready_for_review]
|
||||||
|
paths:
|
||||||
|
- 'comfy_api_nodes/**' # only run if these files changed
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
inject:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Ensure template exists and append to PR body
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const { owner, repo } = context.repo;
|
||||||
|
const number = context.payload.pull_request.number;
|
||||||
|
const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
|
||||||
|
const marker = '<!-- API_NODE_PR_CHECKLIST: do not remove -->';
|
||||||
|
|
||||||
|
const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
|
||||||
|
|
||||||
|
let templateText;
|
||||||
|
try {
|
||||||
|
const res = await github.rest.repos.getContent({
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
path: templatePath,
|
||||||
|
ref: pr.base.ref
|
||||||
|
});
|
||||||
|
const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
|
||||||
|
templateText = buf.toString('utf8');
|
||||||
|
} catch (e) {
|
||||||
|
core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enforce the presence of the marker inside the template (for idempotence)
|
||||||
|
if (!templateText.includes(marker)) {
|
||||||
|
core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the PR already contains the marker, do not append again.
|
||||||
|
const body = pr.body || '';
|
||||||
|
if (body.includes(marker)) {
|
||||||
|
core.info('Template already present in PR body; nothing to inject.');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
|
||||||
|
await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
|
||||||
|
core.notice('API Node template appended to PR description.');
|
||||||
20
.github/workflows/test-ci.yml
vendored
20
.github/workflows/test-ci.yml
vendored
@ -21,14 +21,15 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
# os: [macos, linux, windows]
|
# os: [macos, linux, windows]
|
||||||
os: [macos, linux]
|
# os: [macos, linux]
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
os: [linux]
|
||||||
|
python_version: ["3.10", "3.11", "3.12"]
|
||||||
cuda_version: ["12.1"]
|
cuda_version: ["12.1"]
|
||||||
torch_version: ["stable"]
|
torch_version: ["stable"]
|
||||||
include:
|
include:
|
||||||
- os: macos
|
# - os: macos
|
||||||
runner_label: [self-hosted, macOS]
|
# runner_label: [self-hosted, macOS]
|
||||||
flags: "--use-pytorch-cross-attention"
|
# flags: "--use-pytorch-cross-attention"
|
||||||
- os: linux
|
- os: linux
|
||||||
runner_label: [self-hosted, Linux]
|
runner_label: [self-hosted, Linux]
|
||||||
flags: ""
|
flags: ""
|
||||||
@ -73,14 +74,15 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [macos, linux]
|
# os: [macos, linux]
|
||||||
|
os: [linux]
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
cuda_version: ["12.1"]
|
cuda_version: ["12.1"]
|
||||||
torch_version: ["nightly"]
|
torch_version: ["nightly"]
|
||||||
include:
|
include:
|
||||||
- os: macos
|
# - os: macos
|
||||||
runner_label: [self-hosted, macOS]
|
# runner_label: [self-hosted, macOS]
|
||||||
flags: "--use-pytorch-cross-attention"
|
# flags: "--use-pytorch-cross-attention"
|
||||||
- os: linux
|
- os: linux
|
||||||
runner_label: [self-hosted, Linux]
|
runner_label: [self-hosted, Linux]
|
||||||
flags: ""
|
flags: ""
|
||||||
|
|||||||
168
QUANTIZATION.md
Normal file
168
QUANTIZATION.md
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
# The Comfy guide to Quantization
|
||||||
|
|
||||||
|
|
||||||
|
## How does quantization work?
|
||||||
|
|
||||||
|
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
|
||||||
|
|
||||||
|
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
|
||||||
|
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
|
||||||
|
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
|
||||||
|
|
||||||
|
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
|
||||||
|
|
||||||
|
```
|
||||||
|
absmax = max(abs(tensor))
|
||||||
|
scale = amax / max_dynamic_range_low_precision
|
||||||
|
|
||||||
|
# Quantization
|
||||||
|
tensor_q = (tensor / scale).to(low_precision_dtype)
|
||||||
|
|
||||||
|
# De-Quantization
|
||||||
|
tensor_dq = tensor_q.to(fp16) * scale
|
||||||
|
|
||||||
|
tensor_dq ~ tensor
|
||||||
|
```
|
||||||
|
|
||||||
|
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
|
||||||
|
|
||||||
|
|
||||||
|
## Quantization in Comfy
|
||||||
|
|
||||||
|
```
|
||||||
|
QuantizedTensor (torch.Tensor subclass)
|
||||||
|
↓ __torch_dispatch__
|
||||||
|
Two-Level Registry (generic + layout handlers)
|
||||||
|
↓
|
||||||
|
MixedPrecisionOps + Metadata Detection
|
||||||
|
```
|
||||||
|
|
||||||
|
### Representation
|
||||||
|
|
||||||
|
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
|
||||||
|
|
||||||
|
A `Layout` class defines how a specific quantization format behaves:
|
||||||
|
- Required parameters
|
||||||
|
- Quantize method
|
||||||
|
- De-Quantize method
|
||||||
|
|
||||||
|
```python
|
||||||
|
from comfy.quant_ops import QuantizedLayout
|
||||||
|
|
||||||
|
class MyLayout(QuantizedLayout):
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, **kwargs):
|
||||||
|
# Convert to quantized format
|
||||||
|
qdata = ...
|
||||||
|
params = {'scale': ..., 'orig_dtype': tensor.dtype}
|
||||||
|
return qdata, params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||||
|
return qdata.to(orig_dtype) * scale
|
||||||
|
```
|
||||||
|
|
||||||
|
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
|
||||||
|
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
|
||||||
|
|
||||||
|
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
|
||||||
|
```python
|
||||||
|
from comfy.quant_ops import register_layout_op
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
|
||||||
|
def my_linear(func, args, kwargs):
|
||||||
|
# Extract tensors, call optimized kernel
|
||||||
|
...
|
||||||
|
```
|
||||||
|
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
|
||||||
|
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
|
||||||
|
|
||||||
|
|
||||||
|
### Mixed Precision
|
||||||
|
|
||||||
|
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
|
||||||
|
|
||||||
|
**Architecture:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MixedPrecisionOps(disable_weight_init):
|
||||||
|
_layer_quant_config = {} # Maps layer names to quantization configs
|
||||||
|
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key mechanism:**
|
||||||
|
|
||||||
|
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
|
||||||
|
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
|
||||||
|
- If the layer name **is** in `_layer_quant_config`:
|
||||||
|
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
|
||||||
|
- Load associated quantization parameters (scales, block_size, etc.)
|
||||||
|
|
||||||
|
**Why it's needed:**
|
||||||
|
|
||||||
|
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
|
||||||
|
|
||||||
|
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
|
||||||
|
|
||||||
|
|
||||||
|
## Checkpoint Format
|
||||||
|
|
||||||
|
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
|
||||||
|
|
||||||
|
The quantized checkpoint will contain the same layers as the original checkpoint but:
|
||||||
|
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
|
||||||
|
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
|
||||||
|
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
|
||||||
|
|
||||||
|
### Scaling Parameters details
|
||||||
|
We define 4 possible scaling parameters that should cover most recipes in the near-future:
|
||||||
|
- **weight_scale**: quantization scalers for the weights
|
||||||
|
- **weight_scale_2**: global scalers in the context of double scaling
|
||||||
|
- **pre_quant_scale**: scalers used for smoothing salient weights
|
||||||
|
- **input_scale**: quantization scalers for the activations
|
||||||
|
|
||||||
|
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|
||||||
|
|--------|---------------|--------------|----------------|-----------------|-------------|
|
||||||
|
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
|
||||||
|
|
||||||
|
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
|
||||||
|
|
||||||
|
### Quantization Metadata
|
||||||
|
|
||||||
|
The metadata stored alongside the checkpoint contains:
|
||||||
|
- **format_version**: String to define a version of the standard
|
||||||
|
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"_quantization_metadata": {
|
||||||
|
"format_version": "1.0",
|
||||||
|
"layers": {
|
||||||
|
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
|
||||||
|
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
|
||||||
|
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Creating Quantized Checkpoints
|
||||||
|
|
||||||
|
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
|
||||||
|
|
||||||
|
### Weight Quantization
|
||||||
|
|
||||||
|
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
|
||||||
|
|
||||||
|
### Calibration (for Activation Quantization)
|
||||||
|
|
||||||
|
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
|
||||||
|
|
||||||
|
1. **Collect statistics**: Run inference on N representative samples
|
||||||
|
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
|
||||||
|
3. **Compute scales**: Derive `input_scale` from collected statistics
|
||||||
|
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
|
||||||
|
|
||||||
|
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
|
||||||
10
README.md
10
README.md
@ -173,7 +173,7 @@ There is a portable standalone build for Windows that should work for running on
|
|||||||
|
|
||||||
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
|
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
|
||||||
|
|
||||||
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\
|
||||||
|
|
||||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||||
|
|
||||||
@ -200,7 +200,7 @@ comfy install
|
|||||||
|
|
||||||
## Manual Install (Windows, Linux)
|
## Manual Install (Windows, Linux)
|
||||||
|
|
||||||
Python 3.14 will work if you comment out the `kornia` dependency in the requirements.txt file (breaks the canny node) but it is not recommended.
|
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
|
||||||
|
|
||||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||||
|
|
||||||
@ -242,7 +242,7 @@ RDNA 4 (RX 9000 series):
|
|||||||
|
|
||||||
### Intel GPUs (Windows and Linux)
|
### Intel GPUs (Windows and Linux)
|
||||||
|
|
||||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||||
|
|
||||||
1. To install PyTorch xpu, use the following command:
|
1. To install PyTorch xpu, use the following command:
|
||||||
|
|
||||||
@ -252,10 +252,6 @@ This is the command to install the Pytorch xpu nightly which might have some per
|
|||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||||
|
|
||||||
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
|
||||||
|
|
||||||
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
Nvidia users should install stable pytorch using this command:
|
Nvidia users should install stable pytorch using this command:
|
||||||
|
|||||||
@ -1,15 +1,15 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from comfy.ldm.flux.math import attention
|
|
||||||
from comfy.ldm.flux.layers import (
|
from comfy.ldm.flux.layers import (
|
||||||
MLPEmbedder,
|
MLPEmbedder,
|
||||||
RMSNorm,
|
RMSNorm,
|
||||||
QKNorm,
|
|
||||||
SelfAttention,
|
|
||||||
ModulationOut,
|
ModulationOut,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: remove this in a few months
|
||||||
|
SingleStreamBlock = None
|
||||||
|
DoubleStreamBlock = None
|
||||||
|
|
||||||
|
|
||||||
class ChromaModulationOut(ModulationOut):
|
class ChromaModulationOut(ModulationOut):
|
||||||
@ -48,124 +48,6 @@ class Approximator(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class DoubleStreamBlock(nn.Module):
|
|
||||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
||||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
|
||||||
|
|
||||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
||||||
self.img_mlp = nn.Sequential(
|
|
||||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
|
||||||
nn.GELU(approximate="tanh"),
|
|
||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
||||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
|
||||||
|
|
||||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
||||||
self.txt_mlp = nn.Sequential(
|
|
||||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
|
||||||
nn.GELU(approximate="tanh"),
|
|
||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
|
||||||
)
|
|
||||||
self.flipped_img_txt = flipped_img_txt
|
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
|
|
||||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
|
||||||
|
|
||||||
# prepare image for attention
|
|
||||||
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
|
|
||||||
img_qkv = self.img_attn.qkv(img_modulated)
|
|
||||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
||||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
|
||||||
|
|
||||||
# prepare txt for attention
|
|
||||||
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
|
|
||||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
|
||||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
|
||||||
|
|
||||||
# run actual attention
|
|
||||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
|
||||||
torch.cat((txt_k, img_k), dim=2),
|
|
||||||
torch.cat((txt_v, img_v), dim=2),
|
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
|
||||||
|
|
||||||
# calculate the img bloks
|
|
||||||
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
|
|
||||||
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
|
|
||||||
|
|
||||||
# calculate the txt bloks
|
|
||||||
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
|
||||||
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
|
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
|
||||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
|
||||||
|
|
||||||
return img, txt
|
|
||||||
|
|
||||||
|
|
||||||
class SingleStreamBlock(nn.Module):
|
|
||||||
"""
|
|
||||||
A DiT block with parallel linear layers as described in
|
|
||||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
num_heads: int,
|
|
||||||
mlp_ratio: float = 4.0,
|
|
||||||
qk_scale: float = None,
|
|
||||||
dtype=None,
|
|
||||||
device=None,
|
|
||||||
operations=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_dim = hidden_size
|
|
||||||
self.num_heads = num_heads
|
|
||||||
head_dim = hidden_size // num_heads
|
|
||||||
self.scale = qk_scale or head_dim**-0.5
|
|
||||||
|
|
||||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
||||||
# qkv and mlp_in
|
|
||||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
|
||||||
# proj and mlp_out
|
|
||||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.mlp_act = nn.GELU(approximate="tanh")
|
|
||||||
|
|
||||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
|
|
||||||
mod = vec
|
|
||||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
|
||||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
||||||
q, k = self.norm(q, k, v)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
|
||||||
x.addcmul_(mod.gate, output)
|
|
||||||
if x.dtype == torch.float16:
|
|
||||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class LastLayer(nn.Module):
|
class LastLayer(nn.Module):
|
||||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -11,12 +11,12 @@ import comfy.ldm.common_dit
|
|||||||
from comfy.ldm.flux.layers import (
|
from comfy.ldm.flux.layers import (
|
||||||
EmbedND,
|
EmbedND,
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
|
DoubleStreamBlock,
|
||||||
|
SingleStreamBlock,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .layers import (
|
from .layers import (
|
||||||
DoubleStreamBlock,
|
|
||||||
LastLayer,
|
LastLayer,
|
||||||
SingleStreamBlock,
|
|
||||||
Approximator,
|
Approximator,
|
||||||
ChromaModulationOut,
|
ChromaModulationOut,
|
||||||
)
|
)
|
||||||
@ -90,6 +90,7 @@ class Chroma(nn.Module):
|
|||||||
self.num_heads,
|
self.num_heads,
|
||||||
mlp_ratio=params.mlp_ratio,
|
mlp_ratio=params.mlp_ratio,
|
||||||
qkv_bias=params.qkv_bias,
|
qkv_bias=params.qkv_bias,
|
||||||
|
modulation=False,
|
||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
for _ in range(params.depth)
|
for _ in range(params.depth)
|
||||||
@ -98,7 +99,7 @@ class Chroma(nn.Module):
|
|||||||
|
|
||||||
self.single_blocks = nn.ModuleList(
|
self.single_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations)
|
||||||
for _ in range(params.depth_single_blocks)
|
for _ in range(params.depth_single_blocks)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@ -10,12 +10,10 @@ from torch import Tensor, nn
|
|||||||
from einops import repeat
|
from einops import repeat
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
|
||||||
|
|
||||||
from comfy.ldm.chroma.model import Chroma, ChromaParams
|
from comfy.ldm.chroma.model import Chroma, ChromaParams
|
||||||
from comfy.ldm.chroma.layers import (
|
from comfy.ldm.chroma.layers import (
|
||||||
DoubleStreamBlock,
|
|
||||||
SingleStreamBlock,
|
|
||||||
Approximator,
|
Approximator,
|
||||||
)
|
)
|
||||||
from .layers import (
|
from .layers import (
|
||||||
@ -89,7 +87,6 @@ class ChromaRadiance(Chroma):
|
|||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
self.double_blocks = nn.ModuleList(
|
self.double_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
DoubleStreamBlock(
|
DoubleStreamBlock(
|
||||||
@ -97,6 +94,7 @@ class ChromaRadiance(Chroma):
|
|||||||
self.num_heads,
|
self.num_heads,
|
||||||
mlp_ratio=params.mlp_ratio,
|
mlp_ratio=params.mlp_ratio,
|
||||||
qkv_bias=params.qkv_bias,
|
qkv_bias=params.qkv_bias,
|
||||||
|
modulation=False,
|
||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
for _ in range(params.depth)
|
for _ in range(params.depth)
|
||||||
@ -109,6 +107,7 @@ class ChromaRadiance(Chroma):
|
|||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
mlp_ratio=params.mlp_ratio,
|
mlp_ratio=params.mlp_ratio,
|
||||||
|
modulation=False,
|
||||||
dtype=dtype, device=device, operations=operations,
|
dtype=dtype, device=device, operations=operations,
|
||||||
)
|
)
|
||||||
for _ in range(params.depth_single_blocks)
|
for _ in range(params.depth_single_blocks)
|
||||||
|
|||||||
@ -130,13 +130,17 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
|||||||
|
|
||||||
|
|
||||||
class DoubleStreamBlock(nn.Module):
|
class DoubleStreamBlock(nn.Module):
|
||||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
self.modulation = modulation
|
||||||
|
|
||||||
|
if self.modulation:
|
||||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
@ -147,7 +151,9 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.modulation:
|
||||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
@ -160,46 +166,65 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
self.flipped_img_txt = flipped_img_txt
|
self.flipped_img_txt = flipped_img_txt
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||||
|
if self.modulation:
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
else:
|
||||||
|
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||||
|
|
||||||
# prepare image for attention
|
# prepare image for attention
|
||||||
img_modulated = self.img_norm1(img)
|
img_modulated = self.img_norm1(img)
|
||||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||||
img_qkv = self.img_attn.qkv(img_modulated)
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
|
del img_modulated
|
||||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
del img_qkv
|
||||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
# prepare txt for attention
|
# prepare txt for attention
|
||||||
txt_modulated = self.txt_norm1(txt)
|
txt_modulated = self.txt_norm1(txt)
|
||||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
|
del txt_modulated
|
||||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
del txt_qkv
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
if self.flipped_img_txt:
|
if self.flipped_img_txt:
|
||||||
|
q = torch.cat((img_q, txt_q), dim=2)
|
||||||
|
del img_q, txt_q
|
||||||
|
k = torch.cat((img_k, txt_k), dim=2)
|
||||||
|
del img_k, txt_k
|
||||||
|
v = torch.cat((img_v, txt_v), dim=2)
|
||||||
|
del img_v, txt_v
|
||||||
# run actual attention
|
# run actual attention
|
||||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
attn = attention(q, k, v,
|
||||||
torch.cat((img_k, txt_k), dim=2),
|
|
||||||
torch.cat((img_v, txt_v), dim=2),
|
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||||
else:
|
else:
|
||||||
|
q = torch.cat((txt_q, img_q), dim=2)
|
||||||
|
del txt_q, img_q
|
||||||
|
k = torch.cat((txt_k, img_k), dim=2)
|
||||||
|
del txt_k, img_k
|
||||||
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
|
del txt_v, img_v
|
||||||
# run actual attention
|
# run actual attention
|
||||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
attn = attention(q, k, v,
|
||||||
torch.cat((txt_k, img_k), dim=2),
|
|
||||||
torch.cat((txt_v, img_v), dim=2),
|
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||||
|
del img_attn
|
||||||
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||||
|
del txt_attn
|
||||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
@ -220,6 +245,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
qk_scale: float = None,
|
qk_scale: float = None,
|
||||||
|
modulation=True,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None
|
operations=None
|
||||||
@ -242,19 +268,29 @@ class SingleStreamBlock(nn.Module):
|
|||||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.mlp_act = nn.GELU(approximate="tanh")
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
|
if modulation:
|
||||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
self.modulation = None
|
||||||
|
|
||||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
||||||
|
if self.modulation:
|
||||||
mod, _ = self.modulation(vec)
|
mod, _ = self.modulation(vec)
|
||||||
|
else:
|
||||||
|
mod = vec
|
||||||
|
|
||||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
del qkv
|
||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
del q, k, v
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
mlp = self.mlp_act(mlp)
|
||||||
|
output = self.linear2(torch.cat((attn, mlp), 2))
|
||||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
|||||||
@ -236,10 +236,10 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
||||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
||||||
|
|
||||||
img_normed = self.img_norm1(hidden_states)
|
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
||||||
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
del img_mod1
|
||||||
txt_normed = self.txt_norm1(encoder_hidden_states)
|
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||||
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
del txt_mod1
|
||||||
|
|
||||||
img_attn_output, txt_attn_output = self.attn(
|
img_attn_output, txt_attn_output = self.attn(
|
||||||
hidden_states=img_modulated,
|
hidden_states=img_modulated,
|
||||||
@ -248,16 +248,20 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
del img_modulated
|
||||||
|
del txt_modulated
|
||||||
|
|
||||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||||
|
del img_attn_output
|
||||||
|
del txt_attn_output
|
||||||
|
del img_gate1
|
||||||
|
del txt_gate1
|
||||||
|
|
||||||
img_normed2 = self.img_norm2(hidden_states)
|
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
||||||
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
|
||||||
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
||||||
|
|
||||||
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
||||||
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
|
||||||
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
||||||
|
|
||||||
return encoder_hidden_states, hidden_states
|
return encoder_hidden_states, hidden_states
|
||||||
|
|||||||
@ -503,10 +503,7 @@ class LoadedModel:
|
|||||||
use_more_vram = lowvram_model_memory
|
use_more_vram = lowvram_model_memory
|
||||||
if use_more_vram == 0:
|
if use_more_vram == 0:
|
||||||
use_more_vram = 1e32
|
use_more_vram = 1e32
|
||||||
if use_more_vram > 0:
|
|
||||||
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
|
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
|
||||||
else:
|
|
||||||
self.model.partially_unload(self.model.offload_device, -use_more_vram, force_patch_weights=force_patch_weights)
|
|
||||||
|
|
||||||
real_model = self.model.model
|
real_model = self.model.model
|
||||||
|
|
||||||
@ -1107,6 +1104,9 @@ def pin_memory(tensor):
|
|||||||
if MAX_PINNED_MEMORY <= 0:
|
if MAX_PINNED_MEMORY <= 0:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if type(tensor) is not torch.nn.parameter.Parameter:
|
||||||
|
return False
|
||||||
|
|
||||||
if not is_device_cpu(tensor.device):
|
if not is_device_cpu(tensor.device):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -1116,6 +1116,9 @@ def pin_memory(tensor):
|
|||||||
#on the GPU async. So dont trust the CUDA API and guard here
|
#on the GPU async. So dont trust the CUDA API and guard here
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if not tensor.is_contiguous():
|
||||||
|
return False
|
||||||
|
|
||||||
size = tensor.numel() * tensor.element_size()
|
size = tensor.numel() * tensor.element_size()
|
||||||
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -928,6 +928,9 @@ class ModelPatcher:
|
|||||||
extra_memory += (used - self.model.model_loaded_weight_memory)
|
extra_memory += (used - self.model.model_loaded_weight_memory)
|
||||||
|
|
||||||
self.patch_model(load_weights=False)
|
self.patch_model(load_weights=False)
|
||||||
|
if extra_memory < 0 and not unpatch_weights:
|
||||||
|
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||||
|
return 0
|
||||||
full_load = False
|
full_load = False
|
||||||
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
||||||
self.apply_hooks(self.forced_hooks, force_apply=True)
|
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||||
|
|||||||
35
comfy/ops.py
35
comfy/ops.py
@ -77,6 +77,9 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
# will add async-offload support to your cast and improve performance.
|
# will add async-offload support to your cast and improve performance.
|
||||||
if input is not None:
|
if input is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
|
if isinstance(input, QuantizedTensor):
|
||||||
|
dtype = input._layout_params["orig_dtype"]
|
||||||
|
else:
|
||||||
dtype = input.dtype
|
dtype = input.dtype
|
||||||
if bias_dtype is None:
|
if bias_dtype is None:
|
||||||
bias_dtype = dtype
|
bias_dtype = dtype
|
||||||
@ -534,18 +537,7 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Mixed Precision Operations
|
# Mixed Precision Operations
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
from .quant_ops import QuantizedTensor
|
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
||||||
|
|
||||||
QUANT_FORMAT_MIXINS = {
|
|
||||||
"float8_e4m3fn": {
|
|
||||||
"dtype": torch.float8_e4m3fn,
|
|
||||||
"layout_type": "TensorCoreFP8Layout",
|
|
||||||
"parameters": {
|
|
||||||
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
|
||||||
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class MixedPrecisionOps(disable_weight_init):
|
class MixedPrecisionOps(disable_weight_init):
|
||||||
_layer_quant_config = {}
|
_layer_quant_config = {}
|
||||||
@ -596,23 +588,24 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
if quant_format is None:
|
if quant_format is None:
|
||||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||||
|
|
||||||
mixin = QUANT_FORMAT_MIXINS[quant_format]
|
qconfig = QUANT_ALGOS[quant_format]
|
||||||
self.layout_type = mixin["layout_type"]
|
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||||
|
|
||||||
scale_key = f"{prefix}weight_scale"
|
weight_scale_key = f"{prefix}weight_scale"
|
||||||
layout_params = {
|
layout_params = {
|
||||||
'scale': state_dict.pop(scale_key, None),
|
'scale': state_dict.pop(weight_scale_key, None),
|
||||||
'orig_dtype': MixedPrecisionOps._compute_dtype
|
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||||
|
'block_size': qconfig.get("group_size", None),
|
||||||
}
|
}
|
||||||
if layout_params['scale'] is not None:
|
if layout_params['scale'] is not None:
|
||||||
manually_loaded_keys.append(scale_key)
|
manually_loaded_keys.append(weight_scale_key)
|
||||||
|
|
||||||
self.weight = torch.nn.Parameter(
|
self.weight = torch.nn.Parameter(
|
||||||
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
|
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
|
||||||
requires_grad=False
|
requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
for param_name, param_value in mixin["parameters"].items():
|
for param_name in qconfig["parameters"]:
|
||||||
param_key = f"{prefix}{param_name}"
|
param_key = f"{prefix}{param_name}"
|
||||||
_v = state_dict.pop(param_key, None)
|
_v = state_dict.pop(param_key, None)
|
||||||
if _v is None:
|
if _v is None:
|
||||||
@ -643,7 +636,7 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
if (getattr(self, 'layout_type', None) is not None and
|
if (getattr(self, 'layout_type', None) is not None and
|
||||||
getattr(self, 'input_scale', None) is not None and
|
getattr(self, 'input_scale', None) is not None and
|
||||||
not isinstance(input, QuantizedTensor)):
|
not isinstance(input, QuantizedTensor)):
|
||||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
||||||
return self._forward(input, self.weight, self.bias)
|
return self._forward(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -74,6 +74,12 @@ def _copy_layout_params(params):
|
|||||||
new_params[k] = v
|
new_params[k] = v
|
||||||
return new_params
|
return new_params
|
||||||
|
|
||||||
|
def _copy_layout_params_inplace(src, dst, non_blocking=False):
|
||||||
|
for k, v in src.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
dst[k].copy_(v, non_blocking=non_blocking)
|
||||||
|
else:
|
||||||
|
dst[k] = v
|
||||||
|
|
||||||
class QuantizedLayout:
|
class QuantizedLayout:
|
||||||
"""
|
"""
|
||||||
@ -318,13 +324,13 @@ def generic_to_dtype_layout(func, args, kwargs):
|
|||||||
def generic_copy_(func, args, kwargs):
|
def generic_copy_(func, args, kwargs):
|
||||||
qt_dest = args[0]
|
qt_dest = args[0]
|
||||||
src = args[1]
|
src = args[1]
|
||||||
|
non_blocking = args[2] if len(args) > 2 else False
|
||||||
if isinstance(qt_dest, QuantizedTensor):
|
if isinstance(qt_dest, QuantizedTensor):
|
||||||
if isinstance(src, QuantizedTensor):
|
if isinstance(src, QuantizedTensor):
|
||||||
# Copy from another quantized tensor
|
# Copy from another quantized tensor
|
||||||
qt_dest._qdata.copy_(src._qdata)
|
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
||||||
qt_dest._layout_type = src._layout_type
|
qt_dest._layout_type = src._layout_type
|
||||||
qt_dest._layout_params = _copy_layout_params(src._layout_params)
|
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
||||||
else:
|
else:
|
||||||
# Copy from regular tensor - just copy raw data
|
# Copy from regular tensor - just copy raw data
|
||||||
qt_dest._qdata.copy_(src)
|
qt_dest._qdata.copy_(src)
|
||||||
@ -336,6 +342,26 @@ def generic_copy_(func, args, kwargs):
|
|||||||
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.empty_like.default)
|
||||||
|
def generic_empty_like(func, args, kwargs):
|
||||||
|
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
# Create empty tensor with same shape and dtype as the quantized data
|
||||||
|
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
|
||||||
|
new_qdata = torch.empty_like(qt._qdata, **kwargs)
|
||||||
|
|
||||||
|
# Handle device transfer for layout params
|
||||||
|
target_device = kwargs.get('device', new_qdata.device)
|
||||||
|
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||||
|
|
||||||
|
# Update orig_dtype if dtype is specified
|
||||||
|
new_params['orig_dtype'] = hp_dtype
|
||||||
|
|
||||||
|
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# FP8 Layout + Operation Handlers
|
# FP8 Layout + Operation Handlers
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
@ -378,6 +404,13 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
def get_plain_tensors(cls, qtensor):
|
def get_plain_tensors(cls, qtensor):
|
||||||
return qtensor._qdata, qtensor._layout_params['scale']
|
return qtensor._qdata, qtensor._layout_params['scale']
|
||||||
|
|
||||||
|
QUANT_ALGOS = {
|
||||||
|
"float8_e4m3fn": {
|
||||||
|
"storage_t": torch.float8_e4m3fn,
|
||||||
|
"parameters": {"weight_scale", "input_scale"},
|
||||||
|
"comfy_tensor_layout": "TensorCoreFP8Layout",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
LAYOUTS = {
|
LAYOUTS = {
|
||||||
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
|
import time
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
import folder_paths
|
import folder_paths
|
||||||
@ -739,6 +740,7 @@ class PromptServer():
|
|||||||
for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS:
|
for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS:
|
||||||
if sensitive_val in extra_data:
|
if sensitive_val in extra_data:
|
||||||
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
|
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
|
||||||
|
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
|
||||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
|
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
|
||||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||||
return web.json_response(response)
|
return web.json_response(response)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user